cumulative ops added back and topk fixed

This commit is contained in:
Joe Fioti
2025-12-23 10:46:47 -05:00
parent 6788b98d38
commit 90c96fbe91
4 changed files with 378 additions and 168 deletions

View File

@@ -38,6 +38,7 @@ pretty-duration = "0.1.1"
[dev-dependencies]
candle-core = "0.9.1"
candle-nn = "0.9.1"
ordered-float = "5.1.0"
[workspace]
members = [

View File

@@ -53,12 +53,22 @@ impl GraphTensor {
/// add a new dimension of size 1 at the specified place
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
// Insert contiguous call
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
self.shape.expand_dim(dim, 1);
self
}
/// remove a dimension of size 1
pub fn squeeze(mut self, axis: usize) -> GraphTensor {
assert_eq!(
self.dims()[axis],
Expression::from(1),
"Only dimensions of size 1 can be squeezed!"
);
self.shape.remove_dim(axis);
self
}
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
assert_eq!(
indexes.dtype,
@@ -78,18 +88,32 @@ impl GraphTensor {
/// x = [3, 2, 4, 1, 5, 0]
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
// TODO: this is very inefficient because we need to do O(n^2) comparisons and allocations for the one-hot and then sum reduce. Need a better way!
let ax_size = self.dims()[axis];
let mut arange = self.graph().arange(ax_size);
// TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
assert_eq!(self.dtype, DType::Int);
let dims = self.dims();
let ax_size = dims[axis];
let mut dims2 = dims.clone();
dims2.insert(axis, ax_size);
// candidate: varies along candidate dim (axis), broadcast elsewhere.
let mut candidate = self.graph().arange(ax_size);
for i in 0..axis {
arange = arange.expand_dim(i, self.dims()[i]);
candidate = candidate.expand_dim(i, dims2[i]);
}
for i in axis + 1..self.dims().len() {
arange = arange.expand_dim(i, self.dims()[i]);
for i in axis + 1..dims2.len() {
candidate = candidate.expand_dim(i, dims2[i]);
}
arange = arange.expand_dim(axis + 1, ax_size);
let one_hot = self.expand_dim(axis + 1, ax_size).eq(arange);
(one_hot * arange).sum(axis + 1)
// position: varies along position dim (axis+1), broadcast elsewhere.
let mut position = self.graph().arange(ax_size);
for i in 0..(axis + 1) {
position = position.expand_dim(i, dims2[i]);
}
for i in (axis + 2)..dims2.len() {
position = position.expand_dim(i, dims2[i]);
}
// one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
let one_hot = self.expand_dim(axis, ax_size).eq(candidate);
// inv[candidate, ...] = Σ_pos one_hot * position
(one_hot * position).sum(axis + 1)
}
/// Extracts sliding local windows from an input tensor.
@@ -101,6 +125,7 @@ impl GraphTensor {
) -> GraphTensor {
let (kernel, strides, dilation) =
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
assert_eq!(
self.shape.len(),
kernel.len(),
@@ -117,7 +142,7 @@ impl GraphTensor {
"Dilation must be same number of dimensions as tensor!"
);
// Compute input strides
// Compute input strides (row-major contiguous)
let dims = self.dims();
let n = dims.len();
let mut in_strides = vec![Expression::from(1); n];
@@ -128,26 +153,28 @@ impl GraphTensor {
}
// Per-dim window counts
let mut win = Vec::with_capacity(dims.len());
for (((dim, kernel), stride), dilation) in
dims.into_iter().zip(&kernel).zip(&strides).zip(&dilation)
{
let effective_window = *dilation * (*kernel - 1) + 1;
win.push(((dim - effective_window) / stride) + 1);
let mut win = Vec::with_capacity(n);
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
let effective_window = *d * (*k - 1) + 1;
win.push(((*dim - effective_window) / s) + 1);
}
// final_shape = [kernel..., win...]
let mut final_shape = kernel.clone();
final_shape.extend(win.into_iter().map(|e| e.simplify()));
// [win..., kernel...]
let mut final_shape: Vec<Expression> = win.into_iter().map(|e| e.simplify()).collect();
final_shape.extend(kernel.iter().copied());
// strides exprs over axes [k0..kN-1, w0..wN-1]
// Axis exprs must match final_shape axis order: first w axes, then k axes.
// idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
let mut axis_exprs = Vec::with_capacity(2 * n);
for i in 0..n {
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
}
// w axes
for i in 0..n {
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
}
// k axes
for i in 0..n {
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
}
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
let iota = self.graph().iota(index_expression, final_shape);
@@ -266,6 +293,7 @@ impl GraphTensor {
new_tensor * mask
}
/// Pad along an existing dimension
pub fn pad_along(
self,
left: impl Into<Expression>,
@@ -338,10 +366,170 @@ mod tests {
#[test]
fn test_unfold() {
let mut cx = Graph::new();
// Need all this code because candle doesnt do unfold
pub fn unfold_nd_f32(
x: &[f32],
shape: &[usize],
strides: &[usize],
kernel: &[usize],
step: &[usize],
dilation: &[usize],
pad_before: &[usize],
pad_after: &[usize],
) -> Vec<f32> {
let n = shape.len();
assert!(n > 0);
assert_eq!(strides.len(), n);
assert_eq!(kernel.len(), n);
assert_eq!(step.len(), n);
assert_eq!(dilation.len(), n);
assert_eq!(pad_before.len(), n);
assert_eq!(pad_after.len(), n);
let inp = cx.tensor((5,));
let _pooled = inp.unfold((3,), (1,), (1,));
for d in 0..n {
assert!(kernel[d] > 0);
assert!(step[d] > 0);
assert!(dilation[d] > 0);
assert!(shape[d] > 0);
}
// Effective kernel size per dim: (K-1)*d + 1
let eff_kernel: Vec<usize> =
(0..n).map(|d| (kernel[d] - 1) * dilation[d] + 1).collect();
// Output spatial shape (number of windows) per dim
let mut out_shape = vec![0usize; n];
for d in 0..n {
let padded = shape[d] + pad_before[d] + pad_after[d];
if padded < eff_kernel[d] {
return Vec::new();
}
out_shape[d] = (padded - eff_kernel[d]) / step[d] + 1;
}
let windows = prod(&out_shape);
let window_elems = prod(kernel);
let mut out = vec![0.0f32; windows * window_elems];
// Precompute helpers
let k_mul = row_major_multipliers(kernel);
// Current output window position (row-major)
let mut out_pos = vec![0usize; n];
for w in 0..windows {
if w > 0 {
incr_row_major(&mut out_pos, &out_shape);
}
// Window start in padded coordinates
let start_padded: Vec<usize> = (0..n).map(|d| out_pos[d] * step[d]).collect();
let base_out = w * window_elems;
// Iterate kernel elements (flattened)
for ke in 0..window_elems {
let k_idx = unravel_row_major(ke, kernel, &k_mul);
let mut flat: isize = 0;
let mut in_bounds = true;
for d in 0..n {
let p = start_padded[d] + k_idx[d] * dilation[d];
let logical = p as isize - pad_before[d] as isize;
if logical < 0 || logical >= shape[d] as isize {
in_bounds = false;
break;
}
flat += logical * strides[d] as isize;
}
let out_idx = base_out + ke;
out[out_idx] = if in_bounds { x[flat as usize] } else { 0.0 };
}
}
out
}
// -------- helpers --------
fn prod(xs: &[usize]) -> usize {
xs.iter().copied().product()
}
fn row_major_multipliers(shape: &[usize]) -> Vec<usize> {
let n = shape.len();
let mut mul = vec![1usize; n];
let mut acc = 1usize;
for d in (0..n).rev() {
mul[d] = acc;
acc *= shape[d];
}
mul
}
fn unravel_row_major(mut idx: usize, shape: &[usize], mul: &[usize]) -> Vec<usize> {
let n = shape.len();
let mut coords = vec![0usize; n];
for d in 0..n {
coords[d] = idx / mul[d];
idx %= mul[d];
}
coords
}
fn incr_row_major(pos: &mut [usize], shape: &[usize]) {
for d in (0..pos.len()).rev() {
pos[d] += 1;
if pos[d] < shape[d] {
return;
}
pos[d] = 0;
}
}
test_unary(
5,
|a| a.unfold(3, 1, 1),
|a| {
Tensor::new(
unfold_nd_f32(
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
a.dims(),
a.stride(),
&[3],
&[1],
&[1],
&[0],
&[0],
),
a.device(),
)
.unwrap()
},
);
test_unary(
(8, 10),
|a| a.pad(((0, 2), (4, 4))).unfold((2, 3), (1, 2), (2, 1)),
|a| {
Tensor::new(
unfold_nd_f32(
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
a.dims(),
a.stride(),
&[2, 3],
&[1, 2],
&[2, 1],
&[0, 4],
&[2, 3],
),
a.device(),
)
.unwrap()
},
);
}
#[test]

View File

@@ -1,46 +1,5 @@
use crate::{op::Constant, prelude::*};
// impl GraphTensor {
// /// Cumulative sum last dimension
// pub fn cumsum_last_dim(mut self) -> Self {
// let axis = self.shape.len() - 1;
// if !self.shape.is_contiguous() {
// self = self.contiguous();
// }
// // Pad out length
// let orig_length = self.dims()[axis];
// self = self.pad_along(orig_length - 1, 0, axis).contiguous();
// // Pool
// self = self.pool_last_dim(orig_length, 1, 1);
// // Sum Reduce along new dimension
// self.sum(axis + 1)
// }
// /// Cumulative max last dimension
// pub fn cummax_last_dim(mut self) -> Self {
// let axis = self.shape.len() - 1;
// if !self.shape.is_contiguous() {
// self = self.contiguous();
// }
// // Pad out length
// let orig_length = self.dims()[axis];
// self.shape.padding[self.shape.indexes[axis]].0 = orig_length - 1;
// self = self.contiguous();
// // Pool
// self = self.pool_last_dim(orig_length, 1, 1);
// // Max Reduce along new dimension
// self.max(axis + 1)
// }
// /// Cumulative product last dimension
// pub fn cumprod_last_dim(self) -> Self {
// self.log().cumsum_last_dim().exp()
// }
// }
impl Graph {
/// A scalar expression constant
pub fn constant(&mut self, i: impl Into<Expression>) -> GraphTensor {
@@ -90,27 +49,25 @@ impl Graph {
self.iota((Expression::from('z') * step) + start, (end - start) / step)
}
// /// Lower left-hand triangle of 1s. Currently required to be square
// ///
// /// Same API as https://pytorch.org/docs/stable/generated/torch.tril
// pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
// let size = size.into();
// let horizontal = self.arange(size).expand_dim(0, size);
// let vertical = self.arange(size).expand_dim(1, size);
/// Lower left-hand triangle of 1s. Currently required to be square
///
/// Same API as https://pytorch.org/docs/stable/generated/torch.tril
pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
let size = size.into();
let horizontal = self.arange(size).expand_dim(0, size);
let vertical = self.arange(size).expand_dim(1, size);
(horizontal - (diagonal as f32 + 1.)).lt(vertical)
}
// (horizontal - (diagonal as f32 + 1.)).lt(vertical)
// }
// /// Upper right-hand triangle of 1s
// ///
// /// Same API as https://pytorch.org/docs/stable/generated/torch.triu
// pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
// let size = size.into();
// let horizontal = self.arange(size).expand_dim(0, size).contiguous();
// let vertical = self.arange(size).expand_dim(1, size).contiguous();
// (horizontal - (diagonal as f32 - 1.)).gt(vertical)
// }
/// Upper right-hand triangle of 1s
///
/// Same API as https://pytorch.org/docs/stable/generated/torch.triu
pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
let size = size.into();
let horizontal = self.arange(size).expand_dim(0, size);
let vertical = self.arange(size).expand_dim(1, size);
(horizontal - (diagonal as f32 - 1.)).gt(vertical)
}
}
impl GraphTensor {
@@ -194,70 +151,15 @@ mod tests {
);
}
// #[test]
// fn test_cumprod() {
// let mut cx = Graph::new();
// let a = cx.tensor(3).set(vec![3., 2., 5.]);
// let b = a.cumprod_last_dim().retrieve();
// cx.execute();
// assert_close(&b.data(), &[3., 6., 30.]);
// }
// #[test]
// fn test_tril() {
// let mut cx = Graph::new();
// let triangle = cx.tril(5, 1).retrieve();
// cx.execute();
// assert_exact(
// &triangle.data(),
// &[
// [1.00, 1.00, 0.00, 0.00, 0.00],
// [1.00, 1.00, 1.00, 0.00, 0.00],
// [1.00, 1.00, 1.00, 1.00, 0.00],
// [1.00, 1.00, 1.00, 1.00, 1.00],
// [1.00, 1.00, 1.00, 1.00, 1.00],
// ]
// .into_iter()
// .flatten()
// .collect::<Vec<_>>(),
// );
// }
// #[test]
// fn test_triu() {
// let mut cx = Graph::new();
// let a = cx.triu(3, -1).retrieve();
// let b = cx.triu(3, 0).retrieve();
// let c = cx.triu(3, 1).retrieve();
// cx.execute();
// assert_exact(
// &a.data(),
// &[[1.00, 1.00, 1.00], [1.00, 1.00, 1.00], [0.00, 1.00, 1.00]]
// .into_iter()
// .flatten()
// .collect::<Vec<_>>(),
// );
// assert_exact(
// &b.data(),
// &[[1.00, 1.00, 1.00], [0.00, 1.00, 1.00], [0.00, 0.00, 1.00]]
// .into_iter()
// .flatten()
// .collect::<Vec<_>>(),
// );
// assert_exact(
// &c.data(),
// &[[0.00, 1.00, 1.00], [0.00, 0.00, 1.00], [0.00, 0.00, 0.00]]
// .into_iter()
// .flatten()
// .collect::<Vec<_>>(),
// );
// }
#[test]
fn test_triangle_mask() {
test_init(
|cx| cx.tril(10, 0).cast(DType::F32),
|dev| Tensor::tril2(10, candle_core::DType::F32, dev).unwrap(),
);
test_init(
|cx| cx.triu(43, 0).cast(DType::F32),
|dev| Tensor::triu2(43, candle_core::DType::F32, dev).unwrap(),
);
}
}

View File

@@ -1,3 +1,5 @@
use itertools::Itertools;
use crate::{op, prelude::*};
use std::ops::{Add, Mul, Neg};
@@ -135,20 +137,15 @@ impl GraphTensor {
m - m.exp().sum(axes.to_axes()).log().expand(m.shape)
}
// /// Get the indicies of the max elements along the last axis
// pub fn argmax(self) -> GraphTensor {
// // Get one-hot along last dimension
// let x_equal = self.eq(self.max(self.shape.len() - 1).expand(self.shape));
// // Create index arange for last dimension
// let r = self
// .graph()
// .constant(1.)
// .expand(self.shape.dims.last().unwrap())
// .cumsum_last_dim()
// - 1.;
// // Multiply one-hot by expanded index arange
// (x_equal * r.expand(self.shape)).max(self.shape.len() - 1)
// }
/// Get the indicies of the max elements along an axis
pub fn argmax(self, axis: usize) -> GraphTensor {
// Get one-hot along last dimension
let x_equal = self.eq(self.max(axis).expand(self.shape));
// Create index arange for last dimension
let r = self.graph().arange(self.dims()[axis]).cast(self.dtype);
// Multiply one-hot by expanded index arange
(x_equal * r.expand(self.shape)).max(axis).cast(DType::Int)
}
/// Take the absolute value
pub fn abs(self) -> GraphTensor {
@@ -220,10 +217,58 @@ impl GraphTensor {
pub fn topk_indexes(self, k: usize, axis: usize) -> GraphTensor {
self.argsort_indexes(axis, true).slice_along(..k, axis)
}
/// Apply a cumulative reduction operation along dimensions
///
/// See `cumsum` or `cummax` for usage examples.
pub fn cumop(
mut self,
axes: impl ToAxes,
op: impl Fn(GraphTensor, &[usize]) -> GraphTensor,
) -> Self {
let n_dims = self.shape.len();
let axes = axes.to_axes();
// Pad out length
let mut kernel = vec![1.into(); n_dims];
let mut padding = vec![(Expression::from(0), Expression::from(0)); n_dims];
for &ax in &axes {
let orig_length = self.dims()[ax];
padding[ax] = ((orig_length - 1).into(), 0.into());
kernel[ax] = orig_length;
}
self = self.pad(padding);
// Unfold
self = self.unfold(kernel, vec![1; n_dims], vec![1; n_dims]);
// Remove non-cumulative dimensions
for i in (0..n_dims).rev() {
if !axes.contains(&i) {
self = self.squeeze(n_dims + i);
}
}
// apply operation along cumulative dimensions
op(self, &(0..axes.len()).map(|ax| ax + n_dims).collect_vec())
}
/// Apply a cumulative sum along dimensions
pub fn cumsum(self, axes: impl ToAxes) -> Self {
self.cumop(axes, |t, axes| t.sum(axes))
}
/// Apply a cumulative max along dimensions
pub fn cummax(self, axes: impl ToAxes) -> Self {
self.cumop(axes, |t, axes| t.max(axes))
}
/// Apply a cumulative product along dimensions
pub fn cumprod(self, axes: impl ToAxes) -> Self {
self.cumop(axes, |t, axes| t.prod(axes))
}
}
#[cfg(test)]
pub(super) mod tests {
use std::collections::BinaryHeap;
use crate::{
prelude::*,
tests::{assert_close, random_vec},
@@ -231,6 +276,7 @@ pub(super) mod tests {
use candle_core::{Device, Tensor};
use candle_nn::ops::softmax;
use itertools::Itertools;
use ordered_float::NotNan;
pub fn test_unary(
shape: impl ToShape,
@@ -327,4 +373,77 @@ pub(super) mod tests {
},
);
}
#[test]
fn test_cumsum() {
test_unary(27, |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
test_unary((27, 63), |a| a.cumsum(1), |a| a.cumsum(1).unwrap());
test_unary((27, 63), |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
}
#[test]
fn test_argmax() {
test_unary(
(9, 27),
|a| a.argmax(1).cast(DType::F32),
|a| {
a.argmax(1)
.unwrap()
.to_dtype(candle_core::DType::F32)
.unwrap()
},
);
test_unary(
(9, 27),
|a| a.argmax(0).cast(DType::F32),
|a| {
a.argmax(0)
.unwrap()
.to_dtype(candle_core::DType::F32)
.unwrap()
},
);
}
#[test]
fn test_topk() {
pub fn topk_sorted_indices(x: &[f32], k: usize) -> Vec<usize> {
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<std::cmp::Reverse<(NotNan<f32>, usize)>> =
BinaryHeap::with_capacity(k);
for (i, &v) in x.iter().enumerate() {
let v = NotNan::new(v).expect("NaN encountered in topk");
if heap.len() < k {
heap.push(std::cmp::Reverse((v, i)));
} else if let Some(&std::cmp::Reverse((min_v, _))) = heap.peek() {
if v > min_v {
heap.pop();
heap.push(std::cmp::Reverse((v, i)));
}
}
}
let mut out: Vec<(NotNan<f32>, usize)> =
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
out.into_iter().map(|(_, i)| i).collect()
}
test_unary(
(10, 9),
|a| a.topk_indexes(5, 1).cast(DType::F32) * 1.0,
|a| {
let data = a.flatten_all().unwrap().to_vec1::<f32>().unwrap();
let topk = data
.chunks_exact(9)
.into_iter()
.flat_map(|c| topk_sorted_indices(c, 5))
.into_iter()
.map(|i| i as f32)
.collect_vec();
Tensor::new(topk, a.device()).unwrap()
},
);
}
}