mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
cumulative ops added back and topk fixed
This commit is contained in:
@@ -38,6 +38,7 @@ pretty-duration = "0.1.1"
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.1"
|
||||
candle-nn = "0.9.1"
|
||||
ordered-float = "5.1.0"
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
|
||||
@@ -53,12 +53,22 @@ impl GraphTensor {
|
||||
|
||||
/// add a new dimension of size 1 at the specified place
|
||||
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
|
||||
// Insert contiguous call
|
||||
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
|
||||
self.shape.expand_dim(dim, 1);
|
||||
self
|
||||
}
|
||||
|
||||
/// remove a dimension of size 1
|
||||
pub fn squeeze(mut self, axis: usize) -> GraphTensor {
|
||||
assert_eq!(
|
||||
self.dims()[axis],
|
||||
Expression::from(1),
|
||||
"Only dimensions of size 1 can be squeezed!"
|
||||
);
|
||||
self.shape.remove_dim(axis);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(
|
||||
indexes.dtype,
|
||||
@@ -78,18 +88,32 @@ impl GraphTensor {
|
||||
/// x = [3, 2, 4, 1, 5, 0]
|
||||
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
|
||||
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
|
||||
// TODO: this is very inefficient because we need to do O(n^2) comparisons and allocations for the one-hot and then sum reduce. Need a better way!
|
||||
let ax_size = self.dims()[axis];
|
||||
let mut arange = self.graph().arange(ax_size);
|
||||
// TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
|
||||
assert_eq!(self.dtype, DType::Int);
|
||||
let dims = self.dims();
|
||||
let ax_size = dims[axis];
|
||||
let mut dims2 = dims.clone();
|
||||
dims2.insert(axis, ax_size);
|
||||
// candidate: varies along candidate dim (axis), broadcast elsewhere.
|
||||
let mut candidate = self.graph().arange(ax_size);
|
||||
for i in 0..axis {
|
||||
arange = arange.expand_dim(i, self.dims()[i]);
|
||||
candidate = candidate.expand_dim(i, dims2[i]);
|
||||
}
|
||||
for i in axis + 1..self.dims().len() {
|
||||
arange = arange.expand_dim(i, self.dims()[i]);
|
||||
for i in axis + 1..dims2.len() {
|
||||
candidate = candidate.expand_dim(i, dims2[i]);
|
||||
}
|
||||
arange = arange.expand_dim(axis + 1, ax_size);
|
||||
let one_hot = self.expand_dim(axis + 1, ax_size).eq(arange);
|
||||
(one_hot * arange).sum(axis + 1)
|
||||
// position: varies along position dim (axis+1), broadcast elsewhere.
|
||||
let mut position = self.graph().arange(ax_size);
|
||||
for i in 0..(axis + 1) {
|
||||
position = position.expand_dim(i, dims2[i]);
|
||||
}
|
||||
for i in (axis + 2)..dims2.len() {
|
||||
position = position.expand_dim(i, dims2[i]);
|
||||
}
|
||||
// one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
|
||||
let one_hot = self.expand_dim(axis, ax_size).eq(candidate);
|
||||
// inv[candidate, ...] = Σ_pos one_hot * position
|
||||
(one_hot * position).sum(axis + 1)
|
||||
}
|
||||
|
||||
/// Extracts sliding local windows from an input tensor.
|
||||
@@ -101,6 +125,7 @@ impl GraphTensor {
|
||||
) -> GraphTensor {
|
||||
let (kernel, strides, dilation) =
|
||||
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
|
||||
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
kernel.len(),
|
||||
@@ -117,7 +142,7 @@ impl GraphTensor {
|
||||
"Dilation must be same number of dimensions as tensor!"
|
||||
);
|
||||
|
||||
// Compute input strides
|
||||
// Compute input strides (row-major contiguous)
|
||||
let dims = self.dims();
|
||||
let n = dims.len();
|
||||
let mut in_strides = vec![Expression::from(1); n];
|
||||
@@ -128,26 +153,28 @@ impl GraphTensor {
|
||||
}
|
||||
|
||||
// Per-dim window counts
|
||||
let mut win = Vec::with_capacity(dims.len());
|
||||
for (((dim, kernel), stride), dilation) in
|
||||
dims.into_iter().zip(&kernel).zip(&strides).zip(&dilation)
|
||||
{
|
||||
let effective_window = *dilation * (*kernel - 1) + 1;
|
||||
win.push(((dim - effective_window) / stride) + 1);
|
||||
let mut win = Vec::with_capacity(n);
|
||||
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
|
||||
let effective_window = *d * (*k - 1) + 1;
|
||||
win.push(((*dim - effective_window) / s) + 1);
|
||||
}
|
||||
|
||||
// final_shape = [kernel..., win...]
|
||||
let mut final_shape = kernel.clone();
|
||||
final_shape.extend(win.into_iter().map(|e| e.simplify()));
|
||||
// [win..., kernel...]
|
||||
let mut final_shape: Vec<Expression> = win.into_iter().map(|e| e.simplify()).collect();
|
||||
final_shape.extend(kernel.iter().copied());
|
||||
|
||||
// strides exprs over axes [k0..kN-1, w0..wN-1]
|
||||
// Axis exprs must match final_shape axis order: first w axes, then k axes.
|
||||
// idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
|
||||
let mut axis_exprs = Vec::with_capacity(2 * n);
|
||||
for i in 0..n {
|
||||
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
|
||||
}
|
||||
|
||||
// w axes
|
||||
for i in 0..n {
|
||||
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
|
||||
}
|
||||
// k axes
|
||||
for i in 0..n {
|
||||
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
|
||||
}
|
||||
|
||||
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
|
||||
let iota = self.graph().iota(index_expression, final_shape);
|
||||
@@ -266,6 +293,7 @@ impl GraphTensor {
|
||||
new_tensor * mask
|
||||
}
|
||||
|
||||
/// Pad along an existing dimension
|
||||
pub fn pad_along(
|
||||
self,
|
||||
left: impl Into<Expression>,
|
||||
@@ -338,10 +366,170 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_unfold() {
|
||||
let mut cx = Graph::new();
|
||||
// Need all this code because candle doesnt do unfold
|
||||
pub fn unfold_nd_f32(
|
||||
x: &[f32],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
kernel: &[usize],
|
||||
step: &[usize],
|
||||
dilation: &[usize],
|
||||
pad_before: &[usize],
|
||||
pad_after: &[usize],
|
||||
) -> Vec<f32> {
|
||||
let n = shape.len();
|
||||
assert!(n > 0);
|
||||
assert_eq!(strides.len(), n);
|
||||
assert_eq!(kernel.len(), n);
|
||||
assert_eq!(step.len(), n);
|
||||
assert_eq!(dilation.len(), n);
|
||||
assert_eq!(pad_before.len(), n);
|
||||
assert_eq!(pad_after.len(), n);
|
||||
|
||||
let inp = cx.tensor((5,));
|
||||
let _pooled = inp.unfold((3,), (1,), (1,));
|
||||
for d in 0..n {
|
||||
assert!(kernel[d] > 0);
|
||||
assert!(step[d] > 0);
|
||||
assert!(dilation[d] > 0);
|
||||
assert!(shape[d] > 0);
|
||||
}
|
||||
|
||||
// Effective kernel size per dim: (K-1)*d + 1
|
||||
let eff_kernel: Vec<usize> =
|
||||
(0..n).map(|d| (kernel[d] - 1) * dilation[d] + 1).collect();
|
||||
|
||||
// Output spatial shape (number of windows) per dim
|
||||
let mut out_shape = vec![0usize; n];
|
||||
for d in 0..n {
|
||||
let padded = shape[d] + pad_before[d] + pad_after[d];
|
||||
if padded < eff_kernel[d] {
|
||||
return Vec::new();
|
||||
}
|
||||
out_shape[d] = (padded - eff_kernel[d]) / step[d] + 1;
|
||||
}
|
||||
|
||||
let windows = prod(&out_shape);
|
||||
let window_elems = prod(kernel);
|
||||
let mut out = vec![0.0f32; windows * window_elems];
|
||||
|
||||
// Precompute helpers
|
||||
let k_mul = row_major_multipliers(kernel);
|
||||
|
||||
// Current output window position (row-major)
|
||||
let mut out_pos = vec![0usize; n];
|
||||
|
||||
for w in 0..windows {
|
||||
if w > 0 {
|
||||
incr_row_major(&mut out_pos, &out_shape);
|
||||
}
|
||||
|
||||
// Window start in padded coordinates
|
||||
let start_padded: Vec<usize> = (0..n).map(|d| out_pos[d] * step[d]).collect();
|
||||
|
||||
let base_out = w * window_elems;
|
||||
|
||||
// Iterate kernel elements (flattened)
|
||||
for ke in 0..window_elems {
|
||||
let k_idx = unravel_row_major(ke, kernel, &k_mul);
|
||||
|
||||
let mut flat: isize = 0;
|
||||
let mut in_bounds = true;
|
||||
|
||||
for d in 0..n {
|
||||
let p = start_padded[d] + k_idx[d] * dilation[d];
|
||||
let logical = p as isize - pad_before[d] as isize;
|
||||
|
||||
if logical < 0 || logical >= shape[d] as isize {
|
||||
in_bounds = false;
|
||||
break;
|
||||
}
|
||||
flat += logical * strides[d] as isize;
|
||||
}
|
||||
|
||||
let out_idx = base_out + ke;
|
||||
out[out_idx] = if in_bounds { x[flat as usize] } else { 0.0 };
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
// -------- helpers --------
|
||||
|
||||
fn prod(xs: &[usize]) -> usize {
|
||||
xs.iter().copied().product()
|
||||
}
|
||||
|
||||
fn row_major_multipliers(shape: &[usize]) -> Vec<usize> {
|
||||
let n = shape.len();
|
||||
let mut mul = vec![1usize; n];
|
||||
let mut acc = 1usize;
|
||||
for d in (0..n).rev() {
|
||||
mul[d] = acc;
|
||||
acc *= shape[d];
|
||||
}
|
||||
mul
|
||||
}
|
||||
|
||||
fn unravel_row_major(mut idx: usize, shape: &[usize], mul: &[usize]) -> Vec<usize> {
|
||||
let n = shape.len();
|
||||
let mut coords = vec![0usize; n];
|
||||
for d in 0..n {
|
||||
coords[d] = idx / mul[d];
|
||||
idx %= mul[d];
|
||||
}
|
||||
coords
|
||||
}
|
||||
|
||||
fn incr_row_major(pos: &mut [usize], shape: &[usize]) {
|
||||
for d in (0..pos.len()).rev() {
|
||||
pos[d] += 1;
|
||||
if pos[d] < shape[d] {
|
||||
return;
|
||||
}
|
||||
pos[d] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
test_unary(
|
||||
5,
|
||||
|a| a.unfold(3, 1, 1),
|
||||
|a| {
|
||||
Tensor::new(
|
||||
unfold_nd_f32(
|
||||
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
|
||||
a.dims(),
|
||||
a.stride(),
|
||||
&[3],
|
||||
&[1],
|
||||
&[1],
|
||||
&[0],
|
||||
&[0],
|
||||
),
|
||||
a.device(),
|
||||
)
|
||||
.unwrap()
|
||||
},
|
||||
);
|
||||
test_unary(
|
||||
(8, 10),
|
||||
|a| a.pad(((0, 2), (4, 4))).unfold((2, 3), (1, 2), (2, 1)),
|
||||
|a| {
|
||||
Tensor::new(
|
||||
unfold_nd_f32(
|
||||
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
|
||||
a.dims(),
|
||||
a.stride(),
|
||||
&[2, 3],
|
||||
&[1, 2],
|
||||
&[2, 1],
|
||||
&[0, 4],
|
||||
&[2, 3],
|
||||
),
|
||||
a.device(),
|
||||
)
|
||||
.unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,46 +1,5 @@
|
||||
use crate::{op::Constant, prelude::*};
|
||||
|
||||
// impl GraphTensor {
|
||||
// /// Cumulative sum last dimension
|
||||
// pub fn cumsum_last_dim(mut self) -> Self {
|
||||
// let axis = self.shape.len() - 1;
|
||||
// if !self.shape.is_contiguous() {
|
||||
// self = self.contiguous();
|
||||
// }
|
||||
// // Pad out length
|
||||
// let orig_length = self.dims()[axis];
|
||||
// self = self.pad_along(orig_length - 1, 0, axis).contiguous();
|
||||
|
||||
// // Pool
|
||||
// self = self.pool_last_dim(orig_length, 1, 1);
|
||||
|
||||
// // Sum Reduce along new dimension
|
||||
// self.sum(axis + 1)
|
||||
// }
|
||||
|
||||
// /// Cumulative max last dimension
|
||||
// pub fn cummax_last_dim(mut self) -> Self {
|
||||
// let axis = self.shape.len() - 1;
|
||||
// if !self.shape.is_contiguous() {
|
||||
// self = self.contiguous();
|
||||
// }
|
||||
// // Pad out length
|
||||
// let orig_length = self.dims()[axis];
|
||||
// self.shape.padding[self.shape.indexes[axis]].0 = orig_length - 1;
|
||||
// self = self.contiguous();
|
||||
|
||||
// // Pool
|
||||
// self = self.pool_last_dim(orig_length, 1, 1);
|
||||
// // Max Reduce along new dimension
|
||||
// self.max(axis + 1)
|
||||
// }
|
||||
|
||||
// /// Cumulative product last dimension
|
||||
// pub fn cumprod_last_dim(self) -> Self {
|
||||
// self.log().cumsum_last_dim().exp()
|
||||
// }
|
||||
// }
|
||||
|
||||
impl Graph {
|
||||
/// A scalar expression constant
|
||||
pub fn constant(&mut self, i: impl Into<Expression>) -> GraphTensor {
|
||||
@@ -90,27 +49,25 @@ impl Graph {
|
||||
self.iota((Expression::from('z') * step) + start, (end - start) / step)
|
||||
}
|
||||
|
||||
// /// Lower left-hand triangle of 1s. Currently required to be square
|
||||
// ///
|
||||
// /// Same API as https://pytorch.org/docs/stable/generated/torch.tril
|
||||
// pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
// let size = size.into();
|
||||
// let horizontal = self.arange(size).expand_dim(0, size);
|
||||
// let vertical = self.arange(size).expand_dim(1, size);
|
||||
/// Lower left-hand triangle of 1s. Currently required to be square
|
||||
///
|
||||
/// Same API as https://pytorch.org/docs/stable/generated/torch.tril
|
||||
pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
let size = size.into();
|
||||
let horizontal = self.arange(size).expand_dim(0, size);
|
||||
let vertical = self.arange(size).expand_dim(1, size);
|
||||
(horizontal - (diagonal as f32 + 1.)).lt(vertical)
|
||||
}
|
||||
|
||||
// (horizontal - (diagonal as f32 + 1.)).lt(vertical)
|
||||
// }
|
||||
|
||||
// /// Upper right-hand triangle of 1s
|
||||
// ///
|
||||
// /// Same API as https://pytorch.org/docs/stable/generated/torch.triu
|
||||
// pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
// let size = size.into();
|
||||
// let horizontal = self.arange(size).expand_dim(0, size).contiguous();
|
||||
// let vertical = self.arange(size).expand_dim(1, size).contiguous();
|
||||
|
||||
// (horizontal - (diagonal as f32 - 1.)).gt(vertical)
|
||||
// }
|
||||
/// Upper right-hand triangle of 1s
|
||||
///
|
||||
/// Same API as https://pytorch.org/docs/stable/generated/torch.triu
|
||||
pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
let size = size.into();
|
||||
let horizontal = self.arange(size).expand_dim(0, size);
|
||||
let vertical = self.arange(size).expand_dim(1, size);
|
||||
(horizontal - (diagonal as f32 - 1.)).gt(vertical)
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphTensor {
|
||||
@@ -194,70 +151,15 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_cumprod() {
|
||||
// let mut cx = Graph::new();
|
||||
|
||||
// let a = cx.tensor(3).set(vec![3., 2., 5.]);
|
||||
// let b = a.cumprod_last_dim().retrieve();
|
||||
// cx.execute();
|
||||
|
||||
// assert_close(&b.data(), &[3., 6., 30.]);
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_tril() {
|
||||
// let mut cx = Graph::new();
|
||||
|
||||
// let triangle = cx.tril(5, 1).retrieve();
|
||||
|
||||
// cx.execute();
|
||||
|
||||
// assert_exact(
|
||||
// &triangle.data(),
|
||||
// &[
|
||||
// [1.00, 1.00, 0.00, 0.00, 0.00],
|
||||
// [1.00, 1.00, 1.00, 0.00, 0.00],
|
||||
// [1.00, 1.00, 1.00, 1.00, 0.00],
|
||||
// [1.00, 1.00, 1.00, 1.00, 1.00],
|
||||
// [1.00, 1.00, 1.00, 1.00, 1.00],
|
||||
// ]
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect::<Vec<_>>(),
|
||||
// );
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_triu() {
|
||||
// let mut cx = Graph::new();
|
||||
|
||||
// let a = cx.triu(3, -1).retrieve();
|
||||
// let b = cx.triu(3, 0).retrieve();
|
||||
// let c = cx.triu(3, 1).retrieve();
|
||||
|
||||
// cx.execute();
|
||||
|
||||
// assert_exact(
|
||||
// &a.data(),
|
||||
// &[[1.00, 1.00, 1.00], [1.00, 1.00, 1.00], [0.00, 1.00, 1.00]]
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect::<Vec<_>>(),
|
||||
// );
|
||||
// assert_exact(
|
||||
// &b.data(),
|
||||
// &[[1.00, 1.00, 1.00], [0.00, 1.00, 1.00], [0.00, 0.00, 1.00]]
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect::<Vec<_>>(),
|
||||
// );
|
||||
// assert_exact(
|
||||
// &c.data(),
|
||||
// &[[0.00, 1.00, 1.00], [0.00, 0.00, 1.00], [0.00, 0.00, 0.00]]
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect::<Vec<_>>(),
|
||||
// );
|
||||
// }
|
||||
#[test]
|
||||
fn test_triangle_mask() {
|
||||
test_init(
|
||||
|cx| cx.tril(10, 0).cast(DType::F32),
|
||||
|dev| Tensor::tril2(10, candle_core::DType::F32, dev).unwrap(),
|
||||
);
|
||||
test_init(
|
||||
|cx| cx.triu(43, 0).cast(DType::F32),
|
||||
|dev| Tensor::triu2(43, candle_core::DType::F32, dev).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{op, prelude::*};
|
||||
use std::ops::{Add, Mul, Neg};
|
||||
|
||||
@@ -135,20 +137,15 @@ impl GraphTensor {
|
||||
m - m.exp().sum(axes.to_axes()).log().expand(m.shape)
|
||||
}
|
||||
|
||||
// /// Get the indicies of the max elements along the last axis
|
||||
// pub fn argmax(self) -> GraphTensor {
|
||||
// // Get one-hot along last dimension
|
||||
// let x_equal = self.eq(self.max(self.shape.len() - 1).expand(self.shape));
|
||||
// // Create index arange for last dimension
|
||||
// let r = self
|
||||
// .graph()
|
||||
// .constant(1.)
|
||||
// .expand(self.shape.dims.last().unwrap())
|
||||
// .cumsum_last_dim()
|
||||
// - 1.;
|
||||
// // Multiply one-hot by expanded index arange
|
||||
// (x_equal * r.expand(self.shape)).max(self.shape.len() - 1)
|
||||
// }
|
||||
/// Get the indicies of the max elements along an axis
|
||||
pub fn argmax(self, axis: usize) -> GraphTensor {
|
||||
// Get one-hot along last dimension
|
||||
let x_equal = self.eq(self.max(axis).expand(self.shape));
|
||||
// Create index arange for last dimension
|
||||
let r = self.graph().arange(self.dims()[axis]).cast(self.dtype);
|
||||
// Multiply one-hot by expanded index arange
|
||||
(x_equal * r.expand(self.shape)).max(axis).cast(DType::Int)
|
||||
}
|
||||
|
||||
/// Take the absolute value
|
||||
pub fn abs(self) -> GraphTensor {
|
||||
@@ -220,10 +217,58 @@ impl GraphTensor {
|
||||
pub fn topk_indexes(self, k: usize, axis: usize) -> GraphTensor {
|
||||
self.argsort_indexes(axis, true).slice_along(..k, axis)
|
||||
}
|
||||
|
||||
/// Apply a cumulative reduction operation along dimensions
|
||||
///
|
||||
/// See `cumsum` or `cummax` for usage examples.
|
||||
pub fn cumop(
|
||||
mut self,
|
||||
axes: impl ToAxes,
|
||||
op: impl Fn(GraphTensor, &[usize]) -> GraphTensor,
|
||||
) -> Self {
|
||||
let n_dims = self.shape.len();
|
||||
let axes = axes.to_axes();
|
||||
// Pad out length
|
||||
let mut kernel = vec![1.into(); n_dims];
|
||||
let mut padding = vec![(Expression::from(0), Expression::from(0)); n_dims];
|
||||
for &ax in &axes {
|
||||
let orig_length = self.dims()[ax];
|
||||
padding[ax] = ((orig_length - 1).into(), 0.into());
|
||||
kernel[ax] = orig_length;
|
||||
}
|
||||
self = self.pad(padding);
|
||||
// Unfold
|
||||
self = self.unfold(kernel, vec![1; n_dims], vec![1; n_dims]);
|
||||
// Remove non-cumulative dimensions
|
||||
for i in (0..n_dims).rev() {
|
||||
if !axes.contains(&i) {
|
||||
self = self.squeeze(n_dims + i);
|
||||
}
|
||||
}
|
||||
// apply operation along cumulative dimensions
|
||||
op(self, &(0..axes.len()).map(|ax| ax + n_dims).collect_vec())
|
||||
}
|
||||
|
||||
/// Apply a cumulative sum along dimensions
|
||||
pub fn cumsum(self, axes: impl ToAxes) -> Self {
|
||||
self.cumop(axes, |t, axes| t.sum(axes))
|
||||
}
|
||||
|
||||
/// Apply a cumulative max along dimensions
|
||||
pub fn cummax(self, axes: impl ToAxes) -> Self {
|
||||
self.cumop(axes, |t, axes| t.max(axes))
|
||||
}
|
||||
|
||||
/// Apply a cumulative product along dimensions
|
||||
pub fn cumprod(self, axes: impl ToAxes) -> Self {
|
||||
self.cumop(axes, |t, axes| t.prod(axes))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) mod tests {
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
use crate::{
|
||||
prelude::*,
|
||||
tests::{assert_close, random_vec},
|
||||
@@ -231,6 +276,7 @@ pub(super) mod tests {
|
||||
use candle_core::{Device, Tensor};
|
||||
use candle_nn::ops::softmax;
|
||||
use itertools::Itertools;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
pub fn test_unary(
|
||||
shape: impl ToShape,
|
||||
@@ -327,4 +373,77 @@ pub(super) mod tests {
|
||||
},
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_cumsum() {
|
||||
test_unary(27, |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
|
||||
test_unary((27, 63), |a| a.cumsum(1), |a| a.cumsum(1).unwrap());
|
||||
test_unary((27, 63), |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
|
||||
}
|
||||
#[test]
|
||||
fn test_argmax() {
|
||||
test_unary(
|
||||
(9, 27),
|
||||
|a| a.argmax(1).cast(DType::F32),
|
||||
|a| {
|
||||
a.argmax(1)
|
||||
.unwrap()
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
);
|
||||
test_unary(
|
||||
(9, 27),
|
||||
|a| a.argmax(0).cast(DType::F32),
|
||||
|a| {
|
||||
a.argmax(0)
|
||||
.unwrap()
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_topk() {
|
||||
pub fn topk_sorted_indices(x: &[f32], k: usize) -> Vec<usize> {
|
||||
if k == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut heap: BinaryHeap<std::cmp::Reverse<(NotNan<f32>, usize)>> =
|
||||
BinaryHeap::with_capacity(k);
|
||||
|
||||
for (i, &v) in x.iter().enumerate() {
|
||||
let v = NotNan::new(v).expect("NaN encountered in topk");
|
||||
if heap.len() < k {
|
||||
heap.push(std::cmp::Reverse((v, i)));
|
||||
} else if let Some(&std::cmp::Reverse((min_v, _))) = heap.peek() {
|
||||
if v > min_v {
|
||||
heap.pop();
|
||||
heap.push(std::cmp::Reverse((v, i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
(10, 9),
|
||||
|a| a.topk_indexes(5, 1).cast(DType::F32) * 1.0,
|
||||
|a| {
|
||||
let data = a.flatten_all().unwrap().to_vec1::<f32>().unwrap();
|
||||
let topk = data
|
||||
.chunks_exact(9)
|
||||
.into_iter()
|
||||
.flat_map(|c| topk_sorted_indices(c, 5))
|
||||
.into_iter()
|
||||
.map(|i| i as f32)
|
||||
.collect_vec();
|
||||
Tensor::new(topk, a.device()).unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user