Update elementwise_fusion.rs

This commit is contained in:
Joe Fioti
2024-10-09 06:47:19 -04:00
committed by GitHub
parent caa7e55524
commit d311003e8e

View File

@@ -446,7 +446,6 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.collect()
})
.collect();
println!("Stacked: {:?}", stacked_shapes);
// Stack index expressions
let stacked_index_expressions_partial = stacked_shapes
.iter()
@@ -481,14 +480,12 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.zip(&stacked_shapes)
.map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify())
.collect::<Vec<_>>();
println!("M");
// Replace in subexpressions
let n_subexpressions = subexpressions.len();
for (i, ((subexp, _), stacked_shapes)) in
subexpressions.iter_mut().zip(subexp_views).enumerate()
{
println!("{i}");
// Index
for (i, (ind_exp, val_exp)) in stacked_index_expressions
.iter()
@@ -501,13 +498,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
input_regexes.get(&i).unwrap()
};
// println!("ORIG: {ind_exp}");
let (ind, val) = (
ind_exp.simplify_cache(simplification_cache),
val_exp.simplify_cache(simplification_cache),
);
// println!("\n\n\n\n\n\nSIMP: {ind}");
println!("S {} -> {}", ind_exp.len(), ind.len());
*subexp = re
.replace_all(
subexp,