mirror of
https://git.teahaven.kr/Rust-related/luminal.git
synced 2026-06-06 01:29:48 +09:00
Union small rolled loops with their unrolled form in egglog
The auto-roll prepass folds tiny scalar-mul chains (body=1, trips=2) inside e.g. the gemma_gelu sigmoid expansion into a loop body. The existing egglog fusion rules (GLUMoE GemmaGELU, etc.) pattern-match a specific flat chain of binary ops and can't see through the LoopStart/LoopInput/LoopEnd markers, so rolling silently disables the fusion and the extracted graph is strictly worse than not rolling at all. Add narrow per-binary-op early rewrites that union a rolled single-op-body loop (trips ≤ 4, state at body input position 0 or 1) with its fully-unrolled equivalent in the same eclass. The cost-based extractor then picks whichever representation downstream patterns prefer — the unrolled form when fusions match through the flat chain, the rolled form when nothing benefits. No threshold or special-case in the rolling cost model; the egraph stays the source of truth. Fixes test_glumoe_gemma_gelu_matches_unfused_output (78 → 79 passing in cuda_lite). All four binary HLIR ops (Add, Mul, Mod, LessThan) opt in via early_rewrites(). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
95
src/hlir.rs
95
src/hlir.rs
@@ -119,6 +119,89 @@ pub fn binary_sort(name: &str) -> SortDef {
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate egglog rewrite rules that union a small rolled `body=1, trips=N`
|
||||
/// single-binary-op loop with its fully-unrolled equivalent in the same
|
||||
/// eclass. Both representations coexist; the cost-based extractor picks
|
||||
/// whichever one downstream patterns prefer — the unrolled form when fusions
|
||||
/// (e.g. GLUMoE GemmaGELU) match through the flat chain, the rolled form
|
||||
/// otherwise. Without these unions, rolling a tiny chain blocks the fusion
|
||||
/// entirely and the extracted graph is strictly worse than not rolling.
|
||||
///
|
||||
/// Generates 2 rules per iter count (state at body input position 0 vs 1) for
|
||||
/// every `n_iters` in `2..=max_trips`. Larger trips stay rolled-only — the
|
||||
/// search-time cost of carrying both forms outweighs any fusion benefit at
|
||||
/// that scale (and real transformer-block rolls are body ≫ 1 anyway).
|
||||
///
|
||||
/// Each rule matches:
|
||||
/// `LoopStart(initial)` ← carries the state's initial value
|
||||
/// `LoopInput(s0..s_{N-1})` ← per-iter non-state input
|
||||
/// body op = `(<kind> ?ls ?li)` or `(<kind> ?li ?ls)` depending on state pos
|
||||
/// `LoopEnd(body)` ← post-loop result
|
||||
/// and unions `LoopEnd` with the chain
|
||||
/// `u0 = <kind>(initial, s0); u1 = <kind>(u0, s1); … u_{N-1} = …`
|
||||
/// (or the symmetric form for state at position 1).
|
||||
pub fn binary_op_unroll_rules(op_kind: &str, max_trips: usize) -> Vec<Rule> {
|
||||
let mut rules = Vec::with_capacity((max_trips.saturating_sub(1)) * 2);
|
||||
for n_iters in 2..=max_trips {
|
||||
for state_pos in 0..2 {
|
||||
rules.push(binary_op_unroll_rule(op_kind, n_iters, state_pos));
|
||||
}
|
||||
}
|
||||
rules
|
||||
}
|
||||
|
||||
fn binary_op_unroll_rule(op_kind: &str, n_iters: usize, state_pos: usize) -> Rule {
|
||||
// LoopInput's source IList: (ICons ?s0 (ICons ?s1 … (INil)))
|
||||
let li_sources = (0..n_iters).rev().fold(String::from("(INil)"), |acc, i| {
|
||||
format!("(ICons ?s{i} {acc})")
|
||||
});
|
||||
|
||||
// Body op pattern. For state_pos=0 the body reads (LoopStart, LoopInput);
|
||||
// for state_pos=1 it reads (LoopInput, LoopStart). The unrolled chain
|
||||
// mirrors the body's argument order so a/b strides stay aligned.
|
||||
let (body_in0, body_in1) = match state_pos {
|
||||
0 => ("?ls", "?li"),
|
||||
1 => ("?li", "?ls"),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let mut chain = String::new();
|
||||
for i in 0..n_iters {
|
||||
let prev = if i == 0 {
|
||||
"?initial".to_string()
|
||||
} else {
|
||||
format!("?u{}", i - 1)
|
||||
};
|
||||
let s_i = format!("?s{i}");
|
||||
let (a, b) = match state_pos {
|
||||
0 => (prev, s_i),
|
||||
1 => (s_i, prev),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
chain.push_str(&format!(
|
||||
" (let ?u{i} (Op ({op_kind} ?sh ?as ?bs ?os) (ICons {a} (ICons {b} (INil)))))\n"
|
||||
));
|
||||
}
|
||||
let last = format!("?u{}", n_iters - 1);
|
||||
let name = format!("unroll {op_kind} body trips={n_iters} state={state_pos}");
|
||||
|
||||
Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?ls (LoopStart ?initial ?loop_id ?slot_idx (MNum {n_iters}) ?dt))
|
||||
(= ?li (Op (LoopInput ?loop_id ?stream ?dt) {li_sources}))
|
||||
(= ?body (Op ({op_kind} ?sh ?as ?bs ?os) (ICons {body_in0} (ICons {body_in1} (INil)))))
|
||||
(= ?le (LoopEnd ?body ?loop_id ?slot_idx ?dt))
|
||||
)
|
||||
(
|
||||
{chain} (union ?le {last})
|
||||
)
|
||||
:ruleset expr
|
||||
:name \"{name}\"
|
||||
)"
|
||||
))
|
||||
}
|
||||
|
||||
/// Reduce op kind: (shape: EList, iters: Expression, strides: EList, iter_stride: Expression, out_strides: EList), IList: [inp]
|
||||
pub fn reduce_sort(name: &str) -> SortDef {
|
||||
sort(
|
||||
@@ -1596,6 +1679,9 @@ impl EgglogOp for Add {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Add", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
@@ -1681,6 +1767,9 @@ impl EgglogOp for Mul {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mul", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
@@ -1766,6 +1855,9 @@ impl EgglogOp for Mod {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mod", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
@@ -1852,6 +1944,9 @@ impl EgglogOp for LessThan {
|
||||
// Comparison operations always output Bool
|
||||
vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)]
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("LessThan", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
|
||||
Reference in New Issue
Block a user