mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
cumulative ops added back and topk fixed
This commit is contained in:
@@ -38,6 +38,7 @@ pretty-duration = "0.1.1"
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
candle-core = "0.9.1"
|
candle-core = "0.9.1"
|
||||||
candle-nn = "0.9.1"
|
candle-nn = "0.9.1"
|
||||||
|
ordered-float = "5.1.0"
|
||||||
|
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
|
|||||||
@@ -53,12 +53,22 @@ impl GraphTensor {
|
|||||||
|
|
||||||
/// add a new dimension of size 1 at the specified place
|
/// add a new dimension of size 1 at the specified place
|
||||||
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
|
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
|
||||||
// Insert contiguous call
|
|
||||||
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
|
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
|
||||||
self.shape.expand_dim(dim, 1);
|
self.shape.expand_dim(dim, 1);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// remove a dimension of size 1
|
||||||
|
pub fn squeeze(mut self, axis: usize) -> GraphTensor {
|
||||||
|
assert_eq!(
|
||||||
|
self.dims()[axis],
|
||||||
|
Expression::from(1),
|
||||||
|
"Only dimensions of size 1 can be squeezed!"
|
||||||
|
);
|
||||||
|
self.shape.remove_dim(axis);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
|
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
indexes.dtype,
|
indexes.dtype,
|
||||||
@@ -78,18 +88,32 @@ impl GraphTensor {
|
|||||||
/// x = [3, 2, 4, 1, 5, 0]
|
/// x = [3, 2, 4, 1, 5, 0]
|
||||||
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
|
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
|
||||||
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
|
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
|
||||||
// TODO: this is very inefficient because we need to do O(n^2) comparisons and allocations for the one-hot and then sum reduce. Need a better way!
|
// TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
|
||||||
let ax_size = self.dims()[axis];
|
assert_eq!(self.dtype, DType::Int);
|
||||||
let mut arange = self.graph().arange(ax_size);
|
let dims = self.dims();
|
||||||
|
let ax_size = dims[axis];
|
||||||
|
let mut dims2 = dims.clone();
|
||||||
|
dims2.insert(axis, ax_size);
|
||||||
|
// candidate: varies along candidate dim (axis), broadcast elsewhere.
|
||||||
|
let mut candidate = self.graph().arange(ax_size);
|
||||||
for i in 0..axis {
|
for i in 0..axis {
|
||||||
arange = arange.expand_dim(i, self.dims()[i]);
|
candidate = candidate.expand_dim(i, dims2[i]);
|
||||||
}
|
}
|
||||||
for i in axis + 1..self.dims().len() {
|
for i in axis + 1..dims2.len() {
|
||||||
arange = arange.expand_dim(i, self.dims()[i]);
|
candidate = candidate.expand_dim(i, dims2[i]);
|
||||||
}
|
}
|
||||||
arange = arange.expand_dim(axis + 1, ax_size);
|
// position: varies along position dim (axis+1), broadcast elsewhere.
|
||||||
let one_hot = self.expand_dim(axis + 1, ax_size).eq(arange);
|
let mut position = self.graph().arange(ax_size);
|
||||||
(one_hot * arange).sum(axis + 1)
|
for i in 0..(axis + 1) {
|
||||||
|
position = position.expand_dim(i, dims2[i]);
|
||||||
|
}
|
||||||
|
for i in (axis + 2)..dims2.len() {
|
||||||
|
position = position.expand_dim(i, dims2[i]);
|
||||||
|
}
|
||||||
|
// one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
|
||||||
|
let one_hot = self.expand_dim(axis, ax_size).eq(candidate);
|
||||||
|
// inv[candidate, ...] = Σ_pos one_hot * position
|
||||||
|
(one_hot * position).sum(axis + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts sliding local windows from an input tensor.
|
/// Extracts sliding local windows from an input tensor.
|
||||||
@@ -101,6 +125,7 @@ impl GraphTensor {
|
|||||||
) -> GraphTensor {
|
) -> GraphTensor {
|
||||||
let (kernel, strides, dilation) =
|
let (kernel, strides, dilation) =
|
||||||
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
|
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
self.shape.len(),
|
self.shape.len(),
|
||||||
kernel.len(),
|
kernel.len(),
|
||||||
@@ -117,7 +142,7 @@ impl GraphTensor {
|
|||||||
"Dilation must be same number of dimensions as tensor!"
|
"Dilation must be same number of dimensions as tensor!"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Compute input strides
|
// Compute input strides (row-major contiguous)
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
let n = dims.len();
|
let n = dims.len();
|
||||||
let mut in_strides = vec![Expression::from(1); n];
|
let mut in_strides = vec![Expression::from(1); n];
|
||||||
@@ -128,26 +153,28 @@ impl GraphTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Per-dim window counts
|
// Per-dim window counts
|
||||||
let mut win = Vec::with_capacity(dims.len());
|
let mut win = Vec::with_capacity(n);
|
||||||
for (((dim, kernel), stride), dilation) in
|
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
|
||||||
dims.into_iter().zip(&kernel).zip(&strides).zip(&dilation)
|
let effective_window = *d * (*k - 1) + 1;
|
||||||
{
|
win.push(((*dim - effective_window) / s) + 1);
|
||||||
let effective_window = *dilation * (*kernel - 1) + 1;
|
|
||||||
win.push(((dim - effective_window) / stride) + 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// final_shape = [kernel..., win...]
|
// [win..., kernel...]
|
||||||
let mut final_shape = kernel.clone();
|
let mut final_shape: Vec<Expression> = win.into_iter().map(|e| e.simplify()).collect();
|
||||||
final_shape.extend(win.into_iter().map(|e| e.simplify()));
|
final_shape.extend(kernel.iter().copied());
|
||||||
|
|
||||||
// strides exprs over axes [k0..kN-1, w0..wN-1]
|
// Axis exprs must match final_shape axis order: first w axes, then k axes.
|
||||||
|
// idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
|
||||||
let mut axis_exprs = Vec::with_capacity(2 * n);
|
let mut axis_exprs = Vec::with_capacity(2 * n);
|
||||||
for i in 0..n {
|
|
||||||
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
|
// w axes
|
||||||
}
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
|
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
|
||||||
}
|
}
|
||||||
|
// k axes
|
||||||
|
for i in 0..n {
|
||||||
|
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
|
||||||
|
}
|
||||||
|
|
||||||
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
|
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
|
||||||
let iota = self.graph().iota(index_expression, final_shape);
|
let iota = self.graph().iota(index_expression, final_shape);
|
||||||
@@ -266,6 +293,7 @@ impl GraphTensor {
|
|||||||
new_tensor * mask
|
new_tensor * mask
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pad along an existing dimension
|
||||||
pub fn pad_along(
|
pub fn pad_along(
|
||||||
self,
|
self,
|
||||||
left: impl Into<Expression>,
|
left: impl Into<Expression>,
|
||||||
@@ -338,10 +366,170 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unfold() {
|
fn test_unfold() {
|
||||||
let mut cx = Graph::new();
|
// Need all this code because candle doesnt do unfold
|
||||||
|
pub fn unfold_nd_f32(
|
||||||
|
x: &[f32],
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
kernel: &[usize],
|
||||||
|
step: &[usize],
|
||||||
|
dilation: &[usize],
|
||||||
|
pad_before: &[usize],
|
||||||
|
pad_after: &[usize],
|
||||||
|
) -> Vec<f32> {
|
||||||
|
let n = shape.len();
|
||||||
|
assert!(n > 0);
|
||||||
|
assert_eq!(strides.len(), n);
|
||||||
|
assert_eq!(kernel.len(), n);
|
||||||
|
assert_eq!(step.len(), n);
|
||||||
|
assert_eq!(dilation.len(), n);
|
||||||
|
assert_eq!(pad_before.len(), n);
|
||||||
|
assert_eq!(pad_after.len(), n);
|
||||||
|
|
||||||
let inp = cx.tensor((5,));
|
for d in 0..n {
|
||||||
let _pooled = inp.unfold((3,), (1,), (1,));
|
assert!(kernel[d] > 0);
|
||||||
|
assert!(step[d] > 0);
|
||||||
|
assert!(dilation[d] > 0);
|
||||||
|
assert!(shape[d] > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Effective kernel size per dim: (K-1)*d + 1
|
||||||
|
let eff_kernel: Vec<usize> =
|
||||||
|
(0..n).map(|d| (kernel[d] - 1) * dilation[d] + 1).collect();
|
||||||
|
|
||||||
|
// Output spatial shape (number of windows) per dim
|
||||||
|
let mut out_shape = vec![0usize; n];
|
||||||
|
for d in 0..n {
|
||||||
|
let padded = shape[d] + pad_before[d] + pad_after[d];
|
||||||
|
if padded < eff_kernel[d] {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
out_shape[d] = (padded - eff_kernel[d]) / step[d] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let windows = prod(&out_shape);
|
||||||
|
let window_elems = prod(kernel);
|
||||||
|
let mut out = vec![0.0f32; windows * window_elems];
|
||||||
|
|
||||||
|
// Precompute helpers
|
||||||
|
let k_mul = row_major_multipliers(kernel);
|
||||||
|
|
||||||
|
// Current output window position (row-major)
|
||||||
|
let mut out_pos = vec![0usize; n];
|
||||||
|
|
||||||
|
for w in 0..windows {
|
||||||
|
if w > 0 {
|
||||||
|
incr_row_major(&mut out_pos, &out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Window start in padded coordinates
|
||||||
|
let start_padded: Vec<usize> = (0..n).map(|d| out_pos[d] * step[d]).collect();
|
||||||
|
|
||||||
|
let base_out = w * window_elems;
|
||||||
|
|
||||||
|
// Iterate kernel elements (flattened)
|
||||||
|
for ke in 0..window_elems {
|
||||||
|
let k_idx = unravel_row_major(ke, kernel, &k_mul);
|
||||||
|
|
||||||
|
let mut flat: isize = 0;
|
||||||
|
let mut in_bounds = true;
|
||||||
|
|
||||||
|
for d in 0..n {
|
||||||
|
let p = start_padded[d] + k_idx[d] * dilation[d];
|
||||||
|
let logical = p as isize - pad_before[d] as isize;
|
||||||
|
|
||||||
|
if logical < 0 || logical >= shape[d] as isize {
|
||||||
|
in_bounds = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
flat += logical * strides[d] as isize;
|
||||||
|
}
|
||||||
|
|
||||||
|
let out_idx = base_out + ke;
|
||||||
|
out[out_idx] = if in_bounds { x[flat as usize] } else { 0.0 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------- helpers --------
|
||||||
|
|
||||||
|
fn prod(xs: &[usize]) -> usize {
|
||||||
|
xs.iter().copied().product()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn row_major_multipliers(shape: &[usize]) -> Vec<usize> {
|
||||||
|
let n = shape.len();
|
||||||
|
let mut mul = vec![1usize; n];
|
||||||
|
let mut acc = 1usize;
|
||||||
|
for d in (0..n).rev() {
|
||||||
|
mul[d] = acc;
|
||||||
|
acc *= shape[d];
|
||||||
|
}
|
||||||
|
mul
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unravel_row_major(mut idx: usize, shape: &[usize], mul: &[usize]) -> Vec<usize> {
|
||||||
|
let n = shape.len();
|
||||||
|
let mut coords = vec![0usize; n];
|
||||||
|
for d in 0..n {
|
||||||
|
coords[d] = idx / mul[d];
|
||||||
|
idx %= mul[d];
|
||||||
|
}
|
||||||
|
coords
|
||||||
|
}
|
||||||
|
|
||||||
|
fn incr_row_major(pos: &mut [usize], shape: &[usize]) {
|
||||||
|
for d in (0..pos.len()).rev() {
|
||||||
|
pos[d] += 1;
|
||||||
|
if pos[d] < shape[d] {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
pos[d] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test_unary(
|
||||||
|
5,
|
||||||
|
|a| a.unfold(3, 1, 1),
|
||||||
|
|a| {
|
||||||
|
Tensor::new(
|
||||||
|
unfold_nd_f32(
|
||||||
|
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
|
||||||
|
a.dims(),
|
||||||
|
a.stride(),
|
||||||
|
&[3],
|
||||||
|
&[1],
|
||||||
|
&[1],
|
||||||
|
&[0],
|
||||||
|
&[0],
|
||||||
|
),
|
||||||
|
a.device(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
test_unary(
|
||||||
|
(8, 10),
|
||||||
|
|a| a.pad(((0, 2), (4, 4))).unfold((2, 3), (1, 2), (2, 1)),
|
||||||
|
|a| {
|
||||||
|
Tensor::new(
|
||||||
|
unfold_nd_f32(
|
||||||
|
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
|
||||||
|
a.dims(),
|
||||||
|
a.stride(),
|
||||||
|
&[2, 3],
|
||||||
|
&[1, 2],
|
||||||
|
&[2, 1],
|
||||||
|
&[0, 4],
|
||||||
|
&[2, 3],
|
||||||
|
),
|
||||||
|
a.device(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,46 +1,5 @@
|
|||||||
use crate::{op::Constant, prelude::*};
|
use crate::{op::Constant, prelude::*};
|
||||||
|
|
||||||
// impl GraphTensor {
|
|
||||||
// /// Cumulative sum last dimension
|
|
||||||
// pub fn cumsum_last_dim(mut self) -> Self {
|
|
||||||
// let axis = self.shape.len() - 1;
|
|
||||||
// if !self.shape.is_contiguous() {
|
|
||||||
// self = self.contiguous();
|
|
||||||
// }
|
|
||||||
// // Pad out length
|
|
||||||
// let orig_length = self.dims()[axis];
|
|
||||||
// self = self.pad_along(orig_length - 1, 0, axis).contiguous();
|
|
||||||
|
|
||||||
// // Pool
|
|
||||||
// self = self.pool_last_dim(orig_length, 1, 1);
|
|
||||||
|
|
||||||
// // Sum Reduce along new dimension
|
|
||||||
// self.sum(axis + 1)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Cumulative max last dimension
|
|
||||||
// pub fn cummax_last_dim(mut self) -> Self {
|
|
||||||
// let axis = self.shape.len() - 1;
|
|
||||||
// if !self.shape.is_contiguous() {
|
|
||||||
// self = self.contiguous();
|
|
||||||
// }
|
|
||||||
// // Pad out length
|
|
||||||
// let orig_length = self.dims()[axis];
|
|
||||||
// self.shape.padding[self.shape.indexes[axis]].0 = orig_length - 1;
|
|
||||||
// self = self.contiguous();
|
|
||||||
|
|
||||||
// // Pool
|
|
||||||
// self = self.pool_last_dim(orig_length, 1, 1);
|
|
||||||
// // Max Reduce along new dimension
|
|
||||||
// self.max(axis + 1)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Cumulative product last dimension
|
|
||||||
// pub fn cumprod_last_dim(self) -> Self {
|
|
||||||
// self.log().cumsum_last_dim().exp()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
impl Graph {
|
impl Graph {
|
||||||
/// A scalar expression constant
|
/// A scalar expression constant
|
||||||
pub fn constant(&mut self, i: impl Into<Expression>) -> GraphTensor {
|
pub fn constant(&mut self, i: impl Into<Expression>) -> GraphTensor {
|
||||||
@@ -90,27 +49,25 @@ impl Graph {
|
|||||||
self.iota((Expression::from('z') * step) + start, (end - start) / step)
|
self.iota((Expression::from('z') * step) + start, (end - start) / step)
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// Lower left-hand triangle of 1s. Currently required to be square
|
/// Lower left-hand triangle of 1s. Currently required to be square
|
||||||
// ///
|
///
|
||||||
// /// Same API as https://pytorch.org/docs/stable/generated/torch.tril
|
/// Same API as https://pytorch.org/docs/stable/generated/torch.tril
|
||||||
// pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||||
// let size = size.into();
|
let size = size.into();
|
||||||
// let horizontal = self.arange(size).expand_dim(0, size);
|
let horizontal = self.arange(size).expand_dim(0, size);
|
||||||
// let vertical = self.arange(size).expand_dim(1, size);
|
let vertical = self.arange(size).expand_dim(1, size);
|
||||||
|
(horizontal - (diagonal as f32 + 1.)).lt(vertical)
|
||||||
|
}
|
||||||
|
|
||||||
// (horizontal - (diagonal as f32 + 1.)).lt(vertical)
|
/// Upper right-hand triangle of 1s
|
||||||
// }
|
///
|
||||||
|
/// Same API as https://pytorch.org/docs/stable/generated/torch.triu
|
||||||
// /// Upper right-hand triangle of 1s
|
pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||||
// ///
|
let size = size.into();
|
||||||
// /// Same API as https://pytorch.org/docs/stable/generated/torch.triu
|
let horizontal = self.arange(size).expand_dim(0, size);
|
||||||
// pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
let vertical = self.arange(size).expand_dim(1, size);
|
||||||
// let size = size.into();
|
(horizontal - (diagonal as f32 - 1.)).gt(vertical)
|
||||||
// let horizontal = self.arange(size).expand_dim(0, size).contiguous();
|
}
|
||||||
// let vertical = self.arange(size).expand_dim(1, size).contiguous();
|
|
||||||
|
|
||||||
// (horizontal - (diagonal as f32 - 1.)).gt(vertical)
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GraphTensor {
|
impl GraphTensor {
|
||||||
@@ -194,70 +151,15 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[test]
|
#[test]
|
||||||
// fn test_cumprod() {
|
fn test_triangle_mask() {
|
||||||
// let mut cx = Graph::new();
|
test_init(
|
||||||
|
|cx| cx.tril(10, 0).cast(DType::F32),
|
||||||
// let a = cx.tensor(3).set(vec![3., 2., 5.]);
|
|dev| Tensor::tril2(10, candle_core::DType::F32, dev).unwrap(),
|
||||||
// let b = a.cumprod_last_dim().retrieve();
|
);
|
||||||
// cx.execute();
|
test_init(
|
||||||
|
|cx| cx.triu(43, 0).cast(DType::F32),
|
||||||
// assert_close(&b.data(), &[3., 6., 30.]);
|
|dev| Tensor::triu2(43, candle_core::DType::F32, dev).unwrap(),
|
||||||
// }
|
);
|
||||||
|
}
|
||||||
// #[test]
|
|
||||||
// fn test_tril() {
|
|
||||||
// let mut cx = Graph::new();
|
|
||||||
|
|
||||||
// let triangle = cx.tril(5, 1).retrieve();
|
|
||||||
|
|
||||||
// cx.execute();
|
|
||||||
|
|
||||||
// assert_exact(
|
|
||||||
// &triangle.data(),
|
|
||||||
// &[
|
|
||||||
// [1.00, 1.00, 0.00, 0.00, 0.00],
|
|
||||||
// [1.00, 1.00, 1.00, 0.00, 0.00],
|
|
||||||
// [1.00, 1.00, 1.00, 1.00, 0.00],
|
|
||||||
// [1.00, 1.00, 1.00, 1.00, 1.00],
|
|
||||||
// [1.00, 1.00, 1.00, 1.00, 1.00],
|
|
||||||
// ]
|
|
||||||
// .into_iter()
|
|
||||||
// .flatten()
|
|
||||||
// .collect::<Vec<_>>(),
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
|
|
||||||
// #[test]
|
|
||||||
// fn test_triu() {
|
|
||||||
// let mut cx = Graph::new();
|
|
||||||
|
|
||||||
// let a = cx.triu(3, -1).retrieve();
|
|
||||||
// let b = cx.triu(3, 0).retrieve();
|
|
||||||
// let c = cx.triu(3, 1).retrieve();
|
|
||||||
|
|
||||||
// cx.execute();
|
|
||||||
|
|
||||||
// assert_exact(
|
|
||||||
// &a.data(),
|
|
||||||
// &[[1.00, 1.00, 1.00], [1.00, 1.00, 1.00], [0.00, 1.00, 1.00]]
|
|
||||||
// .into_iter()
|
|
||||||
// .flatten()
|
|
||||||
// .collect::<Vec<_>>(),
|
|
||||||
// );
|
|
||||||
// assert_exact(
|
|
||||||
// &b.data(),
|
|
||||||
// &[[1.00, 1.00, 1.00], [0.00, 1.00, 1.00], [0.00, 0.00, 1.00]]
|
|
||||||
// .into_iter()
|
|
||||||
// .flatten()
|
|
||||||
// .collect::<Vec<_>>(),
|
|
||||||
// );
|
|
||||||
// assert_exact(
|
|
||||||
// &c.data(),
|
|
||||||
// &[[0.00, 1.00, 1.00], [0.00, 0.00, 1.00], [0.00, 0.00, 0.00]]
|
|
||||||
// .into_iter()
|
|
||||||
// .flatten()
|
|
||||||
// .collect::<Vec<_>>(),
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::{op, prelude::*};
|
use crate::{op, prelude::*};
|
||||||
use std::ops::{Add, Mul, Neg};
|
use std::ops::{Add, Mul, Neg};
|
||||||
|
|
||||||
@@ -135,20 +137,15 @@ impl GraphTensor {
|
|||||||
m - m.exp().sum(axes.to_axes()).log().expand(m.shape)
|
m - m.exp().sum(axes.to_axes()).log().expand(m.shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// Get the indicies of the max elements along the last axis
|
/// Get the indicies of the max elements along an axis
|
||||||
// pub fn argmax(self) -> GraphTensor {
|
pub fn argmax(self, axis: usize) -> GraphTensor {
|
||||||
// // Get one-hot along last dimension
|
// Get one-hot along last dimension
|
||||||
// let x_equal = self.eq(self.max(self.shape.len() - 1).expand(self.shape));
|
let x_equal = self.eq(self.max(axis).expand(self.shape));
|
||||||
// // Create index arange for last dimension
|
// Create index arange for last dimension
|
||||||
// let r = self
|
let r = self.graph().arange(self.dims()[axis]).cast(self.dtype);
|
||||||
// .graph()
|
// Multiply one-hot by expanded index arange
|
||||||
// .constant(1.)
|
(x_equal * r.expand(self.shape)).max(axis).cast(DType::Int)
|
||||||
// .expand(self.shape.dims.last().unwrap())
|
}
|
||||||
// .cumsum_last_dim()
|
|
||||||
// - 1.;
|
|
||||||
// // Multiply one-hot by expanded index arange
|
|
||||||
// (x_equal * r.expand(self.shape)).max(self.shape.len() - 1)
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Take the absolute value
|
/// Take the absolute value
|
||||||
pub fn abs(self) -> GraphTensor {
|
pub fn abs(self) -> GraphTensor {
|
||||||
@@ -220,10 +217,58 @@ impl GraphTensor {
|
|||||||
pub fn topk_indexes(self, k: usize, axis: usize) -> GraphTensor {
|
pub fn topk_indexes(self, k: usize, axis: usize) -> GraphTensor {
|
||||||
self.argsort_indexes(axis, true).slice_along(..k, axis)
|
self.argsort_indexes(axis, true).slice_along(..k, axis)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply a cumulative reduction operation along dimensions
|
||||||
|
///
|
||||||
|
/// See `cumsum` or `cummax` for usage examples.
|
||||||
|
pub fn cumop(
|
||||||
|
mut self,
|
||||||
|
axes: impl ToAxes,
|
||||||
|
op: impl Fn(GraphTensor, &[usize]) -> GraphTensor,
|
||||||
|
) -> Self {
|
||||||
|
let n_dims = self.shape.len();
|
||||||
|
let axes = axes.to_axes();
|
||||||
|
// Pad out length
|
||||||
|
let mut kernel = vec![1.into(); n_dims];
|
||||||
|
let mut padding = vec![(Expression::from(0), Expression::from(0)); n_dims];
|
||||||
|
for &ax in &axes {
|
||||||
|
let orig_length = self.dims()[ax];
|
||||||
|
padding[ax] = ((orig_length - 1).into(), 0.into());
|
||||||
|
kernel[ax] = orig_length;
|
||||||
|
}
|
||||||
|
self = self.pad(padding);
|
||||||
|
// Unfold
|
||||||
|
self = self.unfold(kernel, vec![1; n_dims], vec![1; n_dims]);
|
||||||
|
// Remove non-cumulative dimensions
|
||||||
|
for i in (0..n_dims).rev() {
|
||||||
|
if !axes.contains(&i) {
|
||||||
|
self = self.squeeze(n_dims + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// apply operation along cumulative dimensions
|
||||||
|
op(self, &(0..axes.len()).map(|ax| ax + n_dims).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply a cumulative sum along dimensions
|
||||||
|
pub fn cumsum(self, axes: impl ToAxes) -> Self {
|
||||||
|
self.cumop(axes, |t, axes| t.sum(axes))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply a cumulative max along dimensions
|
||||||
|
pub fn cummax(self, axes: impl ToAxes) -> Self {
|
||||||
|
self.cumop(axes, |t, axes| t.max(axes))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply a cumulative product along dimensions
|
||||||
|
pub fn cumprod(self, axes: impl ToAxes) -> Self {
|
||||||
|
self.cumop(axes, |t, axes| t.prod(axes))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(super) mod tests {
|
pub(super) mod tests {
|
||||||
|
use std::collections::BinaryHeap;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
prelude::*,
|
prelude::*,
|
||||||
tests::{assert_close, random_vec},
|
tests::{assert_close, random_vec},
|
||||||
@@ -231,6 +276,7 @@ pub(super) mod tests {
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
use candle_nn::ops::softmax;
|
use candle_nn::ops::softmax;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use ordered_float::NotNan;
|
||||||
|
|
||||||
pub fn test_unary(
|
pub fn test_unary(
|
||||||
shape: impl ToShape,
|
shape: impl ToShape,
|
||||||
@@ -327,4 +373,77 @@ pub(super) mod tests {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_cumsum() {
|
||||||
|
test_unary(27, |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
|
||||||
|
test_unary((27, 63), |a| a.cumsum(1), |a| a.cumsum(1).unwrap());
|
||||||
|
test_unary((27, 63), |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_argmax() {
|
||||||
|
test_unary(
|
||||||
|
(9, 27),
|
||||||
|
|a| a.argmax(1).cast(DType::F32),
|
||||||
|
|a| {
|
||||||
|
a.argmax(1)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(candle_core::DType::F32)
|
||||||
|
.unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
test_unary(
|
||||||
|
(9, 27),
|
||||||
|
|a| a.argmax(0).cast(DType::F32),
|
||||||
|
|a| {
|
||||||
|
a.argmax(0)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(candle_core::DType::F32)
|
||||||
|
.unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_topk() {
|
||||||
|
pub fn topk_sorted_indices(x: &[f32], k: usize) -> Vec<usize> {
|
||||||
|
if k == 0 {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut heap: BinaryHeap<std::cmp::Reverse<(NotNan<f32>, usize)>> =
|
||||||
|
BinaryHeap::with_capacity(k);
|
||||||
|
|
||||||
|
for (i, &v) in x.iter().enumerate() {
|
||||||
|
let v = NotNan::new(v).expect("NaN encountered in topk");
|
||||||
|
if heap.len() < k {
|
||||||
|
heap.push(std::cmp::Reverse((v, i)));
|
||||||
|
} else if let Some(&std::cmp::Reverse((min_v, _))) = heap.peek() {
|
||||||
|
if v > min_v {
|
||||||
|
heap.pop();
|
||||||
|
heap.push(std::cmp::Reverse((v, i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||||
|
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||||
|
|
||||||
|
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||||
|
out.into_iter().map(|(_, i)| i).collect()
|
||||||
|
}
|
||||||
|
test_unary(
|
||||||
|
(10, 9),
|
||||||
|
|a| a.topk_indexes(5, 1).cast(DType::F32) * 1.0,
|
||||||
|
|a| {
|
||||||
|
let data = a.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
||||||
|
let topk = data
|
||||||
|
.chunks_exact(9)
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|c| topk_sorted_indices(c, 5))
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| i as f32)
|
||||||
|
.collect_vec();
|
||||||
|
Tensor::new(topk, a.device()).unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user