Files
luminal/src/hl_ops/movement.rs
2025-12-23 10:46:47 -05:00

737 lines
26 KiB
Rust

use crate::{prelude::*, utils::flatten_strides};
impl GraphTensor {
/// Swap dimensions of the tensor
pub fn permute(mut self, axes: impl ToAxes) -> GraphTensor {
self.shape.permute(&axes.to_axes());
self
}
pub fn transpose(self, dim0: usize, dim1: usize) -> GraphTensor {
let num_dims = self.shape.len();
assert!(
dim0 < num_dims && dim1 < num_dims,
"transpose dimensions ({dim0}, {dim1}) out of bounds for tensor with {num_dims} dimensions"
);
// Create identity permutation, then swap the two specified dimensions
let mut perm_axes: Vec<usize> = (0..num_dims).collect();
perm_axes.swap(dim0, dim1);
self.permute(perm_axes)
}
/// Broadcast tensor along a new dimension
pub fn expand_dim(mut self, axis: usize, size: impl Into<Expression>) -> GraphTensor {
self.shape.expand_dim(axis, size);
self
}
/// Broadcast tensor along new dimensions (with explicitly given dest shape)
pub fn expand(mut self, shape: impl ToShape) -> GraphTensor {
let s = shape.to_shape();
for (i, s) in s.into_iter().enumerate() {
if self.shape.len() <= i || self.shape.dims[i] != s {
self.shape.expand_dim(i, s);
}
}
self
}
/// Merge two dimensions together
pub fn merge_dims(mut self, axis1: usize, axis2: usize) -> GraphTensor {
self.shape.merge_dims(axis1, axis2);
self
}
//// Split a dim into 2 dims, new dim is placed directly after original dim
pub fn split_dims(mut self, axis: usize, new_dim_size: impl Into<Expression>) -> GraphTensor {
self.shape.split_dims(axis, new_dim_size);
self
}
/// add a new dimension of size 1 at the specified place
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
self.shape.expand_dim(dim, 1);
self
}
/// remove a dimension of size 1
pub fn squeeze(mut self, axis: usize) -> GraphTensor {
assert_eq!(
self.dims()[axis],
Expression::from(1),
"Only dimensions of size 1 can be squeezed!"
);
self.shape.remove_dim(axis);
self
}
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
assert_eq!(
indexes.dtype,
DType::Int,
"Gather indexes must have an integer dtype!"
);
let id = self
.graph()
.add_op(Gather::default())
.input(indexes.id, 0, indexes.shape)
.input(self.id, 0, self.shape)
.finish();
GraphTensor::from_id(id, indexes.shape.contiguous(), self.graph_ref, self.dtype)
}
/// Given a tensor of non-repeating indexes along a dimension, generate an inverse permutation
/// x = [3, 2, 4, 1, 5, 0]
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
// TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
assert_eq!(self.dtype, DType::Int);
let dims = self.dims();
let ax_size = dims[axis];
let mut dims2 = dims.clone();
dims2.insert(axis, ax_size);
// candidate: varies along candidate dim (axis), broadcast elsewhere.
let mut candidate = self.graph().arange(ax_size);
for i in 0..axis {
candidate = candidate.expand_dim(i, dims2[i]);
}
for i in axis + 1..dims2.len() {
candidate = candidate.expand_dim(i, dims2[i]);
}
// position: varies along position dim (axis+1), broadcast elsewhere.
let mut position = self.graph().arange(ax_size);
for i in 0..(axis + 1) {
position = position.expand_dim(i, dims2[i]);
}
for i in (axis + 2)..dims2.len() {
position = position.expand_dim(i, dims2[i]);
}
// one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
let one_hot = self.expand_dim(axis, ax_size).eq(candidate);
// inv[candidate, ...] = Σ_pos one_hot * position
(one_hot * position).sum(axis + 1)
}
/// Extracts sliding local windows from an input tensor.
pub fn unfold(
self,
kernel: impl ToShape,
strides: impl ToShape,
dilation: impl ToShape,
) -> GraphTensor {
let (kernel, strides, dilation) =
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
assert_eq!(
self.shape.len(),
kernel.len(),
"Kernel must be same number of dimensions as tensor!"
);
assert_eq!(
self.shape.len(),
strides.len(),
"Strides must be same number of dimensions as tensor!"
);
assert_eq!(
self.shape.len(),
dilation.len(),
"Dilation must be same number of dimensions as tensor!"
);
// Compute input strides (row-major contiguous)
let dims = self.dims();
let n = dims.len();
let mut in_strides = vec![Expression::from(1); n];
let mut acc = Expression::from(1);
for (dim, in_stride) in dims.iter().zip(&mut in_strides).rev() {
*in_stride = acc;
acc *= dim;
}
// Per-dim window counts
let mut win = Vec::with_capacity(n);
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
let effective_window = *d * (*k - 1) + 1;
win.push(((*dim - effective_window) / s) + 1);
}
// [win..., kernel...]
let mut final_shape: Vec<Expression> = win.into_iter().map(|e| e.simplify()).collect();
final_shape.extend(kernel.iter().copied());
// Axis exprs must match final_shape axis order: first w axes, then k axes.
// idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
let mut axis_exprs = Vec::with_capacity(2 * n);
// w axes
for i in 0..n {
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
}
// k axes
for i in 0..n {
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
}
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
let iota = self.graph().iota(index_expression, final_shape);
self.gather(iota)
}
/// Take a slice of a tensor along multiple dimensions.
///
/// ```
/// # use luminal::prelude::*;
/// # let mut cx = Graph::new();
/// let a = cx.tensor((5, 10));
/// let b = a.slice((2..4, 1..)); // 2x9 tensor
/// assert_eq!(b.dims(), vec![Expression::from(2), Expression::from(9)]);
/// ```
pub fn slice(mut self, slice: impl ToSlice) -> GraphTensor {
let mut ranges = slice.to_range_vec();
ranges.extend(
self.dims()
.iter()
.skip(ranges.len())
.map(|d| (0.into(), *d)),
); // Make sure we have a range per dim
if ranges.iter().any(|(st, _)| *st != 0) {
// We have a start slice, need to use an iota because tensors don't have offsets
let mut new_dims = vec![];
let mut index_expressions = vec![];
let mut phys_size = Expression::from(1);
for (dim, (start, end)) in self.dims().into_iter().zip(ranges).rev() {
index_expressions.push((Expression::from('z') + start) * phys_size);
phys_size *= dim;
new_dims.push(dim.min(end) - start);
}
new_dims.reverse();
index_expressions.reverse();
let index_expression = flatten_strides(&new_dims, &index_expressions);
let iota = self.graph().iota(index_expression, new_dims);
self.gather(iota)
} else {
// No start slices so no iota needed, just reduce the shape down
for (sh, (_, end)) in self.shape.dims.iter_mut().zip(ranges) {
*sh = sh.min(end);
}
self
}
}
/// Take a slice of a tensor along a dimension.
///
/// ```
/// # use luminal::prelude::*;
/// # let mut cx = Graph::new();
/// let a = cx.tensor((5, 10));
/// let b = a.slice_along(4.., 1); // 5x6 tensor
/// assert_eq!(b.dims(), vec![Expression::from(5), Expression::from(6)]);
/// ```
pub fn slice_along(self, slice: impl SliceRange, axis: usize) -> GraphTensor {
let mut s = vec![(Expression::from(0), Expression::from(i32::MAX)); axis + 1];
s[axis] = slice.bounds();
self.slice(s)
}
// /// Cut out 'size' elements every 'spacing' elements on a dimension. 'size' must be smaller than the dimension
// pub fn excise(mut self, spacing: usize, size: usize) -> GraphTensor {
// let n_dims = self.shape.len();
// // Pad out to a multiple of spacing + size
// let total_size = (self.shape.dims[n_dims - 1] + ((spacing + size) - 1))
// / (spacing + size)
// * (spacing + size);
// let padding = total_size - self.shape.dims[self.shape.indexes[n_dims - 1]];
// self.shape.padding[self.shape.indexes[n_dims - 1]].1 = padding;
// self = self.contiguous();
// // Expand a new dimension to do the slicing on
// let n_rows = total_size / (spacing + size);
// self.shape.expand_dim(n_dims, spacing + size);
// // self = self.contiguous();
// self.shape.dims[self.shape.indexes[n_dims - 1]] = n_rows;
// self.shape.fake[self.shape.indexes[n_dims]] = false;
// // Slice
// self.shape.mask[self.shape.indexes[n_dims]].1 = spacing.into();
// self = self.contiguous();
// self.shape.remove_dim(n_dims);
// self
// }
/// Pad out dimensions of a tensor with 0
pub fn pad(self, padding: impl ToPad) -> GraphTensor {
let mut padding = padding.to_pad_vec();
padding.extend(vec![(0.into(), 0.into()); self.shape.len() - padding.len()]); // Make sure we have a padding per dim
let mut index_expressions = vec![];
let mut phys_size = Expression::from(1);
let mut new_dims = vec![];
for (dim, (start, end)) in self.dims().into_iter().zip(&padding).rev() {
index_expressions
.push(((Expression::from('z') - *start).max(0).min(dim - 1)) * phys_size);
phys_size *= dim;
new_dims.push(dim + *start + *end);
}
new_dims.reverse();
index_expressions.reverse();
let index_expression = flatten_strides(&new_dims, &index_expressions);
// get indexed tensor
let new_tensor = self.gather(self.graph().iota(index_expression, new_dims.clone()));
// mask out padded elements
let mut mask_expressions = vec![];
for ((start, _), dim) in padding.into_iter().zip(self.dims()) {
mask_expressions
.push(Expression::from('z').gte(start) * Expression::from('z').lt(start + dim));
}
let mask_expression = flatten_strides_mask(&new_dims, &mask_expressions);
let mask = self.graph().iota(mask_expression, new_dims);
new_tensor * mask
}
/// Pad along an existing dimension
pub fn pad_along(
self,
left: impl Into<Expression>,
right: impl Into<Expression>,
axis: usize,
) -> GraphTensor {
let mut p = vec![(Expression::from(0), Expression::from(0)); axis + 1];
p[axis] = (left.into(), right.into());
self.pad(p)
}
/// Concat along an existing dimension
pub fn concat_along(self, rhs: GraphTensor, axis: usize) -> GraphTensor {
// Pad and add
self.pad_along(0, rhs.dims()[axis], axis) + rhs.pad_along(self.dims()[axis], 0, axis)
}
}
#[cfg(test)]
mod tests {
use crate::{
hl_ops::{binary::tests::test_binary, unary::tests::test_unary},
prelude::*,
};
use candle_core::{IndexOp, Tensor};
#[test]
fn test_pad() {
test_unary(
23,
|a| a.pad((2, 6)),
|a| a.pad_with_zeros(0, 2, 6).unwrap(),
);
test_unary(
(18, 72),
|a| a.pad(((4, 31), (2, 9))),
|a| {
a.pad_with_zeros(0, 4, 31)
.unwrap()
.pad_with_zeros(1, 2, 9)
.unwrap()
},
);
}
#[test]
fn test_slice_pad() {
test_unary(
(17, 26),
|a| a.slice((1..3, 14..32)).pad(((3, 14), (5, 0))),
|a| {
a.i((1..3, 14..26)) // candle slice ranges can't go off end
.unwrap()
.pad_with_zeros(0, 3, 14)
.unwrap()
.pad_with_zeros(1, 5, 0)
.unwrap()
},
);
}
#[test]
fn test_transpose() {
test_unary(
(5, 10),
|a| a.transpose(0, 1) * 1.0,
|a| a.transpose(0, 1).unwrap(),
);
}
#[test]
fn test_unfold() {
// Need all this code because candle doesnt do unfold
pub fn unfold_nd_f32(
x: &[f32],
shape: &[usize],
strides: &[usize],
kernel: &[usize],
step: &[usize],
dilation: &[usize],
pad_before: &[usize],
pad_after: &[usize],
) -> Vec<f32> {
let n = shape.len();
assert!(n > 0);
assert_eq!(strides.len(), n);
assert_eq!(kernel.len(), n);
assert_eq!(step.len(), n);
assert_eq!(dilation.len(), n);
assert_eq!(pad_before.len(), n);
assert_eq!(pad_after.len(), n);
for d in 0..n {
assert!(kernel[d] > 0);
assert!(step[d] > 0);
assert!(dilation[d] > 0);
assert!(shape[d] > 0);
}
// Effective kernel size per dim: (K-1)*d + 1
let eff_kernel: Vec<usize> =
(0..n).map(|d| (kernel[d] - 1) * dilation[d] + 1).collect();
// Output spatial shape (number of windows) per dim
let mut out_shape = vec![0usize; n];
for d in 0..n {
let padded = shape[d] + pad_before[d] + pad_after[d];
if padded < eff_kernel[d] {
return Vec::new();
}
out_shape[d] = (padded - eff_kernel[d]) / step[d] + 1;
}
let windows = prod(&out_shape);
let window_elems = prod(kernel);
let mut out = vec![0.0f32; windows * window_elems];
// Precompute helpers
let k_mul = row_major_multipliers(kernel);
// Current output window position (row-major)
let mut out_pos = vec![0usize; n];
for w in 0..windows {
if w > 0 {
incr_row_major(&mut out_pos, &out_shape);
}
// Window start in padded coordinates
let start_padded: Vec<usize> = (0..n).map(|d| out_pos[d] * step[d]).collect();
let base_out = w * window_elems;
// Iterate kernel elements (flattened)
for ke in 0..window_elems {
let k_idx = unravel_row_major(ke, kernel, &k_mul);
let mut flat: isize = 0;
let mut in_bounds = true;
for d in 0..n {
let p = start_padded[d] + k_idx[d] * dilation[d];
let logical = p as isize - pad_before[d] as isize;
if logical < 0 || logical >= shape[d] as isize {
in_bounds = false;
break;
}
flat += logical * strides[d] as isize;
}
let out_idx = base_out + ke;
out[out_idx] = if in_bounds { x[flat as usize] } else { 0.0 };
}
}
out
}
// -------- helpers --------
fn prod(xs: &[usize]) -> usize {
xs.iter().copied().product()
}
fn row_major_multipliers(shape: &[usize]) -> Vec<usize> {
let n = shape.len();
let mut mul = vec![1usize; n];
let mut acc = 1usize;
for d in (0..n).rev() {
mul[d] = acc;
acc *= shape[d];
}
mul
}
fn unravel_row_major(mut idx: usize, shape: &[usize], mul: &[usize]) -> Vec<usize> {
let n = shape.len();
let mut coords = vec![0usize; n];
for d in 0..n {
coords[d] = idx / mul[d];
idx %= mul[d];
}
coords
}
fn incr_row_major(pos: &mut [usize], shape: &[usize]) {
for d in (0..pos.len()).rev() {
pos[d] += 1;
if pos[d] < shape[d] {
return;
}
pos[d] = 0;
}
}
test_unary(
5,
|a| a.unfold(3, 1, 1),
|a| {
Tensor::new(
unfold_nd_f32(
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
a.dims(),
a.stride(),
&[3],
&[1],
&[1],
&[0],
&[0],
),
a.device(),
)
.unwrap()
},
);
test_unary(
(8, 10),
|a| a.pad(((0, 2), (4, 4))).unfold((2, 3), (1, 2), (2, 1)),
|a| {
Tensor::new(
unfold_nd_f32(
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
a.dims(),
a.stride(),
&[2, 3],
&[1, 2],
&[2, 1],
&[0, 4],
&[2, 3],
),
a.device(),
)
.unwrap()
},
);
}
#[test]
fn test_unsqueeze() {
let mut cx = Graph::new();
let inp = cx.tensor((2, 2, 3));
let out1 = inp.unsqueeze(1);
let out2 = inp.unsqueeze(3);
assert_eq!(out1.dims(), &[2, 1, 2, 3]);
assert_eq!(out2.dims(), &[2, 2, 3, 1]);
}
#[test]
fn test_concat() {
test_binary(
17,
32,
|a, b| a.concat_along(b, 0),
|a, b| Tensor::cat(&[a, b], 0).unwrap(),
);
test_binary(
(10, 4),
(10, 6),
|a, b| a.concat_along(b, 1),
|a, b| Tensor::cat(&[a, b], 1).unwrap(),
);
test_binary(
(4, 10),
(6, 10),
|a, b| a.concat_along(b, 0),
|a, b| Tensor::cat(&[a, b], 0).unwrap(),
);
test_unary(
(4, 10),
|a| a.concat_along(a, 0),
|a| Tensor::cat(&[a.clone(), a], 0).unwrap(),
);
}
// // #[test]
// // fn test_cumsum() {
// // let mut cx = Graph::new();
// // let a = cx.constant(1.).expand_dim(0, 3);
// // let b = a.cumsum_last_dim().retrieve();
// // let c = a
// // .expand_dim(1, 3)
// // .permute((1, 0))
// // .cumsum_last_dim()
// // .permute((1, 0))
// // .retrieve();
// // cx.execute();
// // assert_exact(&b.data(), &[1., 2., 3.]);
// // assert_exact(&c.data(), &[1., 1., 1., 2., 2., 2., 3., 3., 3.]);
// // }
// // #[test]
// // fn test_pool_1d() {
// // let mut cx = Graph::new();
// // let inp1 = cx.tensor(5).set([1., 2., 3., 4., 5.]);
// // let inp2 = cx
// // .tensor((2, 5))
// // .set([[15., 14., 13., 12., 11.], [1., 2., 3., 4., 5.]]);
// // // Stride 1
// // let out1 = inp1.pool_last_dim(3, 1, 1).retrieve();
// // // Stride 2
// // let out2 = inp1.pool_last_dim(3, 2, 1).retrieve();
// // // Stride 3
// // let out3 = inp1.pool_last_dim(3, 3, 1).retrieve();
// // // Dilation 2
// // let out4 = inp1.pool_last_dim(3, 1, 2).retrieve();
// // // Dilation 2 Padding 1
// // let out5 = inp1.pad(((1, 1),)).pool_last_dim(3, 1, 2).retrieve();
// // // Stride 1 Batch 2
// // let out6 = inp2.pool_last_dim(3, 1, 1).retrieve();
// // // Stride 3
// // let out7 = inp2.pool_last_dim(3, 3, 1).retrieve();
// // // Dilation 2
// // let out8 = inp2.pool_last_dim(3, 1, 2).retrieve();
// // // Dilation 2 Padding 1
// // let out9 = inp2.pad(((0, 0), (1, 1))).pool_last_dim(3, 1, 2).retrieve();
// // cx.execute();
// // assert_exact(&out1.data(), &[1., 2., 3., 2., 3., 4., 3., 4., 5.]);
// // assert_exact(&out2.data(), &[1., 2., 3., 3., 4., 5.]);
// // assert_exact(&out3.data(), &[1., 2., 3.]);
// // assert_exact(&out4.data(), &[1., 3., 5.]);
// // assert_exact(&out5.data(), &[0., 2., 4., 1., 3., 5., 2., 4., 0.]);
// // assert_exact(
// // &out6.data(),
// // &[
// // 15., 14., 13., 14., 13., 12., 13., 12., 11., 1., 2., 3., 2., 3., 4., 3., 4., 5.,
// // ],
// // );
// // assert_exact(&out7.data(), &[15., 14., 13., 1., 2., 3.]);
// // assert_exact(&out8.data(), &[15., 13., 11., 1., 3., 5.]);
// // assert_exact(
// // &out9.data(),
// // &[
// // 0., 14., 12., 15., 13., 11., 14., 12., 0., 0., 2., 4., 1., 3., 5., 2., 4., 0.,
// // ],
// // );
// // }
// // #[test]
// // fn test_pool_1d_dims() {
// // let mut cx = Graph::new();
// // let inp1 = cx.tensor((4, 4)).set(vec![
// // 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
// // ]);
// // // Stride 1
// // let out1 = inp1.pool_last_dim(3, 1, 1).retrieve();
// // cx.execute();
// // assert_exact(
// // &out1.data(),
// // &[
// // 1., 2., 3., 2., 3., 4., 5., 6., 7., 6., 7., 8., 9., 10., 11., 10., 11., 12., 13.,
// // 14., 15., 14., 15., 16.,
// // ],
// // );
// // }
// // #[test]
// // fn test_pool_2d() {
// // let mut cx = Graph::new();
// // let inp1 = cx.tensor((4, 4)).set(vec![
// // 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
// // ]);
// // // 3x3 kernel
// // let out1 = inp1
// // // Pool first dim first by moving it to end
// // .permute((1, 0))
// // .pool_last_dim(3, 1, 1)
// // // Now move other dim to end
// // .permute((1, 2, 0))
// // .pool_last_dim(3, 1, 1)
// // // Now swap middle two dims
// // .permute((0, 2, 1, 3))
// // // Now merge both pooled dimensions
// // .reshape((4, 3, 3))
// // .retrieve();
// // cx.execute();
// // assert_exact(
// // &out1.data(),
// // &[
// // 1.00, 2.00, 3.00, 5.00, 6.00, 7.00, 9.00, 10.00, 11.00, 2.00, 3.00, 4.00, 6.00,
// // 7.00, 8.00, 10.00, 11.00, 12.00, 5.00, 6.00, 7.00, 9.00, 10.00, 11.00, 13.00,
// // 14.00, 15.00, 6.00, 7.00, 8.00, 10.00, 11.00, 12.00, 14.00, 15.00, 16.00,
// // ],
// // );
// // }
// // #[test]
// // fn test_pool_1d_dilation() {
// // let mut cx = Graph::new();
// // let inp1 = cx.tensor(5).set(vec![1., 2., 3., 4., 5.]);
// // // Stride 1
// // let out1 = inp1.pool_last_dim(2, 1, 2).retrieve();
// // // Stride 2
// // let out2 = inp1.pool_last_dim(2, 2, 2).retrieve();
// // // Stride 3
// // let out3 = inp1.pool_last_dim(2, 3, 2).retrieve();
// // cx.execute();
// // assert_exact(&out1.data(), &[1., 3., 2., 4., 3., 5.]);
// // assert_exact(&out2.data(), &[1., 3., 3., 5.]);
// // assert_exact(&out3.data(), &[1., 3.]);
// // }
// // #[test]
// // fn test_rotate_half() {
// // let mut cx = Graph::new();
// // let a = cx.tensor((3, 2));
// // a.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
// // let x1 = a.slice((.., ..1)).contiguous();
// // let x2 = a.slice((.., 1..)).contiguous();
// // let c = (-x2).concat_along(x1, 1);
// // c.retrieve();
// // cx.execute();
// // let d_dev = Cpu::default();
// // let d_a = d_dev.tensor_from_vec(
// // vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854],
// // (dfdx::shapes::Const::<3>, dfdx::shapes::Const::<2>),
// // );
// // let d_x1 = d_a.clone().slice((.., ..1));
// // let d_x2 = d_a.slice((.., 1..));
// // let d_c = (-d_x2, d_x1)
// // .concat_along(dfdx::shapes::Axis::<1>)
// // .realize::<Rank2<3, 2>>();
// // assert_close(&c.data(), &d_c.as_vec());
// // }
}