Merge branch 'cublaslt' of https://github.com/luminal-ai/luminal into cublaslt

This commit is contained in:
Austin Glover
2026-02-10 19:39:23 +00:00
2 changed files with 9 additions and 3 deletions

View File

@@ -448,6 +448,12 @@ impl<const N: usize> ToAxes for &[usize; N] {
} }
} }
impl ToAxes for &Vec<usize> {
fn to_axes(&self) -> Vec<usize> {
self.to_vec()
}
}
pub trait ToShape { pub trait ToShape {
fn to_shape(self) -> Vec<Expression>; fn to_shape(self) -> Vec<Expression>;
} }

View File

@@ -121,7 +121,8 @@ impl ShapeTracker {
} }
/// Permute the dimensions /// Permute the dimensions
pub fn permute(&mut self, axes: &[usize]) { pub fn permute(&mut self, axes: impl ToAxes) {
let axes = axes.to_axes();
assert!( assert!(
axes.len() == self.len(), axes.len() == self.len(),
"Permute axes ({}) doesn't match shape axes ({})", "Permute axes ({}) doesn't match shape axes ({})",
@@ -305,7 +306,7 @@ mod tests {
Expression::from(1) Expression::from(1)
] ]
); );
tracker.permute(&[1, 2, 0]); tracker.permute((1, 2, 0));
assert_eq!( assert_eq!(
tracker.dims.as_slice(), tracker.dims.as_slice(),
&[ &[
@@ -322,7 +323,6 @@ mod tests {
Expression::from(b * c) Expression::from(b * c)
] ]
); );
assert!(!tracker.is_contiguous());
tracker.expand_dim(1, 1); tracker.expand_dim(1, 1);
assert_eq!( assert_eq!(
tracker.dims.as_slice(), tracker.dims.as_slice(),