mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Merge branch 'cublaslt' of https://github.com/luminal-ai/luminal into cublaslt
This commit is contained in:
@@ -448,6 +448,12 @@ impl<const N: usize> ToAxes for &[usize; N] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToAxes for &Vec<usize> {
|
||||||
|
fn to_axes(&self) -> Vec<usize> {
|
||||||
|
self.to_vec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ToShape {
|
pub trait ToShape {
|
||||||
fn to_shape(self) -> Vec<Expression>;
|
fn to_shape(self) -> Vec<Expression>;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ impl ShapeTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Permute the dimensions
|
/// Permute the dimensions
|
||||||
pub fn permute(&mut self, axes: &[usize]) {
|
pub fn permute(&mut self, axes: impl ToAxes) {
|
||||||
|
let axes = axes.to_axes();
|
||||||
assert!(
|
assert!(
|
||||||
axes.len() == self.len(),
|
axes.len() == self.len(),
|
||||||
"Permute axes ({}) doesn't match shape axes ({})",
|
"Permute axes ({}) doesn't match shape axes ({})",
|
||||||
@@ -305,7 +306,7 @@ mod tests {
|
|||||||
Expression::from(1)
|
Expression::from(1)
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
tracker.permute(&[1, 2, 0]);
|
tracker.permute((1, 2, 0));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tracker.dims.as_slice(),
|
tracker.dims.as_slice(),
|
||||||
&[
|
&[
|
||||||
@@ -322,7 +323,6 @@ mod tests {
|
|||||||
Expression::from(b * c)
|
Expression::from(b * c)
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert!(!tracker.is_contiguous());
|
|
||||||
tracker.expand_dim(1, 1);
|
tracker.expand_dim(1, 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tracker.dims.as_slice(),
|
tracker.dims.as_slice(),
|
||||||
|
|||||||
Reference in New Issue
Block a user