From cd61de9362db6d0b41ad8055f268d2477d2fbaad Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Tue, 10 Feb 2026 19:32:23 +0000 Subject: [PATCH] fixed tests --- src/shape/mod.rs | 6 ++++++ src/shape/tracker.rs | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/shape/mod.rs b/src/shape/mod.rs index db9d328d..dd84ad27 100644 --- a/src/shape/mod.rs +++ b/src/shape/mod.rs @@ -448,6 +448,12 @@ impl ToAxes for &[usize; N] { } } +impl ToAxes for &Vec { + fn to_axes(&self) -> Vec { + self.to_vec() + } +} + pub trait ToShape { fn to_shape(self) -> Vec; } diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index c243c5ce..7f1490be 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -121,7 +121,8 @@ impl ShapeTracker { } /// Permute the dimensions - pub fn permute(&mut self, axes: &[usize]) { + pub fn permute(&mut self, axes: impl ToAxes) { + let axes = axes.to_axes(); assert!( axes.len() == self.len(), "Permute axes ({}) doesn't match shape axes ({})", @@ -305,7 +306,7 @@ mod tests { Expression::from(1) ] ); - tracker.permute(&[1, 2, 0]); + tracker.permute((1, 2, 0)); assert_eq!( tracker.dims.as_slice(), &[ @@ -322,7 +323,6 @@ mod tests { Expression::from(b * c) ] ); - assert!(!tracker.is_contiguous()); tracker.expand_dim(1, 1); assert_eq!( tracker.dims.as_slice(),