Memory analysis post pass (#303)

* Simplify CUDA memory analysis and arena planning

* Simplify CUDA memory planning and fix clippy warnings
This commit is contained in:
Joe Fioti
2026-05-08 08:24:37 -07:00
committed by GitHub
parent 53f7960130
commit 1279dca4e6
16 changed files with 2511 additions and 387 deletions

View File

@@ -19,9 +19,9 @@ use crate::{
CudaBlas,
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
},
driver::{CudaSlice, CudaStream, DevicePtr},
driver::CudaStream,
},
host::HostOp,
host::{DeviceBuffer, HostOp},
};
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
let b_buf = buffers[&inputs[1]];
// Get device pointers
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
let a_ptr = a_buf.ptr();
let b_ptr = b_buf.ptr();
let c_ptr = c_buf.ptr();
// Debug: Check buffer sizes
trace!(

View File

@@ -29,9 +29,9 @@ use crate::{
cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy, cudaDataType,
},
},
driver::{CudaSlice, CudaStream, DevicePtr},
driver::{CudaStream, DevicePtr},
},
host::{HostOp, cublas::parse_cublas_op},
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
try_create_cublaslt,
};
@@ -268,7 +268,7 @@ impl HostOp for CuBlasLt {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
use crate::cudarc::cublaslt::sys::{
@@ -309,9 +309,9 @@ impl HostOp for CuBlasLt {
let b_buf = buffers[&inputs[1]];
// Get device pointers
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
let a_ptr = a_buf.ptr();
let b_ptr = b_buf.ptr();
let c_ptr = c_buf.ptr();
// Clamp leading dimensions to minimum valid values.
// When a dimension is 1 (e.g., k=1 outer product), the stride along that

View File

@@ -1,6 +1,6 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaSlice, CudaStream};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use luminal::{op::EgglogOp, prelude::*};
mod cublas;
mod cublaslt;
@@ -12,6 +12,44 @@ pub type Ops = (
moe::GLUMoE,
);
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
/// the reusable arena, or an external pointer. Host ops only need the pointer
/// and the logical byte length.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DeviceBuffer {
ptr: u64,
len: usize,
}
impl DeviceBuffer {
pub fn new(ptr: u64, len: usize) -> Self {
Self { ptr, len }
}
pub fn ptr(self) -> u64 {
self.ptr
}
pub fn len(self) -> usize {
self.len
}
pub fn is_empty(self) -> bool {
self.len == 0
}
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
let mut host = vec![0u8; self.len];
unsafe {
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
}
stream.synchronize()?;
Ok(host)
}
}
/// Host operations that execute on the CPU but orchestrate GPU work.
///
/// This includes operations like cuBLAS calls and CUDA graph executions.
@@ -29,7 +67,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()>;
@@ -48,6 +86,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
vec![]
}
/// Returns relative lifetimes for extra buffer nodes within this host op.
///
/// The tuple is `(node, first_step, last_step)`, where steps are local to
/// this host op's execution. Returning `None` tells the runtime to treat
/// every extra buffer as live for the whole host op.
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
None
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.

View File

