Fusion: strip absorbed markers and short-circuit per-step realloc walk

After region codegen folds each FusionEnd-rooted DAG into a single fused
CUDA kernel, the FusionStart / nested FusionEnd / FusedX nodes that fed
into it no longer need their own buffers or any other runtime state.
But they were still in the LLIR, which meant `allocate_intermediate_buffers`
walked them every decode token (because `p` increments and is in
`intermediate_buffer_dims`), evaluating `output_bytes()` and stride
expressions for ~2000 marker nodes that contribute nothing.

This was the source of a +2.79 ms / decode-token regression vs the same
binary with fusion ablated, and made the merged fusion branch ~10%
slower than pristine `main` despite fusion saving 443 ms of GPU kernel
time over the run. Total GPU work was *down* with fusion; the cost
lived entirely in the per-step host walk.

Three changes that fix it:

1. `runtime::CudaRuntime::allocate_intermediate_buffers`: skip nodes
   whose KernelOp is `FusionStart` or `FusedX*`. They never materialize
   buffers post region collapse. Root `FusionEnd` is kept because it's
   the kernel anchor for the region and does need a buffer for the
   region's output.

2. `runtime::CompiledBucket`: add `buffer_dyn_high_water` and short-
   circuit the realloc check when every current dyn-map value (for
   dims that affect intermediate sizing) is already <= what we last
   sized buffers for. With the marker walk removed and the cache hit,
   the per-execute "outer setup" phase falls from ~7.6 ms back to
   ~4.2 ms / call.

3. `kernel::to_host::kernel_to_host`: at the end of the function,
   remove every node in `globally_absorbed` from `llir_graph`. Region
   codegen has already folded them; downstream LLIR walks no longer
   need to ignore them per-iteration because they're gone.

Numbers on llama-3-8b decode (default `cargo run -p llama`,
500 search graphs, 500 generated tokens):

  pristine `origin/main` (no fusion):     TPOT 30.74 ms, TTFT 727 ms
  branch fusion ON, before this commit:   TPOT 34.37 ms, TTFT 703 ms
  branch fusion ON, after this commit:    TPOT 29.69 ms, TTFT 614 ms

Fusion now beats main by ~1.05 ms / token (~3.4%) and TTFT by
~113 ms (~15.5%).

Also adds a `LUMINAL_DISABLE_BINARY_FUSION=1` ablation env var on
`FusionEnd::rewrites()` that skips registering any fusion rules.
Lets us A/B fusion's runtime impact on a single binary without
rebuilding; was essential for diagnosing this regression.
This commit is contained in:
Matthew Gunton
2026-04-29 04:05:11 +00:00
parent 8bdcae291c
commit 88bcd12a96
3 changed files with 78 additions and 2 deletions

View File

@@ -209,6 +209,14 @@ impl EgglogOp for FusionEnd {
}
fn rewrites(&self) -> Vec<Rule> {
// Ablation switch: with `LUMINAL_DISABLE_BINARY_FUSION=1` set, do
// not register any fusion rules. The e-graph never sees the FS/FE
// bracketed alternative, extraction always picks the un-fused
// form, and the runtime path matches main with no fusion at all.
// Used to A/B fusion's runtime impact on a single binary.
if std::env::var("LUMINAL_DISABLE_BINARY_FUSION").is_ok() {
return Vec::new();
}
// Seven rule families build and extend FE-bracketed regions. Each
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
// RHS produces `FusedX` variants in a different egglog sort, so the

View File

@@ -897,4 +897,22 @@ pub fn kernel_to_host(
llir_graph.add_edge(src, dst, ());
}
}
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
// FusedX) from the LLIR. Region codegen has already folded them into
// a single fused CUDA function anchored at each region's root
// FusionEnd; the absorbed nodes have no consumers outside the region
// and never need their own buffers. Removing them keeps later
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
// chewing through dead nodes every decode token.
//
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
// walks' starting points), so we keep them — they're the kernel
// anchor for the region's compiled kernel.
for node in globally_absorbed {
// Defensive: only remove if the node still exists.
if llir_graph.node_weight(node).is_some() {
llir_graph.remove_node(node);
}
}
}

View File

