Merge branch 'cublaslt' of https://github.com/luminal-ai/luminal into cublaslt

This commit is contained in:
Austin Glover
2026-02-10 19:39:23 +00:00
2 changed files with 9 additions and 3 deletions

View File

@@ -448,6 +448,12 @@ impl<const N: usize> ToAxes for &[usize; N] {
}
}
impl ToAxes for &Vec<usize> {
fn to_axes(&self) -> Vec<usize> {
self.to_vec()
}
}
pub trait ToShape {
fn to_shape(self) -> Vec<Expression>;
}

View File

@@ -121,7 +121,8 @@ impl ShapeTracker {
}
/// Permute the dimensions
pub fn permute(&mut self, axes: &[usize]) {
pub fn permute(&mut self, axes: impl ToAxes) {
let axes = axes.to_axes();
assert!(
axes.len() == self.len(),
"Permute axes ({}) doesn't match shape axes ({})",
@@ -305,7 +306,7 @@ mod tests {
Expression::from(1)
]
);
tracker.permute(&[1, 2, 0]);
tracker.permute((1, 2, 0));
assert_eq!(
tracker.dims.as_slice(),
&[
@@ -322,7 +323,6 @@ mod tests {
Expression::from(b * c)
]
);
assert!(!tracker.is_contiguous());
tracker.expand_dim(1, 1);
assert_eq!(
tracker.dims.as_slice(),