Update elementwise_fusion.rs

This commit is contained in:
Joe Fioti
2024-10-09 06:47:19 -04:00
committed by GitHub
parent caa7e55524
commit d311003e8e

View File

@@ -446,7 +446,6 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.collect() .collect()
}) })
.collect(); .collect();
println!("Stacked: {:?}", stacked_shapes);
// Stack index expressions // Stack index expressions
let stacked_index_expressions_partial = stacked_shapes let stacked_index_expressions_partial = stacked_shapes
.iter() .iter()
@@ -481,14 +480,12 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.zip(&stacked_shapes) .zip(&stacked_shapes)
.map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify()) .map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
println!("M");
// Replace in subexpressions // Replace in subexpressions
let n_subexpressions = subexpressions.len(); let n_subexpressions = subexpressions.len();
for (i, ((subexp, _), stacked_shapes)) in for (i, ((subexp, _), stacked_shapes)) in
subexpressions.iter_mut().zip(subexp_views).enumerate() subexpressions.iter_mut().zip(subexp_views).enumerate()
{ {
println!("{i}");
// Index // Index
for (i, (ind_exp, val_exp)) in stacked_index_expressions for (i, (ind_exp, val_exp)) in stacked_index_expressions
.iter() .iter()
@@ -501,13 +498,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap()); input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
input_regexes.get(&i).unwrap() input_regexes.get(&i).unwrap()
}; };
// println!("ORIG: {ind_exp}");
let (ind, val) = ( let (ind, val) = (
ind_exp.simplify_cache(simplification_cache), ind_exp.simplify_cache(simplification_cache),
val_exp.simplify_cache(simplification_cache), val_exp.simplify_cache(simplification_cache),
); );
// println!("\n\n\n\n\n\nSIMP: {ind}");
println!("S {} -> {}", ind_exp.len(), ind.len());
*subexp = re *subexp = re
.replace_all( .replace_all(
subexp, subexp,