@@ -77,6 +77,12 @@ pub(crate) struct CompiledBucket {
pub(crate) output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
pub(crate) last_dyn_map: FxHashMap<char, usize>,
pub(crate) intermediate_buffer_dims: FxHashSet<char>,
/// Per-dim high-water mark seen during `allocate_intermediate_buffers`.
/// If every current dyn-map value is `<=` the corresponding entry here
/// (for dims that show up in `intermediate_buffer_dims`), no buffer can
/// possibly need to grow, and the per-execute walk over `llir_graph`
/// can be skipped entirely. Saves ~2 ms per decode token on llama.
pub(crate) buffer_dyn_high_water: FxHashMap<char, usize>,
/// Which bucket index per dim this compilation targets
pub(crate) bucket_indices: FxHashMap<char, usize>,
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
@@ -96,6 +102,7 @@ impl CompiledBucket {
output_alias_map: FxHashMap::default(),
last_dyn_map: FxHashMap::default(),
intermediate_buffer_dims: FxHashSet::default(),
buffer_dyn_high_water: FxHashMap::default(),
bucket_indices: FxHashMap::default(),
hlir_synced: false,
}
@@ -664,6 +671,22 @@ impl CudaRuntime {
if bucket.llir_graph[node].to_op::<Input>().is_some() {
continue;
}
// Skip fusion marker / interior nodes. Region codegen folds
// FusionStart / FusionEnd / FusedX into a single CUDA function
// anchored at the FusionEnd; these marker nodes never need a
// device buffer of their own at runtime, so walking them here
// each step (with `p` incrementing every decode token) is
// pure overhead. Skipping them recovers ~2 ms / token on
// llama with fusion enabled.
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
let kn = op.kernel_name();
if kn == "FusionStart" || kn.starts_with("Fused") {
continue;
}
// Note: we deliberately keep "FusionEnd" because it is the
// anchor for the region's compiled kernel and DOES need a
// buffer for the region's output.
}
let needed_bytes =
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
let out_bytes = op.output_bytes();
@@ -696,6 +719,17 @@ impl CudaRuntime {
let ptr = bucket.buffers[&node].device_ptr(stream).0;
bucket.cached_buffer_ptrs.insert(node, ptr);
}
// Update the high-water mark for the dims that actually drive
// buffer sizing, so subsequent execute() calls can short-circuit
// the realloc walk while we're still within the envelope we sized
// for.
for d in &bucket.intermediate_buffer_dims {
let v = dyn_dims.get(d).copied().unwrap_or(0);
let entry = bucket.buffer_dyn_high_water.entry(*d).or_insert(0);
if v > *entry {
*entry = v;
}
}
let _ = (realloc_count, total_alloc);
}
@@ -1041,13 +1075,29 @@ impl Runtime for CudaRuntime {
let bucket = &mut self.compiled_buckets[self.active_bucket];
let buffers_empty = bucket.buffers.is_empty();
let dyn_map_len_changed = dyn_map.len() != bucket.last_dyn_map.len();
// High-water-mark fast path: if buffers exist and every current
// dyn-map value (for dims that affect intermediate sizing) is
// already <= what we previously sized buffers for, nothing can
// need reallocation. Avoids walking ~6k LLIR nodes every decode
// token just to discover "still big enough".
let within_high_water = !buffers_empty
&& !dyn_map_len_changed
&& dyn_map
.iter()
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
.all(|(d, v)| bucket.buffer_dyn_high_water.get(d).is_some_and(|hw| v <= hw));
let dyn_dims_changed = dyn_map
.iter()
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
.any(|(d, v)| bucket.last_dyn_map.get(d).map(|n| *n != *v).unwrap_or(true));
let needs_realloc = buffers_empty || dyn_map_len_changed || dyn_dims_changed;
if needs_realloc {
let needs_realloc = !within_high_water
&& (buffers_empty || dyn_map_len_changed || dyn_dims_changed);
if !buffers_empty || dyn_dims_changed {
// Always remember the latest dyn_map even when we skip the walk,
// so tracing fields like last_dyn_map remain accurate.
bucket.last_dyn_map = dyn_map.clone();
}
if needs_realloc {
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
}
// Cache HLIR input pointers