mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
tucker/cub
...
tucker/bf1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d000452167 | ||
|
|
fc1345c0a6 |
@@ -336,8 +336,21 @@ fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
|
||||
})
|
||||
}
|
||||
|
||||
// Dtypes for which we promote operands to float for the per-region inline
|
||||
// math expression and then cast back at assignment. Without this NVRTC
|
||||
// rejects mixed-precision literals like `1.0f / v_bf16` as ambiguous (the
|
||||
// `1.0f / bf16` operator overload set has two viable resolutions). Float
|
||||
// math then implicit-cast back to the storage dtype keeps codegen simple
|
||||
// without changing arithmetic semantics.
|
||||
fn needs_float_promotion(dtype: DType) -> bool {
|
||||
matches!(
|
||||
dtype,
|
||||
DType::Bf16 | DType::F16 | DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0
|
||||
)
|
||||
}
|
||||
|
||||
fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
if needs_float_promotion(dtype) {
|
||||
format!("static_cast<float>({local})")
|
||||
} else {
|
||||
local.to_string()
|
||||
@@ -345,7 +358,7 @@ fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
}
|
||||
|
||||
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
if needs_float_promotion(dtype) {
|
||||
format!("{cuda_ty}({expr})")
|
||||
} else {
|
||||
expr.to_string()
|
||||
|
||||
@@ -213,7 +213,11 @@ extern \"C\" {{
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
{dtype} block_max = tid < cnt ? warp_sums[tid] : NEG_INF_F;
|
||||
// C-style cast to {dtype} so the ternary is unambiguous for both
|
||||
// float (no-op cast) and __nv_bfloat16 (calls the explicit bf16
|
||||
// constructor). Without the cast NVRTC errors with ambiguous
|
||||
// ternary on bf16 reductions.
|
||||
{dtype} block_max = tid < cnt ? warp_sums[tid] : ({dtype})NEG_INF_F;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
|
||||
@@ -176,6 +176,13 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
// Search
|
||||
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
|
||||
|
||||
// Drop the saturated e-graphs now that search picked a winner — they
|
||||
// hold 100k+ tuples each on big models and the runtime doesn't need
|
||||
// them anymore. Without this, gemma-style two-pass torch.compile
|
||||
// (prefill + decode) accumulates both passes' e-graphs on the heap
|
||||
// and OOMs partway through pass 2.
|
||||
graph.release_search_state();
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
|
||||
17
src/graph.rs
17
src/graph.rs
@@ -182,7 +182,7 @@ impl CompileOptions {
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
profile_timeout: Some(std::time::Duration::from_secs(5)),
|
||||
group_timeout: None,
|
||||
profile_dims: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
@@ -1257,6 +1257,21 @@ impl Graph {
|
||||
self.egraphs.first()
|
||||
}
|
||||
|
||||
/// Drop the search-time e-graph storage. The saturated e-graphs and
|
||||
/// their context vectors are only consulted during `build_search_space`
|
||||
/// + `search`; once a winning genome has been picked and the runtime is
|
||||
/// built, they're dead weight on the heap. For models like Gemma where
|
||||
/// each e-graph reaches ~140k tuples, keeping them alive between the
|
||||
/// two torch.compile passes (prefill + decode) blows past 525 GiB host
|
||||
/// RAM and the OS OOM-kills the process partway through pass 2. Call
|
||||
/// this after `search()` returns to release that headroom.
|
||||
pub fn release_search_state(&mut self) {
|
||||
self.egraphs.clear();
|
||||
self.egraphs.shrink_to_fit();
|
||||
self.egraph_contexts.clear();
|
||||
self.egraph_contexts.shrink_to_fit();
|
||||
}
|
||||
|
||||
/// Get a reference to the available ops (if search space is built)
|
||||
pub fn egglog_ops(&self) -> Option<&Vec<Arc<Box<dyn EgglogOp>>>> {
|
||||
self.ops.as_ref()
|
||||
|
||||
Reference in New Issue
Block a user