forked from Rust-related/luminal
Update elementwise_fusion.rs
This commit is contained in:
@@ -446,7 +446,6 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
println!("Stacked: {:?}", stacked_shapes);
|
||||
// Stack index expressions
|
||||
let stacked_index_expressions_partial = stacked_shapes
|
||||
.iter()
|
||||
@@ -481,14 +480,12 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
.zip(&stacked_shapes)
|
||||
.map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify())
|
||||
.collect::<Vec<_>>();
|
||||
println!("M");
|
||||
|
||||
// Replace in subexpressions
|
||||
let n_subexpressions = subexpressions.len();
|
||||
for (i, ((subexp, _), stacked_shapes)) in
|
||||
subexpressions.iter_mut().zip(subexp_views).enumerate()
|
||||
{
|
||||
println!("{i}");
|
||||
// Index
|
||||
for (i, (ind_exp, val_exp)) in stacked_index_expressions
|
||||
.iter()
|
||||
@@ -501,13 +498,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
|
||||
input_regexes.get(&i).unwrap()
|
||||
};
|
||||
// println!("ORIG: {ind_exp}");
|
||||
let (ind, val) = (
|
||||
ind_exp.simplify_cache(simplification_cache),
|
||||
val_exp.simplify_cache(simplification_cache),
|
||||
);
|
||||
// println!("\n\n\n\n\n\nSIMP: {ind}");
|
||||
println!("S {} -> {}", ind_exp.len(), ind.len());
|
||||
*subexp = re
|
||||
.replace_all(
|
||||
subexp,
|
||||
|
||||
Reference in New Issue
Block a user