Union small rolled loops with their unrolled form in egglog

The auto-roll prepass folds tiny scalar-mul chains (body=1, trips=2)
inside e.g. the gemma_gelu sigmoid expansion into a loop body. The
existing egglog fusion rules (GLUMoE GemmaGELU, etc.) pattern-match a
specific flat chain of binary ops and can't see through the
LoopStart/LoopInput/LoopEnd markers, so rolling silently disables the
fusion and the extracted graph is strictly worse than not rolling at
all.

Add narrow per-binary-op early rewrites that union a rolled
single-op-body loop (trips ≤ 4, state at body input position 0 or 1)
with its fully-unrolled equivalent in the same eclass. The cost-based
extractor then picks whichever representation downstream patterns
prefer — the unrolled form when fusions match through the flat chain,
the rolled form when nothing benefits. No threshold or special-case
in the rolling cost model; the egraph stays the source of truth.

Fixes test_glumoe_gemma_gelu_matches_unfused_output (78 → 79 passing
in cuda_lite). All four binary HLIR ops (Add, Mul, Mod, LessThan)
opt in via early_rewrites().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Joe Fioti
2026-04-26 04:26:24 +00:00
parent 7d68b62aa8
commit aba9627563

View File

@@ -119,6 +119,89 @@ pub fn binary_sort(name: &str) -> SortDef {
)
}
/// Generate egglog rewrite rules that union a small rolled `body=1, trips=N`
/// single-binary-op loop with its fully-unrolled equivalent in the same
/// eclass. Both representations coexist; the cost-based extractor picks
/// whichever one downstream patterns prefer — the unrolled form when fusions
/// (e.g. GLUMoE GemmaGELU) match through the flat chain, the rolled form
/// otherwise. Without these unions, rolling a tiny chain blocks the fusion
/// entirely and the extracted graph is strictly worse than not rolling.
///
/// Generates 2 rules per iter count (state at body input position 0 vs 1) for
/// every `n_iters` in `2..=max_trips`. Larger trips stay rolled-only — the
/// search-time cost of carrying both forms outweighs any fusion benefit at
/// that scale (and real transformer-block rolls are body ≫ 1 anyway).
///
/// Each rule matches:
/// `LoopStart(initial)` ← carries the state's initial value
/// `LoopInput(s0..s_{N-1})` ← per-iter non-state input
/// body op = `(<kind> ?ls ?li)` or `(<kind> ?li ?ls)` depending on state pos
/// `LoopEnd(body)` ← post-loop result
/// and unions `LoopEnd` with the chain
/// `u0 = <kind>(initial, s0); u1 = <kind>(u0, s1); … u_{N-1} = …`
/// (or the symmetric form for state at position 1).
pub fn binary_op_unroll_rules(op_kind: &str, max_trips: usize) -> Vec<Rule> {
let mut rules = Vec::with_capacity((max_trips.saturating_sub(1)) * 2);
for n_iters in 2..=max_trips {
for state_pos in 0..2 {
rules.push(binary_op_unroll_rule(op_kind, n_iters, state_pos));
}
}
rules
}
fn binary_op_unroll_rule(op_kind: &str, n_iters: usize, state_pos: usize) -> Rule {
// LoopInput's source IList: (ICons ?s0 (ICons ?s1 … (INil)))
let li_sources = (0..n_iters).rev().fold(String::from("(INil)"), |acc, i| {
format!("(ICons ?s{i} {acc})")
});
// Body op pattern. For state_pos=0 the body reads (LoopStart, LoopInput);
// for state_pos=1 it reads (LoopInput, LoopStart). The unrolled chain
// mirrors the body's argument order so a/b strides stay aligned.
let (body_in0, body_in1) = match state_pos {
0 => ("?ls", "?li"),
1 => ("?li", "?ls"),
_ => unreachable!(),
};
let mut chain = String::new();
for i in 0..n_iters {
let prev = if i == 0 {
"?initial".to_string()
} else {
format!("?u{}", i - 1)
};
let s_i = format!("?s{i}");
let (a, b) = match state_pos {
0 => (prev, s_i),
1 => (s_i, prev),
_ => unreachable!(),
};
chain.push_str(&format!(
" (let ?u{i} (Op ({op_kind} ?sh ?as ?bs ?os) (ICons {a} (ICons {b} (INil)))))\n"
));
}
let last = format!("?u{}", n_iters - 1);
let name = format!("unroll {op_kind} body trips={n_iters} state={state_pos}");
Rule::raw(format!(
"(rule
(
(= ?ls (LoopStart ?initial ?loop_id ?slot_idx (MNum {n_iters}) ?dt))
(= ?li (Op (LoopInput ?loop_id ?stream ?dt) {li_sources}))
(= ?body (Op ({op_kind} ?sh ?as ?bs ?os) (ICons {body_in0} (ICons {body_in1} (INil)))))
(= ?le (LoopEnd ?body ?loop_id ?slot_idx ?dt))
)
(
{chain} (union ?le {last})
)
:ruleset expr
:name \"{name}\"
)"
))
}
/// Reduce op kind: (shape: EList, iters: Expression, strides: EList, iter_stride: Expression, out_strides: EList), IList: [inp]
pub fn reduce_sort(name: &str) -> SortDef {
sort(
@@ -1596,6 +1679,9 @@ impl EgglogOp for Add {
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Add", 4)
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
@@ -1681,6 +1767,9 @@ impl EgglogOp for Mul {
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Mul", 4)
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
@@ -1766,6 +1855,9 @@ impl EgglogOp for Mod {
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Mod", 4)
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
@@ -1852,6 +1944,9 @@ impl EgglogOp for LessThan {
// Comparison operations always output Bool
vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)]
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("LessThan", 4)
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,