mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Merge branch 'cublaslt' of https://github.com/luminal-ai/luminal into cublaslt
This commit is contained in:
@@ -448,6 +448,12 @@ impl<const N: usize> ToAxes for &[usize; N] {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToAxes for &Vec<usize> {
|
||||
fn to_axes(&self) -> Vec<usize> {
|
||||
self.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ToShape {
|
||||
fn to_shape(self) -> Vec<Expression>;
|
||||
}
|
||||
|
||||
@@ -121,7 +121,8 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Permute the dimensions
|
||||
pub fn permute(&mut self, axes: &[usize]) {
|
||||
pub fn permute(&mut self, axes: impl ToAxes) {
|
||||
let axes = axes.to_axes();
|
||||
assert!(
|
||||
axes.len() == self.len(),
|
||||
"Permute axes ({}) doesn't match shape axes ({})",
|
||||
@@ -305,7 +306,7 @@ mod tests {
|
||||
Expression::from(1)
|
||||
]
|
||||
);
|
||||
tracker.permute(&[1, 2, 0]);
|
||||
tracker.permute((1, 2, 0));
|
||||
assert_eq!(
|
||||
tracker.dims.as_slice(),
|
||||
&[
|
||||
@@ -322,7 +323,6 @@ mod tests {
|
||||
Expression::from(b * c)
|
||||
]
|
||||
);
|
||||
assert!(!tracker.is_contiguous());
|
||||
tracker.expand_dim(1, 1);
|
||||
assert_eq!(
|
||||
tracker.dims.as_slice(),
|
||||
|
||||
Reference in New Issue
Block a user