fusion: family-gating env var + subsume inner FE in grow rules

Two changes around the elementwise fusion blowup that OOMs the host CPU at
538 GB RSS on the full 32B flux2 transformer.

1. LUMINAL_FUSION_FAMILIES env var: comma-separated subset of
   {uu, bu, ub, bb}. When set, only those families' pair-fuse rules are
   emitted. Default (env unset) keeps all four families as before. Confirmed
   on flux2 transformer:
     - all four families   → 538 GB CPU (OOM)
     - uu                  → 128 GB CPU, slower at runtime (rare U-U in flux2)
     - uu + bu + ub        → 141 GB CPU, matches no-fusion runtime (4.1 s/step)
     - bb only             → 538+ GB CPU (killed)
   So bb is the binding combinatorial constraint — each bb match adds 6
   enodes (3 FusionStart + 2 FusedBinary + 1 FusionEnd) and the pair-fuse
   matcher enumerates O(B²) binary-binary pairs in one pass.

2. Subsume the inner FusionEnd in all `grow-FE-*` rules. Once an FE has been
   extended by a downstream op, the smaller (partially-fused) FE has no
   value — the un-fused KernelX chain is still extractable via the
   pair-fuse union, so multi-consumer fan-out still works. This matches the
   "only the un-fused or the fully-fused variant" search-space design intent
   from the discussion. Note: subsume here does *not* fix the BB OOM (which
   happens in pair-fuse before any grow rule fires); it just cleans up the
   eclass alternatives.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Joe Fioti
2026-05-11 16:17:36 +00:00
parent 1231addcdb
commit 1edd4cfea3

View File

@@ -153,6 +153,24 @@ impl EgglogOp for FusionEnd {
// out, no transpose); a binary feeding a downstream op binds the
// binary's out-stride to the downstream op's in-stride along the
// connecting side.
//
// The blowup we see in fusion_pair on the 32B flux2 transformer is
// dominated by families 3 and 4 (U-B and B-B) because each match adds
// 2-3 new FusionStart enodes; with many binary-binary chains in the
// graph, the egraph saturate-iteration count balloons. Use
// `LUMINAL_FUSION_FAMILIES=uu` to keep only family 1 (U-U) — the
// safest subset — or `LUMINAL_FUSION_FAMILIES=uu,bu,ub` to drop the
// B-B family while keeping the rest. Default: all families on.
let families = std::env::var("LUMINAL_FUSION_FAMILIES").ok();
let allowed: Option<std::collections::HashSet<String>> = families
.as_ref()
.map(|s| s.split(',').map(|p| p.trim().to_lowercase()).collect());
let family_on = |name: &str| -> bool {
allowed
.as_ref()
.map(|set| set.contains(name))
.unwrap_or(true)
};
let mut rules = Vec::new();
// (KernelX kind, FusedX kind)
@@ -171,6 +189,7 @@ impl EgglogOp for FusionEnd {
];
// 1. Pair-fuse U → U: U2(U1(x)) → FE(FU2(FU1(FS(x)))).
if family_on("uu") {
for (ki1, fi1) in unaries {
for (ko2, fo2) in unaries {
rules.push(Rule::raw(format!(
@@ -187,8 +206,10 @@ impl EgglogOp for FusionEnd {
)));
}
}
}
// 2. Pair-fuse B → U: U(B(a, b)) → FE(FU(FB(FS(a), FS(b)))).
if family_on("bu") {
for (kb, fb, lb) in binaries {
for (ku, fu) in unaries {
rules.push(Rule::raw(format!(
@@ -208,10 +229,12 @@ impl EgglogOp for FusionEnd {
)));
}
}
}
// 3. Pair-fuse U → B (lhs / rhs): unary feeds binary's A or B input.
// LHS: B(U(a), b) → FE(FB(FU(FS(a)), FS(b))).
// RHS: B(a, U(b)) → FE(FB(FS(a), FU(FS(b)))).
if family_on("ub") {
for (ku, fu) in unaries {
for (kb, fb, lb) in binaries {
rules.push(Rule::raw(format!(
@@ -246,8 +269,10 @@ impl EgglogOp for FusionEnd {
)));
}
}
}
// 4. Pair-fuse B → B (lhs / rhs): inner binary feeds outer's A or B.
if family_on("bb") {
for (kbi, fbi, lbi) in binaries {
for (kbo, fbo, lbo) in binaries {
rules.push(Rule::raw(format!(
@@ -288,8 +313,17 @@ impl EgglogOp for FusionEnd {
)));
}
}
}
// 5. Grow FE → U: U(FE(inner)) → FE(FU(inner)). No new FS.
//
// Subsume the inner FE — once it's been extended, the partially-fused
// version is dominated by the larger region. The un-fused KernelX
// chain stays via union (pair-fuse never subsumed it), so multi-use
// consumers can still fall back to the un-fused alternative. Without
// this, every intermediate region size coexists in the egraph as an
// alternative, which is the partial-fusion explosion we see on the
// 32B flux2 transformer.
for (ku, fu) in unaries {
rules.push(Rule::raw(format!(
"(rule (
@@ -299,11 +333,13 @@ impl EgglogOp for FusionEnd {
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
(union ?u ?new_fe)
(subsume (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
) :ruleset fusion_grow :name \"grow-FE-U-{ku}\")"
)));
}
// 6. Grow FE → B (lhs / rhs): one input is the FE, the other external.
// Same subsume rationale as Grow FE → U.
for (kb, fb, lb) in binaries {
rules.push(Rule::raw(format!(
"(rule (
@@ -316,6 +352,7 @@ impl EgglogOp for FusionEnd {
(ICons ?inner_a (ICons ?fs_b (INil)))))
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
(union ?bin ?new_fe)
(subsume (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{lb}\")"
)));
rules.push(Rule::raw(format!(
@@ -329,6 +366,7 @@ impl EgglogOp for FusionEnd {
(ICons ?fs_a (ICons ?inner_b (INil)))))
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
(union ?bin ?new_fe)
(subsume (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{lb}\")"
)));
}