@@ -32,7 +32,7 @@ use crate::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
},
host::HostOp,
host::{DeviceBuffer, HostOp},
try_create_cublaslt,
};
@@ -294,27 +294,140 @@ impl HostOp for GLUMoE {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// Resolve dimensions
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
if inputs.len() < 6 {
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
}
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
let x_buf = buffers[&inputs[0]];
let seq = x_buf.len() / (hidden * 4);
// Resolve dimensions
let hidden = self
.gu_matmul_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
let intermediate = self
.dn_matmul_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
let top_k = self
.output_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
let gu_io = self
.gu_io
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
let dn_io = self
.dn_io
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
if hidden == 0 || intermediate == 0 {
anyhow::bail!(
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
);
}
if top_k == 0 {
return Ok(());
}
if gu_io % hidden != 0 {
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
}
if dn_io % intermediate != 0 {
anyhow::bail!(
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
);
}
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
let down_hidden = dn_io / intermediate;
if gate_up_dim != intermediate * 2 {
anyhow::bail!(
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
gate_up_dim,
intermediate * 2
);
}
if down_hidden != hidden {
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
}
let output_bytes = self
.output_bytes()
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
if output_bytes % (hidden * 4) != 0 {
anyhow::bail!(
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
hidden * 4
);
}
let seq = output_bytes / (hidden * 4);
if seq == 0 {
return Ok(());
}
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
buffers.get(&node).copied().ok_or_else(|| {
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
})
};
// Get input/output buffers
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
let mode_aux_buf = buffers[&inputs[5]];
let output_buf = buffers[&self_node]; // [seq, hidden] F32
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
let topk_bytes = seq * top_k * 4;
if x_buf.len() < output_bytes {
anyhow::bail!(
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
x_buf.len()
);
}
if topk_idx_buf.len() < topk_bytes {
anyhow::bail!(
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
topk_idx_buf.len()
);
}
if topk_vals_buf.len() < topk_bytes {
anyhow::bail!(
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
topk_vals_buf.len()
);
}
if output_buf.len() < output_bytes {
anyhow::bail!(
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
output_buf.len()
);
}
let gu_stride_bytes = gate_up_dim * hidden * 2;
let down_stride_bytes = hidden * intermediate * 2;
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
anyhow::bail!(
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
gate_up_buf.len()
);
}
let num_experts = gate_up_buf.len() / gu_stride_bytes;
if num_experts == 0 {
anyhow::bail!("GLUMoE has no expert weights");
}
if down_buf.len() < num_experts * down_stride_bytes {
anyhow::bail!(
"GLUMoE down weight buffer too small: have {} bytes, need {}",
down_buf.len(),
num_experts * down_stride_bytes
);
}
// Get raw device pointer addresses
let x_ptr = buf_ptr(x_buf, stream);
@@ -326,21 +439,17 @@ impl HostOp for GLUMoE {
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
// Read top-k routing values from GPU
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
let idx_k = topk_idx_i32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let val_k = topk_vals_f32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let top_k = idx_k.min(val_k);
if seq > 0 && top_k == 0 {
return Ok(());
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
if expert_idx < 0 || expert_idx as usize >= num_experts {
anyhow::bail!(
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
);
}
}
// Mode-dependent expert weights used for the final reduction:
@@ -350,9 +459,16 @@ impl HostOp for GLUMoE {
let expert_weights_f32: &[f32] = match self.mode {
GLUMoEMode::SwiGLU => topk_vals_f32,
GLUMoEMode::GemmaGELU => {
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
debug_assert!(per_expert_scale_f32.len() >= num_experts);
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
let per_expert_scale_bytes = num_experts * 4;
if per_expert_scale_host.len() < per_expert_scale_bytes {
anyhow::bail!(
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
per_expert_scale_host.len()
);
}
let per_expert_scale_f32: &[f32] =
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
expert_weights_storage.resize(seq * top_k, 0.0);
for t in 0..seq {
let base = t * top_k;
@@ -382,10 +498,10 @@ impl HostOp for GLUMoE {
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
let hid_ptr = buf_ptr(&hidden_tmp, stream);
let ws_ptr = buf_ptr(&workspace, stream);
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
let hid_ptr = slice_ptr(&hidden_tmp, stream);
let ws_ptr = slice_ptr(&workspace, stream);
// Cast x F32 → BF16
let n_cast = (seq * hidden) as i32;
@@ -404,8 +520,8 @@ impl HostOp for GLUMoE {
}
// Per-token expert computation
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
for t in 0..seq {
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
@@ -507,7 +623,11 @@ impl HostOp for GLUMoE {
// Helpers
// ============================================================
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
buf.ptr()
}
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
let (ptr, _guard) = buf.device_ptr(stream);
ptr
}

View File

@@ -113,6 +113,9 @@ impl KernelOp for FusionStart {
fn kernel_name(&self) -> &'static str {
"FusionStart"
}
fn output_aliases_input(&self) -> Option<usize> {
Some(0)
}
}
// =========================================================================

View File

@@ -289,6 +289,24 @@ impl EgglogOp for KernelScatterNoCopy {
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
let mut rules = vec![
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
Rule::raw(
"(rule
((= ?list (ICons ?head ?tail)))
((consumed_buffer_ilist_contains ?list ?head))
:ruleset cleanup
:name \"consumed-buffer-ilist-contains-head\"
)",
),
Rule::raw(
"(rule
((= ?list (ICons ?head ?tail))
(consumed_buffer_ilist_contains ?tail ?item))
((consumed_buffer_ilist_contains ?list ?item))
:ruleset cleanup
:name \"consumed-buffer-ilist-contains-tail\"
)",
),
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
Rule::raw(
"(rule
@@ -324,13 +342,28 @@ impl EgglogOp for KernelScatterNoCopy {
"(rule
((= ?cb (ConsumedBuffer ?a))
(= ?op1 (Op ?k1 ?ilist1))
(= ?ilist1 (ICons ?cb ?rest1))
(consumed_buffer_ilist_contains ?ilist1 ?cb)
(= ?op2 (Op ?k2 ?ilist2))
(!= ?op1 ?op2)
(= ?ilist2 (ICons ?a ?t2)))
(consumed_buffer_ilist_contains ?ilist2 ?a))
((delete (ConsumedBuffer ?a)))
:ruleset cleanup
:name \"consumed-buffer-cleanup-pos\"
:name \"consumed-buffer-cleanup-shared-op-use\"
)",
));
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
rules.push(Rule::raw(
"(rule
((= ?cb (ConsumedBuffer ?dest))
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
:ruleset post_cleanup
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
)",
));
// Surviving ConsumedBuffers are valid — union with source and delete.

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use luminal::{
egglog_utils::{api::Rule, base::OP_KIND},
graph::LLIRGraph,
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
op::{EgglogOp, LLIROp},
prelude::{
petgraph::{Direction, algo::toposort, visit::EdgeRef},
@@ -22,7 +23,7 @@ use luminal::{
use tracing::{Level, enabled, span};
use crate::{
host::HostOp,
host::{DeviceBuffer, HostOp},
kernel::{
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
destroy_cuda_event,
@@ -47,8 +48,12 @@ struct CompiledKernel {
shared_mem: Expression,
/// Input node indices (for buffer lookup)
inputs: Vec<NodeIndex>,
/// Human-readable labels for input nodes, for launch diagnostics.
input_labels: Vec<String>,
/// Reference to the KernelOp for trait methods
kernel_op: Arc<Box<dyn KernelOp>>,
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
has_dyn_dims_param: bool,
/// Internal buffers allocated for this kernel
internal_bufs: Vec<CudaSlice<u8>>,
/// Device constants from compile()
@@ -68,7 +73,9 @@ impl CompiledKernel {
block: (Expression, Expression, Expression),
shared_mem: Expression,
inputs: Vec<NodeIndex>,
input_labels: Vec<String>,
kernel_op: Arc<Box<dyn KernelOp>>,
has_dyn_dims_param: bool,
constants: FxHashMap<char, CudaSlice<u8>>,
kernel_name: &'static str,
) -> Self {
@@ -79,7 +86,9 @@ impl CompiledKernel {
block,
shared_mem,
inputs,
input_labels,
kernel_op,
has_dyn_dims_param,
internal_bufs: Vec::new(),
constants,
graph_node: None,
@@ -226,7 +235,7 @@ impl HostOp for CudaGraphOp {
stream: &Arc<CudaStream>,
_self_node: NodeIndex,
_inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
self.execute_internal(stream, buffers, dyn_map)
@@ -258,6 +267,40 @@ impl HostOp for CudaGraphOp {
.collect()
}
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
let state = self.state.borrow();
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
let max_step = state.kernels.len().saturating_sub(1);
let mut touch = |node: NodeIndex, step: usize| {
lifetimes
.entry(node)
.and_modify(|(first, last)| {
*first = (*first).min(step);
*last = (*last).max(step);
})
.or_insert((step, step));
};
for (step, kernel) in state.kernels.iter().enumerate() {
for &input in &kernel.inputs {
touch(input, step);
}
touch(kernel.node, step);
}
for node in self.extra_buffer_nodes() {
lifetimes.entry(node).or_insert((0, max_step));
}
Some(
lifetimes
.into_iter()
.map(|(node, (start, end))| (node, start, end))
.collect(),
)
}
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
self.buffer_sizes.clone()
}
@@ -268,11 +311,64 @@ impl HostOp for CudaGraphOp {
}
impl CudaGraphOp {
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
match kernel_name {
"Constant" | "Iota" => Some(0),
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
| "Mul" => Some(2),
"Scatter" | "ScatterNoCopy" => Some(3),
_ => None,
}
}
fn kernel_requires_output_buffer(
kernel: &CompiledKernel,
dyn_map: &FxHashMap<char, usize>,
) -> bool {
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
&& kernel.kernel_op.output_aliases_input().is_none()
}
fn validate_kernel_pointers(
kernel: &CompiledKernel,
output_ptr: u64,
input_ptrs: &[u64],
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
anyhow::bail!(
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
if *input_ptr == 0 {
let input_label = kernel
.input_labels
.get(idx)
.map(String::as_str)
.unwrap_or("unknown");
anyhow::bail!(
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
kernel.kernel_name,
kernel.node,
input_node,
);
}
}
Ok(())
}
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
fn execute_internal(
&self,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let mut state = self.state.borrow_mut();
@@ -343,7 +439,7 @@ impl CudaGraphOp {
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
current_buffer_ptrs.insert(node, buf.ptr());
}
}
@@ -391,13 +487,26 @@ impl CudaGraphOp {
.iter()
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
dyn_dims_ptr
} else {
0
};
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
anyhow::bail!(
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
kernel_dyn_dims_ptr,
);
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
}
@@ -424,6 +533,19 @@ impl CudaGraphOp {
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
if grid_dim.0 == 0
|| grid_dim.1 == 0
|| grid_dim.2 == 0
|| block_dim.0 == 0
|| block_dim.1 == 0
|| block_dim.2 == 0
{
anyhow::bail!(
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
kernel.kernel_name,
kernel.node,
);
}
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let cu_func = unsafe { kernel.function.raw_function() };
@@ -452,7 +574,7 @@ impl CudaGraphOp {
&self,
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let ctx = stream.context().clone();
@@ -474,7 +596,7 @@ impl CudaGraphOp {
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
buffer_ptrs.insert(node, buf.ptr());
}
}
@@ -521,6 +643,19 @@ impl CudaGraphOp {
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
if grid_dim.0 == 0
|| grid_dim.1 == 0
|| grid_dim.2 == 0
|| block_dim.0 == 0
|| block_dim.1 == 0
|| block_dim.2 == 0
{
anyhow::bail!(
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
kernel.kernel_name,
kernel.node,
);
}
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
@@ -529,18 +664,41 @@ impl CudaGraphOp {
.iter()
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
dyn_dims_ptr
} else {
0
};
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
anyhow::bail!(
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
kernel_dyn_dims_ptr,
);
let mut params = UnifiedKernelParams::new(param_values);
let cu_func = unsafe { kernel.function.raw_function() };
let kernel_node = kernel.node;
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
eprintln!(
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
kernel.kernel_name,
kernel.node,
kernel.inputs.len(),
kernel.has_dyn_dims_param,
params.values.len(),
);
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
@@ -662,6 +820,36 @@ pub fn kernel_to_host(
// live in a different convex subgraph than the FS itself.
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
graph
.node_weight(idx)
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
};
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
name_of(graph, node) == Some("FusionStart")
|| graph[node].to_op::<LoopStart>().is_some()
|| graph[node].to_op::<LoopEnd>().is_some()
|| graph[node].to_op::<LoopInput>().is_some()
|| graph[node].to_op::<LoopInputStatic>().is_some()
|| graph[node].to_op::<LoopOutput>().is_some()
|| graph[node].to_op::<LoopOutputSelect>().is_some()
};
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
let mut visited = FxHashSet::default();
while visited.insert(node) && is_transparent_input(graph, node) {
let Some(pred) = graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.next()
else {
break;
};
node = pred;
}
node
};
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
// Track all CudaGraphOp nodes and their subgraphs for edge creation
@@ -678,6 +866,7 @@ pub fn kernel_to_host(
let mut all_dyn_dims = FxHashSet::default();
let mut all_buffer_nodes = FxHashSet::default();
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
let mut external_inputs = FxHashSet::default();
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
@@ -691,9 +880,7 @@ pub fn kernel_to_host(
// Set global dyn dims ordering so compiles use consistent indices
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
global_dyn_dims.sort();
if !global_dyn_dims.is_empty() {
set_global_dyn_dims(global_dyn_dims.clone());
}
set_global_dyn_dims(global_dyn_dims.clone());
// Group the topo order into compile units: each FusionEnd-rooted
// region collapses to a single CompileUnit::Region (one fused
@@ -711,14 +898,35 @@ pub fn kernel_to_host(
.to_dialect::<dyn KernelOp>()
.unwrap();
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
kernel_op_ref.compile(cuda_stream, kernel_cache);
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
// Collect inputs from graph edges
let inputs: Vec<NodeIndex> = llir_graph
.edges_directed(*kernel_node_idx, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.map(|input| resolve_transparent_input(llir_graph, input))
.collect_vec();
if let Some(expected_inputs) =
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
{
assert_eq!(
inputs.len(),
expected_inputs,
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
kernel_op_ref.kernel_name(),
kernel_node_idx,
);
}
let input_labels = inputs
.iter()
.map(|&input| {
name_of(llir_graph, input)
.map(str::to_string)
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
})
.collect_vec();
// Collect buffer nodes and sizes
@@ -729,6 +937,12 @@ pub fn kernel_to_host(
all_buffer_sizes.insert(*kernel_node_idx, output_size);
}
all_buffer_nodes.extend(inputs.iter().copied());
external_inputs.extend(
inputs
.iter()
.copied()
.filter(|input| !subgraph.contains(input)),
);
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
@@ -739,7 +953,9 @@ pub fn kernel_to_host(
block,
shared_mem,
inputs,
input_labels,
kernel_op.clone(),
has_dyn_dims_param,
constants,
kernel_op.kernel_name(),
));
@@ -752,6 +968,7 @@ pub fn kernel_to_host(
cuda_stream,
kernel_cache,
);
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
// The region's CompiledKernel is keyed on the FE node
// (so FE provides trait methods like output_size /
@@ -763,7 +980,20 @@ pub fn kernel_to_host(
.to_dialect::<dyn KernelOp>()
.unwrap();
let inputs: Vec<NodeIndex> = region.external_inputs.clone();
let inputs: Vec<NodeIndex> = region
.external_inputs
.iter()
.copied()
.map(|input| resolve_transparent_input(llir_graph, input))
.collect();
let input_labels = inputs
.iter()
.map(|&input| {
name_of(llir_graph, input)
.map(str::to_string)
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
})
.collect_vec();
let output_size = fe_op_ref.output_size();
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
@@ -771,6 +1001,12 @@ pub fn kernel_to_host(
all_buffer_sizes.insert(region.fe_node, output_size);
}
all_buffer_nodes.extend(inputs.iter().copied());
external_inputs.extend(
inputs
.iter()
.copied()
.filter(|input| !subgraph.contains(input)),
);
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
@@ -781,7 +1017,9 @@ pub fn kernel_to_host(
compiled.block,
compiled.shared_mem,
inputs,
input_labels,
kernel_op,
has_dyn_dims_param,
compiled.constants,
"FusedRegion",
));
@@ -826,16 +1064,17 @@ pub fn kernel_to_host(
}
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
// Find external inputs: nodes outside subgraph that have edges into subgraph
let external_inputs: FxHashSet<NodeIndex> = subgraph
.iter()
.flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.filter(|src| !subgraph.contains(src))
})
.collect();
// Find external inputs: nodes outside subgraph that have edges into
// subgraph. Also include normalized FusionStart predecessors, because
// the compiled kernels read from the concrete producer buffer rather
// than the marker node.
external_inputs.extend(subgraph.iter().flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.map(|input| resolve_transparent_input(llir_graph, input))
.filter(|src| !subgraph.contains(src))
}));
// Add edges from external inputs to CudaGraphOp
for input in &external_inputs {

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
use crate::{
host::HostOp,
host::{DeviceBuffer, HostOp},
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
use fixedbitset::FixedBitSet;
use half::{bf16, f16};
@@ -32,6 +32,8 @@ use std::{
use tracing::{Level, span, trace};
use uuid::Uuid;
const ARENA_ALIGNMENT: usize = 256;
pub enum CudaInput {
Buffer(CudaSlice<u8>),
Ptr(u64),
@@ -70,13 +72,25 @@ pub(crate) struct BufferSpec {
dtype: DType,
}
#[derive(Debug, Clone)]
struct PlannedBuffer {
node: NodeIndex,
bytes: usize,
start: usize,
end: usize,
}
/// Per-bucket compiled state. Each bucket holds its own executable graph,
/// explicit runtime metadata, intermediate buffers, and node mappings.
/// Weights (hlir_buffers) are shared.
pub(crate) struct CompiledBucket {
pub(crate) exec_graph: StableGraph<ExecutableHostOp, (), Directed>,
pub(crate) node_to_exec: FxHashMap<NodeIndex, NodeIndex>,
pub(crate) buffers: FxHashMap<NodeIndex, CudaSlice<u8>>,
/// Single reusable arena for all intermediate buffers in this bucket.
pub(crate) arena: Option<CudaSlice<u8>>,
pub(crate) arena_bytes: usize,
pub(crate) logical_buffer_offsets: FxHashMap<NodeIndex, usize>,
pub(crate) logical_buffer_bytes: FxHashMap<NodeIndex, usize>,
pub(crate) cached_buffer_ptrs: FxHashMap<NodeIndex, u64>,
pub(crate) buffer_specs: FxHashMap<NodeIndex, BufferSpec>,
pub(crate) llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
@@ -99,7 +113,10 @@ impl CompiledBucket {
CompiledBucket {
exec_graph: StableGraph::default(),
node_to_exec: FxHashMap::default(),
buffers: FxHashMap::default(),
arena: None,
arena_bytes: 0,
logical_buffer_offsets: FxHashMap::default(),
logical_buffer_bytes: FxHashMap::default(),
cached_buffer_ptrs: FxHashMap::default(),
buffer_specs: FxHashMap::default(),
llir_to_hlir: FxHashMap::default(),
@@ -186,9 +203,71 @@ impl CudaRuntime {
.collect()
}
/// Public access to the active intermediate buffers (for tests and diagnostics).
pub fn buffers(&self) -> &FxHashMap<NodeIndex, CudaSlice<u8>> {
&self.active().buffers
fn bucket_buffer(
bucket: &CompiledBucket,
stream: &Arc<CudaStream>,
logical_node: &NodeIndex,
) -> Option<DeviceBuffer> {
let arena = bucket.arena.as_ref()?;
let offset = *bucket.logical_buffer_offsets.get(logical_node)?;
let len = *bucket.logical_buffer_bytes.get(logical_node)?;
let ptr = arena.device_ptr(stream).0.checked_add(offset as u64)?;
Some(DeviceBuffer::new(ptr, len))
}
fn copy_device_buffer_to_new_slice(
stream: &Arc<CudaStream>,
src: DeviceBuffer,
) -> CudaSlice<u8> {
let dst = stream.alloc_zeros::<u8>(src.len()).unwrap();
let dst_ptr = dst.device_ptr(stream).0;
unsafe {
result::memcpy_dtod_async(dst_ptr, src.ptr(), src.len(), stream.cu_stream())
.expect("cuMemcpyDtoDAsync failed");
}
stream.synchronize().unwrap();
dst
}
fn resolve_runtime_buffer(
bucket: &CompiledBucket,
stream: &Arc<CudaStream>,
hlir_buffers: &FxHashMap<NodeIndex, CudaInput>,
external_buffers: &FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
external_output_buffers: &FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
mut node: NodeIndex,
) -> Option<DeviceBuffer> {
let mut visited = FxHashSet::default();
loop {
if !visited.insert(node) {
return None;
}
if let Some(ext) = external_output_buffers.get(&node) {
return Some(DeviceBuffer::new(ext.device_ptr(stream).0, ext.len()));
}
if let Some(buf) = Self::bucket_buffer(bucket, stream, &node) {
return Some(buf);
}
if let Some(hlir_node) = bucket.llir_to_hlir.get(&node) {
match hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => {
return Some(DeviceBuffer::new(buf.device_ptr(stream).0, buf.len()));
}
Some(CudaInput::Ptr(_)) => {
if let Some(ext) = external_buffers.get(hlir_node) {
return Some(DeviceBuffer::new(ext.device_ptr(stream).0, ext.len()));
}
}
None => {}
}
}
let alias_target = bucket.output_alias_map.get(&node)?;
node = *alias_target;
}
}
#[tracing::instrument(skip_all)]
@@ -324,6 +403,15 @@ impl CudaRuntime {
let data_id = self.resolve_data_node(id);
let bucket = self.active();
let truncate_to_logical_bytes = |mut data: Vec<u8>| {
if let Some(spec) = bucket.buffer_specs.get(&data_id)
&& let Some(logical_bytes) = spec.bytes.exec(&bucket.last_dyn_map)
{
data.truncate(logical_bytes.min(data.len()));
}
data
};
let _span = span!(Level::TRACE, "dtoh").entered();
// If predecessor is an Input node, data lives in hlir_buffers
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
@@ -345,40 +433,45 @@ impl CudaRuntime {
}
}
} else {
// Predecessor is a computation node — data is in intermediate buffers
self.cuda_stream
.clone_dtoh(
bucket
.buffers
.get(&data_id)
.expect("Cannot find tensor in runtime!"),
)
.unwrap()
if let Some(ext) = self.external_output_buffers.get(&data_id) {
return truncate_to_logical_bytes(self.cuda_stream.clone_dtoh(&**ext).unwrap());
}
// Predecessor is a computation node — data is in the intermediate arena.
truncate_to_logical_bytes(
Self::bucket_buffer(bucket, &self.cuda_stream, &data_id)
.expect("Cannot find tensor in runtime!")
.clone_dtoh(&self.cuda_stream)
.unwrap(),
)
}
}
/// Resolve the device-side CudaSlice for an output tensor without copying to host.
/// Resolve the device-side buffer for an output tensor without copying to host.
/// Used by copy_output_to_device_ptr for DtoD transfers.
fn resolve_output_slice(&self, id: impl ToId) -> &CudaSlice<u8> {
fn resolve_output_buffer(&self, id: impl ToId) -> DeviceBuffer {
let data_id = self.resolve_data_node(id);
let bucket = self.active();
if let Some(ext) = self.external_output_buffers.get(&data_id) {
return DeviceBuffer::new(ext.device_ptr(&self.cuda_stream).0, ext.len());
}
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
match self
.hlir_buffers
.get(hlir_node)
.expect("Cannot find input tensor in runtime!")
{
CudaInput::Buffer(buf) => buf,
CudaInput::Buffer(buf) => {
DeviceBuffer::new(buf.device_ptr(&self.cuda_stream).0, buf.len())
}
CudaInput::Ptr(_) => self
.external_buffers
.get(hlir_node)
.map(|ext| &**ext)
.map(|ext| DeviceBuffer::new(ext.device_ptr(&self.cuda_stream).0, ext.len()))
.expect("Cannot read raw pointer input — no external_buffers entry for node"),
}
} else {
bucket
.buffers
.get(&data_id)
Self::bucket_buffer(bucket, &self.cuda_stream, &data_id)
.expect("Cannot find tensor in runtime!")
}
}
@@ -393,13 +486,12 @@ impl CudaRuntime {
dest_ptr != 0,
"copy_output_to_device_ptr called with null pointer"
);
let src_slice = self.resolve_output_slice(id);
let src_ptr = src_slice.device_ptr(&self.cuda_stream).0;
let copy_bytes = n_bytes.min(src_slice.len());
let src = self.resolve_output_buffer(id);
let copy_bytes = n_bytes.min(src.len());
unsafe {
cudarc::driver::result::memcpy_dtod_async(
result::memcpy_dtod_async(
dest_ptr,
src_ptr,
src.ptr(),
copy_bytes,
self.cuda_stream.cu_stream(),
)
@@ -496,39 +588,36 @@ impl CudaRuntime {
CudaInput::Ptr(p) => panic!("Cannot take raw pointer input (ptr=0x{:x})", p),
}
} else {
self.compiled_buckets[bi]
.buffers
.remove(&lineage_node)
.expect("Cannot find tensor in runtime!")
let src = Self::bucket_buffer(
&self.compiled_buckets[bi],
&self.cuda_stream,
&lineage_node,
)
.expect("Cannot find tensor in runtime!");
Self::copy_device_buffer_to_new_slice(&self.cuda_stream, src)
}
} else {
// Copy-then-modify: output data is in alias_node's buffer (intermediate),
// but we want to extract the lineage HLIR buffer so intermediates stay intact.
// while the lineage HLIR buffer has stale pre-op data. Return an owned
// copy of the arena output and drop the stale HLIR buffer.
let hlir_node = *self.compiled_buckets[bi]
.llir_to_hlir
.get(&lineage_node)
.expect("output_data_input lineage must reach an HLIR input node");
// Take the intermediate buffer (has the actual output data)
let output_buf = self.compiled_buckets[bi]
.buffers
.remove(&alias_node)
.expect("Cannot find intermediate output buffer in runtime!");
let output =
Self::bucket_buffer(&self.compiled_buckets[bi], &self.cuda_stream, &alias_node)
.expect("Cannot find intermediate output buffer in runtime!");
let output_buf = Self::copy_device_buffer_to_new_slice(&self.cuda_stream, output);
// Take the HLIR buffer (has stale pre-op data)
let hlir_buf = match self
match self
.hlir_buffers
.remove(&hlir_node)
.expect("Cannot find HLIR input buffer in runtime!")
{
CudaInput::Buffer(buf) => buf,
CudaInput::Buffer(_buf) => {}
CudaInput::Ptr(p) => panic!("Cannot take raw pointer input (ptr=0x{:x})", p),
};
// Put stale HLIR buffer into intermediate slot (keeps allocation alive)
self.compiled_buckets[bi]
.buffers
.insert(alias_node, hlir_buf);
}
// Return the output buffer (has correct data)
output_buf
@@ -595,20 +684,16 @@ impl CudaRuntime {
.get(&input_id)
.expect("Cannot find input in LLIR mapping!");
// Swap intermediate buffer <-> input buffer
let intermediate_buf = self.compiled_buckets[bi]
.buffers
.get_mut(&data_llir_node)
.expect("Output not in intermediate buffers");
if let CudaInput::Buffer(input_buf) = self
.hlir_buffers
.get_mut(&input_id)
.expect("Input not in hlir_buffers")
{
std::mem::swap(intermediate_buf, input_buf);
} else {
panic!("Input is a raw pointer, cannot swap");
}
let src = Self::bucket_buffer(
&self.compiled_buckets[bi],
&self.cuda_stream,
&data_llir_node,
)
.expect("Output not in intermediate buffers");
let input_buf = Self::copy_device_buffer_to_new_slice(&self.cuda_stream, src);
self.hlir_buffers
.insert(input_id, CudaInput::Buffer(input_buf));
self.changed_hlir.insert(input_id);
// Update cached pointer for the input
let ptr = match &self.hlir_buffers[&input_id] {
@@ -624,7 +709,7 @@ impl CudaRuntime {
/// They will be re-allocated on the next `execute()` call.
pub fn free_intermediate_buffers(&mut self) {
for bucket in &mut self.compiled_buckets {
bucket.buffers.clear();
bucket.arena = None;
bucket.cached_buffer_ptrs.clear();
}
}
@@ -635,42 +720,248 @@ impl CudaRuntime {
stream: &Arc<CudaStream>,
dyn_dims: &FxHashMap<char, usize>,
) {
let is_first_alloc = bucket.buffers.is_empty();
// Only sync if we might need to free/reallocate buffers
if is_first_alloc {
stream.synchronize().unwrap();
let needs_new_plan = !Self::buffer_plan_matches(bucket, dyn_dims);
if needs_new_plan {
if bucket.arena.is_some() {
stream.synchronize().unwrap();
}
Self::plan_intermediate_buffers(bucket, dyn_dims);
}
if bucket.arena_bytes == 0 {
bucket.arena = None;
bucket.cached_buffer_ptrs.clear();
return;
}
if bucket
.arena
.as_ref()
.is_none_or(|arena| arena.len() < bucket.arena_bytes)
{
bucket.arena = Some(stream.alloc_zeros(bucket.arena_bytes).unwrap());
}
let arena_ptr = bucket.arena.as_ref().unwrap().device_ptr(stream).0;
for (logical_node, &offset) in &bucket.logical_buffer_offsets {
if let Some(ptr) = arena_ptr.checked_add(offset as u64) {
bucket.cached_buffer_ptrs.insert(*logical_node, ptr);
}
}
}
fn buffer_plan_matches(bucket: &CompiledBucket, dyn_dims: &FxHashMap<char, usize>) -> bool {
if bucket.buffer_specs.is_empty() {
return true;
}
if bucket.logical_buffer_offsets.is_empty() && !bucket.buffer_specs.is_empty() {
return false;
}
bucket
.intermediate_buffer_dims
.iter()
.all(|dim| bucket.last_dyn_map.get(dim) == dyn_dims.get(dim))
}
fn plan_intermediate_buffers(bucket: &mut CompiledBucket, dyn_dims: &FxHashMap<char, usize>) {
bucket.logical_buffer_offsets.clear();
bucket.logical_buffer_bytes.clear();
bucket.arena_bytes = 0;
bucket.intermediate_buffer_dims.clear();
let mut total_alloc: usize = 0;
let mut realloc_count: usize = 0;
for (node, spec) in bucket.buffer_specs.clone() {
bucket.cached_buffer_ptrs.clear();
bucket.last_dyn_map = dyn_dims.clone();
let mut logical_bytes = FxHashMap::default();
for (node, spec) in &bucket.buffer_specs {
bucket
.intermediate_buffer_dims
.extend(spec.bytes.dyn_vars());
let needed_bytes = spec.bytes.exec(dyn_dims).unwrap();
if needed_bytes == 0 {
continue;
let bytes = spec.bytes.exec(dyn_dims).unwrap();
if bytes > 0 {
logical_bytes.insert(*node, bytes);
}
// Only allocate/reallocate if we don't have a buffer or existing one is too small
let existing_len = bucket.buffers.get(&node).map(|b| b.len()).unwrap_or(0);
if existing_len >= needed_bytes {
continue; // Existing buffer is large enough, reuse it
}
// Need to allocate (or reallocate)
total_alloc += needed_bytes;
realloc_count += 1;
bucket
.buffers
.insert(node, stream.alloc_zeros(needed_bytes).unwrap());
let ptr = bucket.buffers[&node].device_ptr(stream).0;
bucket.cached_buffer_ptrs.insert(node, ptr);
}
let _ = (realloc_count, total_alloc);
if logical_bytes.is_empty() {
bucket.arena = None;
return;
}
let total_spec_count = logical_bytes.len();
let total_spec_bytes = logical_bytes.values().copied().sum::<usize>();
let mut first_use: FxHashMap<NodeIndex, usize> = FxHashMap::default();
let mut last_use: FxHashMap<NodeIndex, usize> = FxHashMap::default();
let exec_order = toposort(&bucket.exec_graph, None).unwrap_or_default();
let output_alias_map = bucket.output_alias_map.clone();
let mut touch = |node: NodeIndex, step: usize| {
let Some(node) = resolve_logical_buffer_node(node, &logical_bytes, &output_alias_map)
else {
return;
};
first_use
.entry(node)
.and_modify(|first| *first = (*first).min(step))
.or_insert(step);
last_use
.entry(node)
.and_modify(|last| *last = (*last).max(step))
.or_insert(step);
};
let mut time = 0usize;
for exec_node in exec_order.iter().copied() {
let exec_op = &bucket.exec_graph[exec_node];
let precise_extra_lifetimes = exec_op.internal.extra_buffer_lifetimes();
let span = precise_extra_lifetimes
.as_ref()
.and_then(|lifetimes| lifetimes.iter().map(|(_, _, end)| *end).max())
.map(|end| end + 1)
.unwrap_or(1)
.max(1);
let start_time = time;
let end_time = time + span - 1;
time += span;
let precise_nodes = precise_extra_lifetimes
.as_ref()
.map(|lifetimes| {
lifetimes
.iter()
.filter_map(|(node, _, _)| {
resolve_logical_buffer_node(*node, &logical_bytes, &output_alias_map)
})
.collect::<FxHashSet<_>>()
})
.unwrap_or_default();
let mut touch_if_not_precise = |node: NodeIndex, step: usize| {
if resolve_logical_buffer_node(node, &logical_bytes, &output_alias_map)
.is_some_and(|node| precise_nodes.contains(&node))
{
return;
}
touch(node, step);
};
touch_if_not_precise(exec_op.output, start_time);
touch_if_not_precise(exec_op.output, end_time);
for &input in &exec_op.inputs {
touch_if_not_precise(input, start_time);
touch_if_not_precise(input, end_time);
}
if let Some(lifetimes) = precise_extra_lifetimes {
for (node, start, end) in lifetimes {
touch(node, start_time + start);
touch(node, start_time + end);
}
} else {
for extra_node in exec_op.internal.extra_buffer_nodes() {
touch(extra_node, start_time);
touch(extra_node, end_time);
}
}
}
for &producer in bucket.output_producers.values() {
let mut alias_node = producer;
while let Some(target) = bucket.output_alias_map.get(&alias_node) {
alias_node = *target;
}
touch(alias_node, time);
let mut data_node = producer;
while let Some(target) = bucket.output_data_map.get(&data_node) {
data_node = *target;
}
touch(data_node, time);
touch(producer, time);
}
let mut planned = logical_bytes
.into_iter()
.filter(|(node, _)| first_use.contains_key(node) || last_use.contains_key(node))
.map(|(node, bytes)| PlannedBuffer {
node,
bytes,
start: first_use.get(&node).copied().unwrap_or(0),
end: last_use.get(&node).copied().unwrap_or(0),
})
.collect_vec();
planned.sort_by_key(|buf| (buf.start, std::cmp::Reverse(buf.bytes), buf.node.index()));
let planned_logical_count = planned.len();
let planned_logical_bytes = planned.iter().map(|buf| buf.bytes).sum::<usize>();
let logical_peak = logical_interval_peak(&planned);
let mut arena_end = 0usize;
let mut placed: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(planned.len());
let mut placement_order = planned.iter().collect_vec();
placement_order.sort_by_key(|buf| {
(
std::cmp::Reverse(buf.bytes),
std::cmp::Reverse(buf.end.saturating_sub(buf.start)),
buf.start,
buf.node.index(),
)
});
for buf in placement_order {
let allocation_bytes = align_up(buf.bytes, ARENA_ALIGNMENT);
let mut candidates = vec![0usize];
for &(placed_start, placed_end, placed_offset, placed_bytes) in &placed {
if intervals_overlap(buf.start, buf.end, placed_start, placed_end) {
candidates.push(align_up(placed_offset + placed_bytes, ARENA_ALIGNMENT));
}
}
candidates.sort_unstable();
candidates.dedup();
let offset = candidates
.into_iter()
.find(|&candidate| {
placed
.iter()
.all(|&(placed_start, placed_end, placed_offset, placed_bytes)| {
!intervals_overlap(buf.start, buf.end, placed_start, placed_end)
|| !byte_ranges_overlap(
candidate,
allocation_bytes,
placed_offset,
placed_bytes,
)
})
})
.unwrap_or_else(|| {
placed
.iter()
.filter(|(placed_start, placed_end, _, _)| {
intervals_overlap(buf.start, buf.end, *placed_start, *placed_end)
})
.map(|(_, _, offset, bytes)| align_up(offset + bytes, ARENA_ALIGNMENT))
.max()
.unwrap_or(0)
});
bucket.logical_buffer_offsets.insert(buf.node, offset);
bucket.logical_buffer_bytes.insert(buf.node, buf.bytes);
placed.push((buf.start, buf.end, offset, allocation_bytes));
arena_end = arena_end.max(offset + allocation_bytes);
}
bucket.arena_bytes = arena_end;
if std::env::var_os("LUMINAL_CUDA_MEMORY_DEBUG").is_some() {
eprintln!(
" CUDA memory plan specs={total_spec_count} used={planned_logical_count} skipped={} spec_bytes={} used_bytes={} skipped_bytes={} logical_peak={} arena_plan={} allocations={}",
total_spec_count.saturating_sub(planned_logical_count),
total_spec_bytes,
planned_logical_bytes,
total_spec_bytes.saturating_sub(planned_logical_bytes),
logical_peak,
bucket.arena_bytes,
bucket.logical_buffer_offsets.len(),
);
}
}
/// Pre-allocate buffers with the given dynamic dimension values.
@@ -679,10 +970,7 @@ impl CudaRuntime {
pub fn prebuild_graphs(&mut self, dyn_map: &FxHashMap<char, usize>) {
let bucket = &mut self.compiled_buckets[self.active_bucket];
// 1. Allocate intermediate buffers (needed for buffer pointers)
if bucket.buffers.is_empty() {
bucket.last_dyn_map = dyn_map.clone();
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
}
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
// 2. Process changed HLIR inputs to get their buffer pointers
if !self.changed_hlir.is_empty() || !bucket.hlir_synced {
@@ -809,6 +1097,57 @@ fn format_duration_precise(d: &std::time::Duration) -> String {
}
}
fn resolve_logical_buffer_node(
mut node: NodeIndex,
logical_bytes: &FxHashMap<NodeIndex, usize>,
output_alias_map: &FxHashMap<NodeIndex, NodeIndex>,
) -> Option<NodeIndex> {
let mut visited = FxHashSet::default();
while !logical_bytes.contains_key(&node) {
if !visited.insert(node) {
return None;
}
let target = output_alias_map.get(&node)?;
node = *target;
}
Some(node)
}
fn align_up(value: usize, alignment: usize) -> usize {
if alignment <= 1 {
value
} else {
value.div_ceil(alignment) * alignment
}
}
fn intervals_overlap(a_start: usize, a_end: usize, b_start: usize, b_end: usize) -> bool {
a_start <= b_end && b_start <= a_end
}
fn byte_ranges_overlap(a_offset: usize, a_bytes: usize, b_offset: usize, b_bytes: usize) -> bool {
a_offset < b_offset + b_bytes && b_offset < a_offset + a_bytes
}
fn logical_interval_peak(planned: &[PlannedBuffer]) -> usize {
let mut events = Vec::with_capacity(planned.len() * 2);
for buf in planned {
events.push((buf.start, buf.bytes as i128));
events.push((buf.end.saturating_add(1), -(buf.bytes as i128)));
}
events.sort_by_key(|(step, delta)| (*step, *delta));
let mut current = 0i128;
let mut peak = 0i128;
for (_, delta) in events {
current += delta;
peak = peak.max(current);
}
peak.max(0) as usize
}
impl Runtime for CudaRuntime {
type Ops = (crate::kernel::Ops, crate::host::Ops);
type CompileArg = Arc<CudaStream>;
@@ -817,9 +1156,14 @@ impl Runtime for CudaRuntime {
fn late_egglog_passes(
ops: &[Arc<Box<dyn luminal::op::EgglogOp>>],
_options: &luminal::graph::BuildSearchSpaceOptions,
options: &luminal::graph::BuildSearchSpaceOptions,
dyn_map: &FxHashMap<char, usize>,
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
vec![crate::memory_analysis::cuda_memory_analysis_pass(ops)]
vec![crate::memory_analysis::cuda_memory_analysis_pass(
ops,
options.max_memory_bytes,
dyn_map,
)]
}
fn estimate_graph_memory<'a>(
@@ -904,7 +1248,7 @@ impl Runtime for CudaRuntime {
fn clear_intermediate_buffers(&mut self) {
let _ = self.cuda_stream.synchronize();
for bucket in &mut self.compiled_buckets {
bucket.buffers.clear();
bucket.arena = None;
bucket.cached_buffer_ptrs.clear();
}
}
@@ -912,14 +1256,37 @@ impl Runtime for CudaRuntime {
fn intermediate_buffer_bytes(&self) -> usize {
self.compiled_buckets
.iter()
.map(|b| b.buffers.values().map(|buf| buf.len()).sum::<usize>())
.map(|b| b.arena.as_ref().map(|arena| arena.len()).unwrap_or(0))
.sum()
}
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
self.compiled_buckets
.get(self.active_bucket)
.map(|bucket| bucket.arena_bytes)
}
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
self.compiled_buckets
.get(self.active_bucket)
.map(|bucket| bucket.arena.as_ref().map(|arena| arena.len()).unwrap_or(0))
}
fn has_nan_outputs(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {
let _ = self.cuda_stream.synchronize();
let bucket = self.active();
for (node_id, buf) in &bucket.buffers {
let mut checked = FxHashSet::default();
for producer in bucket.output_producers.values().copied() {
let mut node_id = producer;
while let Some(alias_target) = bucket.output_alias_map.get(&node_id) {
node_id = *alias_target;
}
if !checked.insert(node_id) {
continue;
}
let Some(buf) = Self::bucket_buffer(bucket, &self.cuda_stream, &node_id) else {
continue;
};
let n_bytes = buf.len();
if n_bytes == 0 || n_bytes % 4 != 0 {
continue;
@@ -929,7 +1296,7 @@ impl Runtime for CudaRuntime {
// and their bit patterns can produce false positives when reinterpreted as f32.
let is_float = bucket
.buffer_specs
.get(node_id)
.get(&node_id)
.map(|spec| matches!(spec.dtype, DType::F32))
.unwrap_or(true);
@@ -937,7 +1304,7 @@ impl Runtime for CudaRuntime {
continue;
}
let host_bytes: Vec<u8> = match self.cuda_stream.clone_dtoh(buf) {
let host_bytes: Vec<u8> = match buf.clone_dtoh(&self.cuda_stream) {
Ok(v) => v,
Err(_) => continue,
};
@@ -957,9 +1324,9 @@ impl Runtime for CudaRuntime {
_trials: usize,
_timeout: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String) {
// Clear active bucket's buffers before loading new LLIR for profiling
// Clear active bucket's arena before loading new LLIR for profiling.
if !self.compiled_buckets.is_empty() {
self.active_mut().buffers.clear();
self.active_mut().arena = None;
}
self.load_llir(llir_graph);
self.profiling = true;
@@ -1020,7 +1387,7 @@ impl Runtime for CudaRuntime {
if idx != self.active_bucket {
// Free the old bucket's intermediates to avoid holding 2 full sets in GPU memory
let old = self.active_bucket;
self.compiled_buckets[old].buffers.clear();
self.compiled_buckets[old].arena = None;
self.compiled_buckets[old].cached_buffer_ptrs.clear();
self.active_bucket = idx;
// Mark bucket as needing HLIR sync since it may have missed changes
@@ -1029,17 +1396,7 @@ impl Runtime for CudaRuntime {
}
let bucket = &mut self.compiled_buckets[self.active_bucket];
let buffers_empty = bucket.buffers.is_empty();
let dyn_map_len_changed = dyn_map.len() != bucket.last_dyn_map.len();
let dyn_dims_changed = dyn_map
.iter()
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
.any(|(d, v)| bucket.last_dyn_map.get(d).map(|n| *n != *v).unwrap_or(true));
let needs_realloc = buffers_empty || dyn_map_len_changed || dyn_dims_changed;
if needs_realloc {
bucket.last_dyn_map = dyn_map.clone();
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
}
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
// Cache HLIR input pointers
if !self.changed_hlir.is_empty() || !bucket.hlir_synced {
let hlir_nodes: Vec<NodeIndex> = if !bucket.hlir_synced {
@@ -1081,82 +1438,45 @@ impl Runtime for CudaRuntime {
trace!("Executing: {:?}", exec_op);
// Build buffer map for the HostOp interface
let mut buffer_map: FxHashMap<NodeIndex, &CudaSlice<u8>> = FxHashMap::default();
let mut buffer_map: FxHashMap<NodeIndex, DeviceBuffer> = FxHashMap::default();
// Add output buffer -- prefer external output pointer if registered (zero copy)
if let Some(ext) = self.external_output_buffers.get(&exec_op.output) {
buffer_map.insert(exec_op.output, &**ext);
} else if let Some(buf) = bucket.buffers.get(&exec_op.output) {
if let Some(buf) = Self::resolve_runtime_buffer(
bucket,
&self.cuda_stream,
&self.hlir_buffers,
&self.external_buffers,
&self.external_output_buffers,
exec_op.output,
) {
buffer_map.insert(exec_op.output, buf);
}
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
for inp in exec_op.inputs.iter() {
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => {
buffer_map.insert(*inp, buf);
}
Some(CudaInput::Ptr(_)) => {
if let Some(ext) = self.external_buffers.get(hlir_node) {
buffer_map.insert(*inp, &**ext);
}
}
None => {}
}
if !buffer_map.contains_key(inp)
&& let Some(buf) = bucket.buffers.get(inp)
{
buffer_map.insert(*inp, buf);
}
} else if let Some(buf) = bucket.buffers.get(inp) {
buffer_map.insert(*inp, buf);
for &inp in &exec_op.inputs {
if let Some(buf) = Self::resolve_runtime_buffer(
bucket,
&self.cuda_stream,
&self.hlir_buffers,
&self.external_buffers,
&self.external_output_buffers,
inp,
) {
buffer_map.insert(inp, buf);
}
}
// Add extra buffer nodes (for CudaGraphOp)
let extra_nodes = exec_op.internal.extra_buffer_nodes();
for extra_node in extra_nodes {
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
if let Some(ext) = self.external_output_buffers.get(&extra_node) {
e.insert(&**ext);
} else if let Some(buf) = bucket.buffers.get(&extra_node) {
e.insert(buf);
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => {
e.insert(buf);
}
Some(CudaInput::Ptr(_)) => {
if let Some(ext) = self.external_buffers.get(hlir_node) {
e.insert(&**ext);
}
}
None => {}
}
}
}
}
// Resolve output aliases
for (&alias_node, &alias_target) in &bucket.output_alias_map {
if !buffer_map.contains_key(&alias_node) {
continue;
}
// Try HLIR buffer first (includes external device pointers)
let resolved: Option<&CudaSlice<u8>> =
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => Some(buf),
Some(CudaInput::Ptr(_)) => {
self.external_buffers.get(hlir_node).map(|ext| &**ext)
}
None => None,
}
} else {
None
};
if let Some(buf) = resolved {
buffer_map.insert(alias_node, buf);
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
buffer_map.insert(alias_node, buf);
if let Entry::Vacant(e) = buffer_map.entry(extra_node)
&& let Some(buf) = Self::resolve_runtime_buffer(
bucket,
&self.cuda_stream,
&self.hlir_buffers,
&self.external_buffers,
&self.external_output_buffers,
extra_node,
)
{
e.insert(buf);
}
}
let _span = span!(
@@ -1253,13 +1573,7 @@ impl Runtime for CudaRuntime {
for (bucket_indices, representative_dyn_map, llir) in bucket_llirs {
let mut bucket = self.compile_bucket(llir);
bucket.bucket_indices = bucket_indices.clone();
// Eagerly allocate intermediate buffers using the representative dyn_map
bucket.last_dyn_map = representative_dyn_map.clone();
Self::allocate_intermediate_buffers(
&mut bucket,
&self.cuda_stream,
representative_dyn_map,
);
let _ = representative_dyn_map;
self.compiled_buckets.push(bucket);
}
self.active_bucket = 0;
@@ -1361,7 +1675,8 @@ impl CudaRuntime {
}
}
});
let allocated = !is_marker || has_external_consumer;
let allocated = kernel_op.output_aliases_input().is_none()
&& (!is_marker || has_external_consumer);
if allocated {
bucket.buffer_specs.insert(
node,

View File

@@ -41,9 +41,8 @@ fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
all_names
}
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
/// the ConsumedBuffer (not in any other ICons).
/// When dest is NOT shared with any other compute op, KernelScatterNoCopy should
/// be the only scatter variant left after post-cleanup.
#[test]
fn test_scatter_nocopy_selected_when_dest_unshared() {
let ctx = CudaContext::new(0).unwrap();
@@ -62,12 +61,17 @@ fn test_scatter_nocopy_selected_when_dest_unshared() {
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should be available (dest is not shared)
// KernelScatterNoCopy should be the only scatter variant (dest is not shared)
assert!(
names.iter().any(|n| n == "ScatterNoCopy"),
"Expected ScatterNoCopy to be available but got: {:?}",
names
);
assert!(
!names.iter().any(|n| n == "Scatter"),
"Regular Scatter should be pruned when ScatterNoCopy is valid, got: {:?}",
names
);
}
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
@@ -109,8 +113,42 @@ fn test_scatter_nocopy_not_selected_when_dest_shared() {
);
}
/// Shared-use detection must catch the destination in non-first input
/// positions too. Gather takes indexes first and data second, so this would
/// miss the unsafe read if cleanup only inspected the head of the input list.
#[test]
fn test_scatter_nocopy_not_selected_when_dest_shared_as_later_input() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let scatter_indexes = cx.tensor(3).as_dtype(DType::Int).persist();
let read_indexes = cx.tensor(1).as_dtype(DType::Int).persist();
let scatter_result = src.scatter(scatter_indexes, dest);
let _dest_also_read = dest.gather(read_indexes).output();
let _result = scatter_result.output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest is read by another op, got: {:?}",
names
);
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected regular Scatter but got: {:?}",
names
);
}
/// Actually execute the scatter and verify correctness.
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
/// Post-cleanup should force the valid no-copy extraction.
#[test]
fn test_scatter_execution_correctness() {
let ctx = CudaContext::new(0).unwrap();
@@ -135,9 +173,8 @@ fn test_scatter_execution_correctness() {
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
// Try many random extractions to cover both Scatter and ScatterNoCopy
// Try many random extractions; each valid choice should now use ScatterNoCopy.
let mut rng = rand::rng();
let mut tested_scatter = false;
let mut tested_nocopy = false;
for _ in 0..50 {
@@ -180,27 +217,24 @@ fn test_scatter_execution_correctness() {
let actual = rt.get_f32(result);
let variant = if has_nocopy {
tested_nocopy = true;
"ScatterNoCopy"
} else if has_scatter {
tested_scatter = true;
"Scatter"
} else {
"Unknown"
};
assert!(
has_nocopy,
"Expected ScatterNoCopy after post-cleanup, got no no-copy scatter"
);
assert!(
!has_scatter,
"Regular Scatter should be pruned when ScatterNoCopy is valid"
);
tested_nocopy = true;
assert_eq!(
actual, expected,
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
"Scatter result mismatch with ScatterNoCopy: got {:?}, expected {:?}",
actual, expected
);
}
println!(
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
tested_scatter, tested_nocopy
);
println!("Tested ScatterNoCopy: {}", tested_nocopy);
assert!(
tested_nocopy,
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
@@ -242,12 +276,28 @@ fn test_scatter_kv_cache_roundtrip() {
rt = cx.search(rt, 5);
// Print which scatter variant was selected
// Print and verify which scatter variant was selected
let scatter_names: Vec<_> = rt
.kernel_names()
.iter()
.copied()
.filter(|name| name.contains("catter"))
.collect();
for name in rt.kernel_names() {
if name.contains("catter") {
println!("Selected: {name}");
}
}
assert!(
scatter_names.contains(&"ScatterNoCopy"),
"Expected ScatterNoCopy in KV-cache search result, got: {:?}",
scatter_names
);
assert!(
!scatter_names.contains(&"Scatter"),
"Regular Scatter should be pruned from KV-cache search result, got: {:?}",
scatter_names
);
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
@@ -342,17 +392,31 @@ fn test_scatter_dual_cache() {
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
// Use seeded search for deterministic scatter variant selection.
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
// Use seeded search for deterministic variant selection.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Print selected variants
// Print and verify selected variants
let scatter_names: Vec<_> = rt
.kernel_names()
.iter()
.copied()
.filter(|name| name.contains("catter"))
.collect();
for name in rt.kernel_names() {
if name.contains("catter") {
println!("Dual test selected: {name}");
}
}
assert!(
!scatter_names.is_empty(),
"Expected scatter kernels in dual-cache search result"
);
assert!(
scatter_names.iter().all(|name| *name == "ScatterNoCopy"),
"Expected only ScatterNoCopy in dual-cache search result, got: {:?}",
scatter_names
);
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);

View File

@@ -52,6 +52,9 @@ fn main() {
v_out.output();
}
cx.set_dim('s', 1);
cx.set_dim('p', 1);
println!("Building E-Graph...");
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(500),

View File

@@ -275,7 +275,7 @@ fn main() {
let weights_path = artifact_dir.join("weights.safetensors");
let cli = cli_args(&artifact_dir);
let image_path = cli.image_path.clone();
let search_graphs = 1usize;
let search_graphs = 50usize;
println!("Using artifact directory: {}", artifact_dir.display());

View File

@@ -444,6 +444,7 @@ pub fn base_expression_egglog() -> String {
p.add_ruleset("expr");
p.add_ruleset("dtype_prop");
p.add_ruleset("cleanup");
p.add_ruleset("post_cleanup");
// Register all sorts
s.register(&mut p);

View File

@@ -9,6 +9,8 @@ use std::hash::{Hash, Hasher};
use std::{str, sync::Arc, time::Duration};
use tracing::trace;
pub use egraph_serialize::{ClassId, NodeId};
pub mod api;
pub mod base;
@@ -36,7 +38,9 @@ struct EgglogSchedulePhase {
schedule: String,
}
#[derive(Debug, Clone, Default)]
pub type EGraphPostprocess = Arc<dyn Fn(&mut SerializedEGraph) + Send + Sync + 'static>;
#[derive(Clone, Default)]
pub struct LateEgglogPass {
/// Egglog declarations and rules for a backend-provided late pass.
///
@@ -49,6 +53,19 @@ pub struct LateEgglogPass {
/// Backends can use this for analysis-only layers or for analysis followed
/// by backend-specific cleanup rules.
pub schedule: String,
/// Optional Rust post-processing hook that runs on the serialized e-graph
/// after egglog has finished and before the final empty-eclass cascade.
pub postprocess: Option<EGraphPostprocess>,
}
impl std::fmt::Debug for LateEgglogPass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LateEgglogPass")
.field("program", &self.program)
.field("schedule", &self.schedule)
.field("has_postprocess", &self.postprocess.is_some())
.finish()
}
}
impl LateEgglogPass {
@@ -56,8 +73,17 @@ impl LateEgglogPass {
Self {
program: program.into(),
schedule: schedule.into(),
postprocess: None,
}
}
pub fn with_postprocess(
mut self,
postprocess: impl Fn(&mut SerializedEGraph) + Send + Sync + 'static,
) -> Self {
self.postprocess = Some(Arc::new(postprocess));
self
}
}
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
@@ -130,6 +156,7 @@ pub struct OpTextParts {
late_program: String,
rewrites: String,
late_phases: Vec<EgglogSchedulePhase>,
late_postprocesses: Vec<EGraphPostprocess>,
}
impl OpTextParts {
@@ -181,6 +208,10 @@ impl OpTextParts {
})
})
.collect(),
late_postprocesses: late_passes
.iter()
.filter_map(|pass| pass.postprocess.clone())
.collect(),
}
}
}
@@ -230,6 +261,10 @@ fn egglog_final_phases() -> Vec<EgglogSchedulePhase> {
name: "cleanup".to_string(),
schedule: "(saturate cleanup)".to_string(),
},
EgglogSchedulePhase {
name: "post cleanup".to_string(),
schedule: "(saturate post_cleanup)".to_string(),
},
EgglogSchedulePhase {
name: "base cleanup".to_string(),
schedule: "(saturate base_cleanup)".to_string(),
@@ -297,9 +332,8 @@ use crate::{
};
use egglog::{ArcSort, CommandOutput, EGraph, Value};
use egglog_reports::ReportLevel;
use egraph_serialize::{ClassId, NodeId};
#[derive(Debug)]
#[derive(Debug, Clone)]
/// This is snapshot of an EGraph with Rust native hash maps and sets for enabling more native traversal / algorithm writing.
/// The name comes from the serialize egraph crates, which returns a ETermDAG, which caused issues, so this is a homebrew semi-static egraph
pub struct SerializedEGraph {
@@ -1909,6 +1943,20 @@ pub fn run_egglog_with_report_parts(
root: &str,
op_parts: &OpTextParts,
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
#[cfg(debug_assertions)]
{
use std::sync::atomic::{AtomicBool, Ordering};
static WARNED_DEBUG_EGGLOG: AtomicBool = AtomicBool::new(false);
if !WARNED_DEBUG_EGGLOG.swap(true, Ordering::Relaxed) {
eprintln!(
"{}",
" Egglog warning: running in a debug build; model-sized saturation can be very slow. Use `cargo run --release ...` for CUDA model examples."
.yellow()
);
}
}
let total_start = std::time::Instant::now();
let full_start = std::time::Instant::now();
@@ -2178,6 +2226,10 @@ pub fn run_egglog_with_report_parts(
}
}
for postprocess in &op_parts.late_postprocesses {
postprocess(&mut egraph);
}
// Cascade: remove enodes whose children reference empty eclasses
loop {
let mut to_remove = vec![];

View File

@@ -1030,7 +1030,8 @@ impl Graph {
let mut ops = Rt::Ops::into_vec();
ops.extend(<crate::hlir::HLIROps as IntoEgglogOp>::into_vec());
let cleanup_hlir = TypeId::of::<Rt>() != TypeId::of::<NativeRuntime>();
let late_passes = Rt::late_egglog_passes(&ops, &options);
let memory_dyn_map = self.memory_limit_dyn_map();
let late_passes = Rt::late_egglog_passes(&ops, &options, &memory_dyn_map);
let (program, root) = hlir_to_egglog(self);
self.egraphs = vec![
@@ -1074,7 +1075,8 @@ impl Graph {
ops.retain(|o| !exclude_ops.contains(&o.sort().name));
ops.extend(<crate::hlir::HLIROps as IntoEgglogOp>::into_vec());
let cleanup_hlir = TypeId::of::<Rt>() != TypeId::of::<NativeRuntime>();
let late_passes = Rt::late_egglog_passes(&ops, &options);
let memory_dyn_map = self.memory_limit_dyn_map();
let late_passes = Rt::late_egglog_passes(&ops, &options, &memory_dyn_map);
let (program, root) = hlir_to_egglog(self);
self.egraphs = vec![
@@ -1176,6 +1178,16 @@ impl Graph {
combos
}
fn memory_limit_dyn_map(&self) -> FxHashMap<char, usize> {
let mut dyn_map = self.dyn_map.clone();
for (&dim, buckets) in &self.dim_buckets {
if let Some(max) = buckets.iter().map(|bucket| bucket.max).max() {
dyn_map.insert(dim, max);
}
}
dyn_map
}
/// Format a human-readable label for a bucket combination.
fn format_bucket_label(&self, bucket_indices: &FxHashMap<char, usize>) -> String {
let mut parts: Vec<String> = Vec::new();
@@ -1309,7 +1321,12 @@ impl Graph {
let has_nan = runtime.has_nan_outputs(&graph, &profile_dyn_map);
(
rep_metric,
append_memory_display(rep_display, memory_bytes),
append_memory_display(
rep_display,
memory_bytes,
runtime.planned_intermediate_buffer_bytes(),
runtime.allocated_intermediate_buffer_bytes(),
),
has_nan,
)
}));
@@ -1421,7 +1438,12 @@ impl Graph {
let has_nan = runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
(
rep_metric,
append_memory_display(rep_display, memory_bytes),
append_memory_display(
rep_display,
memory_bytes,
runtime.planned_intermediate_buffer_bytes(),
runtime.allocated_intermediate_buffer_bytes(),
),
has_nan,
)
}));
@@ -1610,11 +1632,27 @@ fn stable_toposort_by_node_index(graph: &HLIRGraph) -> Option<Vec<NodeIndex>> {
(ordered.len() == graph.node_count()).then_some(ordered)
}
fn append_memory_display(display: String, memory_bytes: Option<usize>) -> String {
let Some(bytes) = memory_bytes else {
return display;
};
format!("{display} | MEM: {}", format_memory_bytes(bytes))
fn append_memory_display(
display: String,
estimate_bytes: Option<usize>,
planned_bytes: Option<usize>,
allocated_bytes: Option<usize>,
) -> String {
let mut parts = Vec::new();
if let Some(bytes) = estimate_bytes {
parts.push(format!("EST: {}", format_memory_bytes(bytes)));
}
if let Some(bytes) = planned_bytes {
parts.push(format!("PLAN: {}", format_memory_bytes(bytes)));
}
if let Some(bytes) = allocated_bytes {
parts.push(format!("ALLOC: {}", format_memory_bytes(bytes)));
}
if parts.is_empty() {
display
} else {
format!("{display} | {}", parts.join(" | "))
}
}
fn format_memory_bytes(bytes: usize) -> String {

View File

@@ -19,6 +19,7 @@ pub trait Runtime {
fn late_egglog_passes(
_ops: &[Arc<Box<dyn EgglogOp>>],
_options: &crate::graph::BuildSearchSpaceOptions,
_dyn_map: &FxHashMap<char, usize>,
) -> Vec<crate::egglog_utils::LateEgglogPass>
where
Self: Sized,
@@ -57,6 +58,14 @@ pub trait Runtime {
fn intermediate_buffer_bytes(&self) -> usize {
0
}
/// Total bytes in the active runtime memory plan, if the runtime has one.
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
None
}
/// Total active intermediate allocation bytes, if the runtime can report it.
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
None
}
/// Check if the most recent execution produced NaN in any output buffer.
/// Used by the search to reject NaN-producing graph variants.
fn has_nan_outputs(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {