Compare commits

...

2 Commits

Author SHA1 Message Date
Tucker Morgan
d000452167 graph: release search-time e-graphs + bump profile_timeout default to 5s
Two related changes that address gemma-scale torch.compile reliability
on hosts where the search-time memory budget is tight.

(1) New `Graph::release_search_state()` clears `self.egraphs` and
`self.egraph_contexts` (the saturated e-graphs and per-bucket search
contexts produced by `build_search_space`). Once `search()` returns a
winning genome and the runtime is built, the e-graphs are dead weight
— their tuples have been distilled into the chosen LLIR and there's no
re-entrant code path that reads them. `compile_backend` now calls this
right after `graph.search(...)` so a `process_pt2` call returns a
CompiledGraph whose `Graph` is free of those structures.

Why it matters: torch.compile typically issues two compiles for a
cached-decode loop (one for prefill at s=N, one for decode at s=1).
Each one runs `build_search_space` and produces its own e-graphs. On
Gemma-4-26B-A4B each e-graph saturates to ~140k tuples and the peak
host RSS for a single pass approaches 470 GiB on a 525 GiB box.
Without the release, pass 1's CompiledGraph still holds its e-graphs
when pass 2 starts saturating its own — pass 2's egglog cycle 001
OOM-kills the Python process well before completing.

With the release, pass 1's e-graphs are freed before pass 2 starts.
On Gemma the prefill compile now succeeds end-to-end where it
previously SIGKILLed mid-saturation. The decode-pass compile still
OOMs on this hardware (single-pass peak alone is ~470 GiB, leaving
no room above the 60 GiB Python/transformers baseline), but that's a
separate scaling issue: even one e-graph is barely-fits on a 525 GiB
host. Boxes with 1 TB+ host RAM now compile both passes cleanly.

