forked from Rust-related/luminal
Update elementwise_fusion.rs
This commit is contained in:
@@ -446,7 +446,6 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
|||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
println!("Stacked: {:?}", stacked_shapes);
|
|
||||||
// Stack index expressions
|
// Stack index expressions
|
||||||
let stacked_index_expressions_partial = stacked_shapes
|
let stacked_index_expressions_partial = stacked_shapes
|
||||||
.iter()
|
.iter()
|
||||||
@@ -481,14 +480,12 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
|||||||
.zip(&stacked_shapes)
|
.zip(&stacked_shapes)
|
||||||
.map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify())
|
.map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial).simplify())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
println!("M");
|
|
||||||
|
|
||||||
// Replace in subexpressions
|
// Replace in subexpressions
|
||||||
let n_subexpressions = subexpressions.len();
|
let n_subexpressions = subexpressions.len();
|
||||||
for (i, ((subexp, _), stacked_shapes)) in
|
for (i, ((subexp, _), stacked_shapes)) in
|
||||||
subexpressions.iter_mut().zip(subexp_views).enumerate()
|
subexpressions.iter_mut().zip(subexp_views).enumerate()
|
||||||
{
|
{
|
||||||
println!("{i}");
|
|
||||||
// Index
|
// Index
|
||||||
for (i, (ind_exp, val_exp)) in stacked_index_expressions
|
for (i, (ind_exp, val_exp)) in stacked_index_expressions
|
||||||
.iter()
|
.iter()
|
||||||
@@ -501,13 +498,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
|||||||
input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
|
input_regexes.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
|
||||||
input_regexes.get(&i).unwrap()
|
input_regexes.get(&i).unwrap()
|
||||||
};
|
};
|
||||||
// println!("ORIG: {ind_exp}");
|
|
||||||
let (ind, val) = (
|
let (ind, val) = (
|
||||||
ind_exp.simplify_cache(simplification_cache),
|
ind_exp.simplify_cache(simplification_cache),
|
||||||
val_exp.simplify_cache(simplification_cache),
|
val_exp.simplify_cache(simplification_cache),
|
||||||
);
|
);
|
||||||
// println!("\n\n\n\n\n\nSIMP: {ind}");
|
|
||||||
println!("S {} -> {}", ind_exp.len(), ind.len());
|
|
||||||
*subexp = re
|
*subexp = re
|
||||||
.replace_all(
|
.replace_all(
|
||||||
subexp,
|
subexp,
|
||||||
|
|||||||
Reference in New Issue
Block a user