mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Fusion: strip absorbed markers and short-circuit per-step realloc walk
After region codegen folds each FusionEnd-rooted DAG into a single fused CUDA kernel, the FusionStart / nested FusionEnd / FusedX nodes that fed into it no longer need their own buffers or any other runtime state. But they were still in the LLIR, which meant `allocate_intermediate_buffers` walked them every decode token (because `p` increments and is in `intermediate_buffer_dims`), evaluating `output_bytes()` and stride expressions for ~2000 marker nodes that contribute nothing. This was the source of a +2.79 ms / decode-token regression vs the same binary with fusion ablated, and made the merged fusion branch ~10% slower than pristine `main` despite fusion saving 443 ms of GPU kernel time over the run. Total GPU work was *down* with fusion; the cost lived entirely in the per-step host walk. Three changes that fix it: 1. `runtime::CudaRuntime::allocate_intermediate_buffers`: skip nodes whose KernelOp is `FusionStart` or `FusedX*`. They never materialize buffers post region collapse. Root `FusionEnd` is kept because it's the kernel anchor for the region and does need a buffer for the region's output. 2. `runtime::CompiledBucket`: add `buffer_dyn_high_water` and short- circuit the realloc check when every current dyn-map value (for dims that affect intermediate sizing) is already <= what we last sized buffers for. With the marker walk removed and the cache hit, the per-execute "outer setup" phase falls from ~7.6 ms back to ~4.2 ms / call. 3. `kernel::to_host::kernel_to_host`: at the end of the function, remove every node in `globally_absorbed` from `llir_graph`. Region codegen has already folded them; downstream LLIR walks no longer need to ignore them per-iteration because they're gone. Numbers on llama-3-8b decode (default `cargo run -p llama`, 500 search graphs, 500 generated tokens): pristine `origin/main` (no fusion): TPOT 30.74 ms, TTFT 727 ms branch fusion ON, before this commit: TPOT 34.37 ms, TTFT 703 ms branch fusion ON, after this commit: TPOT 29.69 ms, TTFT 614 ms Fusion now beats main by ~1.05 ms / token (~3.4%) and TTFT by ~113 ms (~15.5%). Also adds a `LUMINAL_DISABLE_BINARY_FUSION=1` ablation env var on `FusionEnd::rewrites()` that skips registering any fusion rules. Lets us A/B fusion's runtime impact on a single binary without rebuilding; was essential for diagnosing this regression.
This commit is contained in:
@@ -209,6 +209,14 @@ impl EgglogOp for FusionEnd {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Ablation switch: with `LUMINAL_DISABLE_BINARY_FUSION=1` set, do
|
||||
// not register any fusion rules. The e-graph never sees the FS/FE
|
||||
// bracketed alternative, extraction always picks the un-fused
|
||||
// form, and the runtime path matches main with no fusion at all.
|
||||
// Used to A/B fusion's runtime impact on a single binary.
|
||||
if std::env::var("LUMINAL_DISABLE_BINARY_FUSION").is_ok() {
|
||||
return Vec::new();
|
||||
}
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
|
||||
@@ -897,4 +897,22 @@ pub fn kernel_to_host(
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// FusedX) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
|
||||
// chewing through dead nodes every decode token.
|
||||
//
|
||||
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
|
||||
// walks' starting points), so we keep them — they're the kernel
|
||||
// anchor for the region's compiled kernel.
|
||||
for node in globally_absorbed {
|
||||
// Defensive: only remove if the node still exists.
|
||||
if llir_graph.node_weight(node).is_some() {
|
||||
llir_graph.remove_node(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,12 @@ pub(crate) struct CompiledBucket {
|
||||
pub(crate) output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
pub(crate) last_dyn_map: FxHashMap<char, usize>,
|
||||
pub(crate) intermediate_buffer_dims: FxHashSet<char>,
|
||||
/// Per-dim high-water mark seen during `allocate_intermediate_buffers`.
|
||||
/// If every current dyn-map value is `<=` the corresponding entry here
|
||||
/// (for dims that show up in `intermediate_buffer_dims`), no buffer can
|
||||
/// possibly need to grow, and the per-execute walk over `llir_graph`
|
||||
/// can be skipped entirely. Saves ~2 ms per decode token on llama.
|
||||
pub(crate) buffer_dyn_high_water: FxHashMap<char, usize>,
|
||||
/// Which bucket index per dim this compilation targets
|
||||
pub(crate) bucket_indices: FxHashMap<char, usize>,
|
||||
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
|
||||
@@ -96,6 +102,7 @@ impl CompiledBucket {
|
||||
output_alias_map: FxHashMap::default(),
|
||||
last_dyn_map: FxHashMap::default(),
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
buffer_dyn_high_water: FxHashMap::default(),
|
||||
bucket_indices: FxHashMap::default(),
|
||||
hlir_synced: false,
|
||||
}
|
||||
@@ -664,6 +671,22 @@ impl CudaRuntime {
|
||||
if bucket.llir_graph[node].to_op::<Input>().is_some() {
|
||||
continue;
|
||||
}
|
||||
// Skip fusion marker / interior nodes. Region codegen folds
|
||||
// FusionStart / FusionEnd / FusedX into a single CUDA function
|
||||
// anchored at the FusionEnd; these marker nodes never need a
|
||||
// device buffer of their own at runtime, so walking them here
|
||||
// each step (with `p` incrementing every decode token) is
|
||||
// pure overhead. Skipping them recovers ~2 ms / token on
|
||||
// llama with fusion enabled.
|
||||
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
|
||||
let kn = op.kernel_name();
|
||||
if kn == "FusionStart" || kn.starts_with("Fused") {
|
||||
continue;
|
||||
}
|
||||
// Note: we deliberately keep "FusionEnd" because it is the
|
||||
// anchor for the region's compiled kernel and DOES need a
|
||||
// buffer for the region's output.
|
||||
}
|
||||
let needed_bytes =
|
||||
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
|
||||
let out_bytes = op.output_bytes();
|
||||
@@ -696,6 +719,17 @@ impl CudaRuntime {
|
||||
let ptr = bucket.buffers[&node].device_ptr(stream).0;
|
||||
bucket.cached_buffer_ptrs.insert(node, ptr);
|
||||
}
|
||||
// Update the high-water mark for the dims that actually drive
|
||||
// buffer sizing, so subsequent execute() calls can short-circuit
|
||||
// the realloc walk while we're still within the envelope we sized
|
||||
// for.
|
||||
for d in &bucket.intermediate_buffer_dims {
|
||||
let v = dyn_dims.get(d).copied().unwrap_or(0);
|
||||
let entry = bucket.buffer_dyn_high_water.entry(*d).or_insert(0);
|
||||
if v > *entry {
|
||||
*entry = v;
|
||||
}
|
||||
}
|
||||
let _ = (realloc_count, total_alloc);
|
||||
}
|
||||
|
||||
@@ -1041,13 +1075,29 @@ impl Runtime for CudaRuntime {
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
let buffers_empty = bucket.buffers.is_empty();
|
||||
let dyn_map_len_changed = dyn_map.len() != bucket.last_dyn_map.len();
|
||||
// High-water-mark fast path: if buffers exist and every current
|
||||
// dyn-map value (for dims that affect intermediate sizing) is
|
||||
// already <= what we previously sized buffers for, nothing can
|
||||
// need reallocation. Avoids walking ~6k LLIR nodes every decode
|
||||
// token just to discover "still big enough".
|
||||
let within_high_water = !buffers_empty
|
||||
&& !dyn_map_len_changed
|
||||
&& dyn_map
|
||||
.iter()
|
||||
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
|
||||
.all(|(d, v)| bucket.buffer_dyn_high_water.get(d).is_some_and(|hw| v <= hw));
|
||||
let dyn_dims_changed = dyn_map
|
||||
.iter()
|
||||
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
|
||||
.any(|(d, v)| bucket.last_dyn_map.get(d).map(|n| *n != *v).unwrap_or(true));
|
||||
let needs_realloc = buffers_empty || dyn_map_len_changed || dyn_dims_changed;
|
||||
if needs_realloc {
|
||||
let needs_realloc = !within_high_water
|
||||
&& (buffers_empty || dyn_map_len_changed || dyn_dims_changed);
|
||||
if !buffers_empty || dyn_dims_changed {
|
||||
// Always remember the latest dyn_map even when we skip the walk,
|
||||
// so tracing fields like last_dyn_map remain accurate.
|
||||
bucket.last_dyn_map = dyn_map.clone();
|
||||
}
|
||||
if needs_realloc {
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
}
|
||||
// Cache HLIR input pointers
|
||||
|
||||
Reference in New Issue
Block a user