(2) `CompileOptions::new()` default `profile_timeout` 1s → 5s. The
1s default is too tight for Gemma-26B and Qwen3-30B candidate
profiles on shared-GPU benchmark hosts — 100/100 GA initial-genome
candidates time out under it. luminal-bench's per-model rust crates
were already setting their own 5s budget for the same reason
(luminal-benchmarks PR #8). Bumping the library default avoids
forcing every downstream consumer to discover and rediscover the
same workaround.

Verified on Gemma-4-26B-A4B end-to-end via `torch.compile(model,
backend=luminal_backend)`: combined with the prior bf16 codegen fixes
(d7e8629 or similar), the prefill compile now consistently passes
the initial-genome stage (1-2 timeout rejections out of the budget,
then a viable candidate) and completes the search.
2026-05-27 20:38:01 +00:00
Tucker Morgan
fc1345c0a6 luminal_cuda_lite: fix bf16 ambiguous-operator NVRTC errors
Two related codegen bugs that made it impossible for torch.compile to
compile bf16 models that hit Recip / Sigmoid / reduce-max kernels via
luminal_python's PT2 backend.

(1) reduce_max kernel ternary at hlir.rs:216 — `tid < cnt ? warp_sums[tid]
: NEG_INF_F` is unambiguous when {dtype} is float, but for bf16 NVRTC
errors with:

    error: ambiguous "?" operation:
      second operand of type "__nv_bfloat16" can be converted to third
      operand type "float", and vice versa
        __nv_bfloat16 block_max = tid < cnt ? warp_sums[tid] : NEG_INF_F;

`__nv_bfloat16` has both `operator float() const` (implicit) and
`explicit __nv_bfloat16(float)` constructor visible to NVRTC, so the
overload resolver can't decide which conversion to apply. Cast the
sentinel: `({dtype})NEG_INF_F`.

(2) per-region elementwise codegen — `Recip` emits `1.0f / {}` and
`Sigmoid` emits `1.0f / (1.0f + expf(-{}))`. When the operand is bf16
the resulting expression has the same `float-or-bf16-overload`
ambiguity for the division operator. There's already a precedent: the
fp8 path uses `elementwise_value` to promote operands to float before
the math and `elementwise_init_expr` to cast the result back at
assignment, sidestepping the ambiguity entirely. Extend that helper to
also cover Bf16 and F16 — same memory layout assumptions, same kernel
semantics, no perf change (the float math was already happening
internally for transcendentals like expf anyway).

Verified end-to-end on Gemma-4-26B-A4B through
`torch.compile(model, backend=luminal_backend)`: previously 100/100 GA
initial-genome candidates were rejected with NVRTC compilation panics
caused by these two ambiguities. After the fix, candidates compile
cleanly and the prefill pass completes.
2026-05-27 20:37:10 +00:00
4 changed files with 43 additions and 4 deletions

View File

@@ -336,8 +336,21 @@ fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
})
}
// Dtypes for which we promote operands to float for the per-region inline
// math expression and then cast back at assignment. Without this NVRTC
// rejects mixed-precision literals like `1.0f / v_bf16` as ambiguous (the
// `1.0f / bf16` operator overload set has two viable resolutions). Float
// math then implicit-cast back to the storage dtype keeps codegen simple
// without changing arithmetic semantics.
fn needs_float_promotion(dtype: DType) -> bool {
matches!(
dtype,
DType::Bf16 | DType::F16 | DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0
)
}
fn elementwise_value(local: &str, dtype: DType) -> String {
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
if needs_float_promotion(dtype) {
format!("static_cast<float>({local})")
} else {
local.to_string()
@@ -345,7 +358,7 @@ fn elementwise_value(local: &str, dtype: DType) -> String {
}
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
if needs_float_promotion(dtype) {
format!("{cuda_ty}({expr})")
} else {
expr.to_string()

View File

@@ -213,7 +213,11 @@ extern \"C\" {{
if (warp_id == 0) {{
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
{dtype} block_max = tid < cnt ? warp_sums[tid] : NEG_INF_F;
// C-style cast to {dtype} so the ternary is unambiguous for both
// float (no-op cast) and __nv_bfloat16 (calls the explicit bf16
// constructor). Without the cast NVRTC errors with ambiguous
// ternary on bf16 reductions.
{dtype} block_max = tid < cnt ? warp_sums[tid] : ({dtype})NEG_INF_F;
#pragma unroll
for (int s = cnt / 2; s > 0; s /= 2) {{

View File

@@ -176,6 +176,13 @@ pub fn compile_backend<Rt: Runtime + 'static>(
// Search
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
// Drop the saturated e-graphs now that search picked a winner — they
// hold 100k+ tuples each on big models and the runtime doesn't need
// them anymore. Without this, gemma-style two-pass torch.compile
// (prefill + decode) accumulates both passes' e-graphs on the heap
// and OOMs partway through pass 2.
graph.release_search_state();
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);

View File

@@ -182,7 +182,7 @@ impl CompileOptions {
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: Some(std::time::Duration::from_secs(1)),
profile_timeout: Some(std::time::Duration::from_secs(5)),
group_timeout: None,
profile_dims: FxHashMap::default(),
dim_buckets: FxHashMap::default(),
@@ -1257,6 +1257,21 @@ impl Graph {
self.egraphs.first()
}
/// Drop the search-time e-graph storage. The saturated e-graphs and
/// their context vectors are only consulted during `build_search_space`
/// + `search`; once a winning genome has been picked and the runtime is
/// built, they're dead weight on the heap. For models like Gemma where
/// each e-graph reaches ~140k tuples, keeping them alive between the
/// two torch.compile passes (prefill + decode) blows past 525 GiB host
/// RAM and the OS OOM-kills the process partway through pass 2. Call
/// this after `search()` returns to release that headroom.
pub fn release_search_state(&mut self) {
self.egraphs.clear();
self.egraphs.shrink_to_fit();
self.egraph_contexts.clear();
self.egraph_contexts.shrink_to_fit();
}
/// Get a reference to the available ops (if search space is built)
pub fn egglog_ops(&self) -> Option<&Vec<Arc<Box<dyn EgglogOp>>>> {
self.ops.as_ref()