mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Memory analysis post pass (#303)
* Simplify CUDA memory analysis and arena planning * Simplify CUDA memory planning and fix clippy warnings
This commit is contained in:
@@ -19,9 +19,9 @@ use crate::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
driver::CudaStream,
|
||||
},
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
|
||||
@@ -29,9 +29,9 @@ use crate::{
|
||||
cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
@@ -268,7 +268,7 @@ impl HostOp for CuBlasLt {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::cudarc::cublaslt::sys::{
|
||||
@@ -309,9 +309,9 @@ impl HostOp for CuBlasLt {
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
|
||||
// Clamp leading dimensions to minimum valid values.
|
||||
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
@@ -12,6 +12,44 @@ pub type Ops = (
|
||||
moe::GLUMoE,
|
||||
);
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
/// the reusable arena, or an external pointer. Host ops only need the pointer
|
||||
/// and the logical byte length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct DeviceBuffer {
|
||||
ptr: u64,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl DeviceBuffer {
|
||||
pub fn new(ptr: u64, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
pub fn ptr(self) -> u64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn len(self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
|
||||
let mut host = vec![0u8; self.len];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(host)
|
||||
}
|
||||
}
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
@@ -29,7 +67,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
@@ -48,6 +86,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns relative lifetimes for extra buffer nodes within this host op.
|
||||
///
|
||||
/// The tuple is `(node, first_step, last_step)`, where steps are local to
|
||||
/// this host op's execution. Returning `None` tells the runtime to treat
|
||||
/// every extra buffer as live for the whole host op.
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
@@ -294,27 +294,140 @@ impl HostOp for GLUMoE {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
if inputs.len() < 6 {
|
||||
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
|
||||
}
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
// Resolve dimensions
|
||||
let hidden = self
|
||||
.gu_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
|
||||
let intermediate = self
|
||||
.dn_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
|
||||
let top_k = self
|
||||
.output_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
|
||||
let gu_io = self
|
||||
.gu_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
|
||||
let dn_io = self
|
||||
.dn_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
|
||||
|
||||
if hidden == 0 || intermediate == 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
|
||||
);
|
||||
}
|
||||
if top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if gu_io % hidden != 0 {
|
||||
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
|
||||
}
|
||||
if dn_io % intermediate != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
|
||||
);
|
||||
}
|
||||
|
||||
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
|
||||
let down_hidden = dn_io / intermediate;
|
||||
if gate_up_dim != intermediate * 2 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
|
||||
gate_up_dim,
|
||||
intermediate * 2
|
||||
);
|
||||
}
|
||||
if down_hidden != hidden {
|
||||
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
|
||||
}
|
||||
|
||||
let output_bytes = self
|
||||
.output_bytes()
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
|
||||
if output_bytes % (hidden * 4) != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
|
||||
hidden * 4
|
||||
);
|
||||
}
|
||||
let seq = output_bytes / (hidden * 4);
|
||||
if seq == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
// Get input/output buffers
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
|
||||
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
|
||||
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
|
||||
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
if output_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
|
||||
output_buf.len()
|
||||
);
|
||||
}
|
||||
|
||||
let gu_stride_bytes = gate_up_dim * hidden * 2;
|
||||
let down_stride_bytes = hidden * intermediate * 2;
|
||||
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
|
||||
gate_up_buf.len()
|
||||
);
|
||||
}
|
||||
let num_experts = gate_up_buf.len() / gu_stride_bytes;
|
||||
if num_experts == 0 {
|
||||
anyhow::bail!("GLUMoE has no expert weights");
|
||||
}
|
||||
if down_buf.len() < num_experts * down_stride_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down weight buffer too small: have {} bytes, need {}",
|
||||
down_buf.len(),
|
||||
num_experts * down_stride_bytes
|
||||
);
|
||||
}
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
@@ -326,21 +439,17 @@ impl HostOp for GLUMoE {
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
|
||||
|
||||
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
@@ -350,9 +459,16 @@ impl HostOp for GLUMoE {
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
if per_expert_scale_host.len() < per_expert_scale_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
|
||||
per_expert_scale_host.len()
|
||||
);
|
||||
}
|
||||
let per_expert_scale_f32: &[f32] =
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
@@ -382,10 +498,10 @@ impl HostOp for GLUMoE {
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = slice_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = slice_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
@@ -404,8 +520,8 @@ impl HostOp for GLUMoE {
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
@@ -507,7 +623,11 @@ impl HostOp for GLUMoE {
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
|
||||
buf.ptr()
|
||||
}
|
||||
|
||||
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
|
||||
@@ -113,6 +113,9 @@ impl KernelOp for FusionStart {
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
|
||||
@@ -289,6 +289,24 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
|
||||
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
|
||||
let mut rules = vec![
|
||||
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail)))
|
||||
((consumed_buffer_ilist_contains ?list ?head))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-head\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail))
|
||||
(consumed_buffer_ilist_contains ?tail ?item))
|
||||
((consumed_buffer_ilist_contains ?list ?item))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-tail\"
|
||||
)",
|
||||
),
|
||||
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
@@ -324,13 +342,28 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?op1 (Op ?k1 ?ilist1))
|
||||
(= ?ilist1 (ICons ?cb ?rest1))
|
||||
(consumed_buffer_ilist_contains ?ilist1 ?cb)
|
||||
(= ?op2 (Op ?k2 ?ilist2))
|
||||
(!= ?op1 ?op2)
|
||||
(= ?ilist2 (ICons ?a ?t2)))
|
||||
(consumed_buffer_ilist_contains ?ilist2 ?a))
|
||||
((delete (ConsumedBuffer ?a)))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-cleanup-pos\"
|
||||
:name \"consumed-buffer-cleanup-shared-op-use\"
|
||||
)",
|
||||
));
|
||||
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
|
||||
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?dest))
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
|
||||
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
|
||||
:ruleset post_cleanup
|
||||
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
|
||||
)",
|
||||
));
|
||||
// Surviving ConsumedBuffers are valid — union with source and delete.
|
||||
|
||||
@@ -13,6 +13,7 @@ use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
@@ -22,7 +23,7 @@ use luminal::{
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
@@ -47,8 +48,12 @@ struct CompiledKernel {
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Human-readable labels for input nodes, for launch diagnostics.
|
||||
input_labels: Vec<String>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
|
||||
has_dyn_dims_param: bool,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
@@ -68,7 +73,9 @@ impl CompiledKernel {
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
input_labels: Vec<String>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
has_dyn_dims_param: bool,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
@@ -79,7 +86,9 @@ impl CompiledKernel {
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
@@ -226,7 +235,7 @@ impl HostOp for CudaGraphOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
@@ -258,6 +267,40 @@ impl HostOp for CudaGraphOp {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
let state = self.state.borrow();
|
||||
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
|
||||
let max_step = state.kernels.len().saturating_sub(1);
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
lifetimes
|
||||
.entry(node)
|
||||
.and_modify(|(first, last)| {
|
||||
*first = (*first).min(step);
|
||||
*last = (*last).max(step);
|
||||
})
|
||||
.or_insert((step, step));
|
||||
};
|
||||
|
||||
for (step, kernel) in state.kernels.iter().enumerate() {
|
||||
for &input in &kernel.inputs {
|
||||
touch(input, step);
|
||||
}
|
||||
touch(kernel.node, step);
|
||||
}
|
||||
|
||||
for node in self.extra_buffer_nodes() {
|
||||
lifetimes.entry(node).or_insert((0, max_step));
|
||||
}
|
||||
|
||||
Some(
|
||||
lifetimes
|
||||
.into_iter()
|
||||
.map(|(node, (start, end))| (node, start, end))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
@@ -268,11 +311,64 @@ impl HostOp for CudaGraphOp {
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
|
||||
match kernel_name {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_requires_output_buffer(
|
||||
kernel: &CompiledKernel,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> bool {
|
||||
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
|
||||
&& kernel.kernel_op.output_aliases_input().is_none()
|
||||
}
|
||||
|
||||
fn validate_kernel_pointers(
|
||||
kernel: &CompiledKernel,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
|
||||
if *input_ptr == 0 {
|
||||
let input_label = kernel
|
||||
.input_labels
|
||||
.get(idx)
|
||||
.map(String::as_str)
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!(
|
||||
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
input_node,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
@@ -343,7 +439,7 @@ impl CudaGraphOp {
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
current_buffer_ptrs.insert(node, buf.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,13 +487,26 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
kernel_dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
@@ -424,6 +533,19 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
@@ -452,7 +574,7 @@ impl CudaGraphOp {
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
@@ -474,7 +596,7 @@ impl CudaGraphOp {
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
buffer_ptrs.insert(node, buf.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,6 +643,19 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
@@ -529,18 +664,41 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
kernel_dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
|
||||
eprintln!(
|
||||
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
kernel.inputs.len(),
|
||||
kernel.has_dyn_dims_param,
|
||||
params.values.len(),
|
||||
);
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
@@ -662,6 +820,36 @@ pub fn kernel_to_host(
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
|
||||
graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
|
||||
name_of(graph, node) == Some("FusionStart")
|
||||
|| graph[node].to_op::<LoopStart>().is_some()
|
||||
|| graph[node].to_op::<LoopEnd>().is_some()
|
||||
|| graph[node].to_op::<LoopInput>().is_some()
|
||||
|| graph[node].to_op::<LoopInputStatic>().is_some()
|
||||
|| graph[node].to_op::<LoopOutput>().is_some()
|
||||
|| graph[node].to_op::<LoopOutputSelect>().is_some()
|
||||
};
|
||||
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
|
||||
let mut visited = FxHashSet::default();
|
||||
while visited.insert(node) && is_transparent_input(graph, node) {
|
||||
let Some(pred) = graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.next()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
node = pred;
|
||||
}
|
||||
node
|
||||
};
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
// Track all CudaGraphOp nodes and their subgraphs for edge creation
|
||||
@@ -678,6 +866,7 @@ pub fn kernel_to_host(
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
let mut external_inputs = FxHashSet::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
@@ -691,9 +880,7 @@ pub fn kernel_to_host(
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
@@ -711,14 +898,35 @@ pub fn kernel_to_host(
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect_vec();
|
||||
if let Some(expected_inputs) =
|
||||
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
|
||||
{
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
expected_inputs,
|
||||
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel_op_ref.kernel_name(),
|
||||
kernel_node_idx,
|
||||
);
|
||||
}
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
@@ -729,6 +937,12 @@ pub fn kernel_to_host(
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
@@ -739,7 +953,9 @@ pub fn kernel_to_host(
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op.clone(),
|
||||
has_dyn_dims_param,
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
@@ -752,6 +968,7 @@ pub fn kernel_to_host(
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
@@ -763,7 +980,20 @@ pub fn kernel_to_host(
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region.external_inputs.clone();
|
||||
let inputs: Vec<NodeIndex> = region
|
||||
.external_inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect();
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
@@ -771,6 +1001,12 @@ pub fn kernel_to_host(
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
@@ -781,7 +1017,9 @@ pub fn kernel_to_host(
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
@@ -826,16 +1064,17 @@ pub fn kernel_to_host(
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
// Find external inputs: nodes outside subgraph that have edges into
|
||||
// subgraph. Also include normalized FusionStart predecessors, because
|
||||
// the compiled kernels read from the concrete producer buffer rather
|
||||
// than the marker node.
|
||||
external_inputs.extend(subgraph.iter().flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
}));
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
|
||||
|
||||
use fixedbitset::FixedBitSet;
|
||||
use half::{bf16, f16};
|
||||
@@ -32,6 +32,8 @@ use std::{
|
||||
use tracing::{Level, span, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
const ARENA_ALIGNMENT: usize = 256;
|
||||
|
||||
pub enum CudaInput {
|
||||
Buffer(CudaSlice<u8>),
|
||||
Ptr(u64),
|
||||
@@ -70,13 +72,25 @@ pub(crate) struct BufferSpec {
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PlannedBuffer {
|
||||
node: NodeIndex,
|
||||
bytes: usize,
|
||||
start: usize,
|
||||
end: usize,
|
||||
}
|
||||
|
||||
/// Per-bucket compiled state. Each bucket holds its own executable graph,
|
||||
/// explicit runtime metadata, intermediate buffers, and node mappings.
|
||||
/// Weights (hlir_buffers) are shared.
|
||||
pub(crate) struct CompiledBucket {
|
||||
pub(crate) exec_graph: StableGraph<ExecutableHostOp, (), Directed>,
|
||||
pub(crate) node_to_exec: FxHashMap<NodeIndex, NodeIndex>,
|
||||
pub(crate) buffers: FxHashMap<NodeIndex, CudaSlice<u8>>,
|
||||
/// Single reusable arena for all intermediate buffers in this bucket.
|
||||
pub(crate) arena: Option<CudaSlice<u8>>,
|
||||
pub(crate) arena_bytes: usize,
|
||||
pub(crate) logical_buffer_offsets: FxHashMap<NodeIndex, usize>,
|
||||
pub(crate) logical_buffer_bytes: FxHashMap<NodeIndex, usize>,
|
||||
pub(crate) cached_buffer_ptrs: FxHashMap<NodeIndex, u64>,
|
||||
pub(crate) buffer_specs: FxHashMap<NodeIndex, BufferSpec>,
|
||||
pub(crate) llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
|
||||
@@ -99,7 +113,10 @@ impl CompiledBucket {
|
||||
CompiledBucket {
|
||||
exec_graph: StableGraph::default(),
|
||||
node_to_exec: FxHashMap::default(),
|
||||
buffers: FxHashMap::default(),
|
||||
arena: None,
|
||||
arena_bytes: 0,
|
||||
logical_buffer_offsets: FxHashMap::default(),
|
||||
logical_buffer_bytes: FxHashMap::default(),
|
||||
cached_buffer_ptrs: FxHashMap::default(),
|
||||
buffer_specs: FxHashMap::default(),
|
||||
llir_to_hlir: FxHashMap::default(),
|
||||
@@ -186,9 +203,71 @@ impl CudaRuntime {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Public access to the active intermediate buffers (for tests and diagnostics).
|
||||
pub fn buffers(&self) -> &FxHashMap<NodeIndex, CudaSlice<u8>> {
|
||||
&self.active().buffers
|
||||
fn bucket_buffer(
|
||||
bucket: &CompiledBucket,
|
||||
stream: &Arc<CudaStream>,
|
||||
logical_node: &NodeIndex,
|
||||
) -> Option<DeviceBuffer> {
|
||||
let arena = bucket.arena.as_ref()?;
|
||||
let offset = *bucket.logical_buffer_offsets.get(logical_node)?;
|
||||
let len = *bucket.logical_buffer_bytes.get(logical_node)?;
|
||||
let ptr = arena.device_ptr(stream).0.checked_add(offset as u64)?;
|
||||
Some(DeviceBuffer::new(ptr, len))
|
||||
}
|
||||
|
||||
fn copy_device_buffer_to_new_slice(
|
||||
stream: &Arc<CudaStream>,
|
||||
src: DeviceBuffer,
|
||||
) -> CudaSlice<u8> {
|
||||
let dst = stream.alloc_zeros::<u8>(src.len()).unwrap();
|
||||
let dst_ptr = dst.device_ptr(stream).0;
|
||||
unsafe {
|
||||
result::memcpy_dtod_async(dst_ptr, src.ptr(), src.len(), stream.cu_stream())
|
||||
.expect("cuMemcpyDtoDAsync failed");
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
dst
|
||||
}
|
||||
|
||||
fn resolve_runtime_buffer(
|
||||
bucket: &CompiledBucket,
|
||||
stream: &Arc<CudaStream>,
|
||||
hlir_buffers: &FxHashMap<NodeIndex, CudaInput>,
|
||||
external_buffers: &FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
external_output_buffers: &FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
mut node: NodeIndex,
|
||||
) -> Option<DeviceBuffer> {
|
||||
let mut visited = FxHashSet::default();
|
||||
loop {
|
||||
if !visited.insert(node) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(ext) = external_output_buffers.get(&node) {
|
||||
return Some(DeviceBuffer::new(ext.device_ptr(stream).0, ext.len()));
|
||||
}
|
||||
|
||||
if let Some(buf) = Self::bucket_buffer(bucket, stream, &node) {
|
||||
return Some(buf);
|
||||
}
|
||||
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&node) {
|
||||
match hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
return Some(DeviceBuffer::new(buf.device_ptr(stream).0, buf.len()));
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = external_buffers.get(hlir_node) {
|
||||
return Some(DeviceBuffer::new(ext.device_ptr(stream).0, ext.len()));
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
let alias_target = bucket.output_alias_map.get(&node)?;
|
||||
node = *alias_target;
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -324,6 +403,15 @@ impl CudaRuntime {
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
|
||||
let truncate_to_logical_bytes = |mut data: Vec<u8>| {
|
||||
if let Some(spec) = bucket.buffer_specs.get(&data_id)
|
||||
&& let Some(logical_bytes) = spec.bytes.exec(&bucket.last_dyn_map)
|
||||
{
|
||||
data.truncate(logical_bytes.min(data.len()));
|
||||
}
|
||||
data
|
||||
};
|
||||
|
||||
let _span = span!(Level::TRACE, "dtoh").entered();
|
||||
// If predecessor is an Input node, data lives in hlir_buffers
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
|
||||
@@ -345,40 +433,45 @@ impl CudaRuntime {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Predecessor is a computation node — data is in intermediate buffers
|
||||
self.cuda_stream
|
||||
.clone_dtoh(
|
||||
bucket
|
||||
.buffers
|
||||
.get(&data_id)
|
||||
.expect("Cannot find tensor in runtime!"),
|
||||
)
|
||||
.unwrap()
|
||||
if let Some(ext) = self.external_output_buffers.get(&data_id) {
|
||||
return truncate_to_logical_bytes(self.cuda_stream.clone_dtoh(&**ext).unwrap());
|
||||
}
|
||||
|
||||
// Predecessor is a computation node — data is in the intermediate arena.
|
||||
truncate_to_logical_bytes(
|
||||
Self::bucket_buffer(bucket, &self.cuda_stream, &data_id)
|
||||
.expect("Cannot find tensor in runtime!")
|
||||
.clone_dtoh(&self.cuda_stream)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the device-side CudaSlice for an output tensor without copying to host.
|
||||
/// Resolve the device-side buffer for an output tensor without copying to host.
|
||||
/// Used by copy_output_to_device_ptr for DtoD transfers.
|
||||
fn resolve_output_slice(&self, id: impl ToId) -> &CudaSlice<u8> {
|
||||
fn resolve_output_buffer(&self, id: impl ToId) -> DeviceBuffer {
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
if let Some(ext) = self.external_output_buffers.get(&data_id) {
|
||||
return DeviceBuffer::new(ext.device_ptr(&self.cuda_stream).0, ext.len());
|
||||
}
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
|
||||
match self
|
||||
.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Cannot find input tensor in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => buf,
|
||||
CudaInput::Buffer(buf) => {
|
||||
DeviceBuffer::new(buf.device_ptr(&self.cuda_stream).0, buf.len())
|
||||
}
|
||||
CudaInput::Ptr(_) => self
|
||||
.external_buffers
|
||||
.get(hlir_node)
|
||||
.map(|ext| &**ext)
|
||||
.map(|ext| DeviceBuffer::new(ext.device_ptr(&self.cuda_stream).0, ext.len()))
|
||||
.expect("Cannot read raw pointer input — no external_buffers entry for node"),
|
||||
}
|
||||
} else {
|
||||
bucket
|
||||
.buffers
|
||||
.get(&data_id)
|
||||
Self::bucket_buffer(bucket, &self.cuda_stream, &data_id)
|
||||
.expect("Cannot find tensor in runtime!")
|
||||
}
|
||||
}
|
||||
@@ -393,13 +486,12 @@ impl CudaRuntime {
|
||||
dest_ptr != 0,
|
||||
"copy_output_to_device_ptr called with null pointer"
|
||||
);
|
||||
let src_slice = self.resolve_output_slice(id);
|
||||
let src_ptr = src_slice.device_ptr(&self.cuda_stream).0;
|
||||
let copy_bytes = n_bytes.min(src_slice.len());
|
||||
let src = self.resolve_output_buffer(id);
|
||||
let copy_bytes = n_bytes.min(src.len());
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtod_async(
|
||||
result::memcpy_dtod_async(
|
||||
dest_ptr,
|
||||
src_ptr,
|
||||
src.ptr(),
|
||||
copy_bytes,
|
||||
self.cuda_stream.cu_stream(),
|
||||
)
|
||||
@@ -496,39 +588,36 @@ impl CudaRuntime {
|
||||
CudaInput::Ptr(p) => panic!("Cannot take raw pointer input (ptr=0x{:x})", p),
|
||||
}
|
||||
} else {
|
||||
self.compiled_buckets[bi]
|
||||
.buffers
|
||||
.remove(&lineage_node)
|
||||
.expect("Cannot find tensor in runtime!")
|
||||
let src = Self::bucket_buffer(
|
||||
&self.compiled_buckets[bi],
|
||||
&self.cuda_stream,
|
||||
&lineage_node,
|
||||
)
|
||||
.expect("Cannot find tensor in runtime!");
|
||||
Self::copy_device_buffer_to_new_slice(&self.cuda_stream, src)
|
||||
}
|
||||
} else {
|
||||
// Copy-then-modify: output data is in alias_node's buffer (intermediate),
|
||||
// but we want to extract the lineage HLIR buffer so intermediates stay intact.
|
||||
// while the lineage HLIR buffer has stale pre-op data. Return an owned
|
||||
// copy of the arena output and drop the stale HLIR buffer.
|
||||
let hlir_node = *self.compiled_buckets[bi]
|
||||
.llir_to_hlir
|
||||
.get(&lineage_node)
|
||||
.expect("output_data_input lineage must reach an HLIR input node");
|
||||
|
||||
// Take the intermediate buffer (has the actual output data)
|
||||
let output_buf = self.compiled_buckets[bi]
|
||||
.buffers
|
||||
.remove(&alias_node)
|
||||
.expect("Cannot find intermediate output buffer in runtime!");
|
||||
let output =
|
||||
Self::bucket_buffer(&self.compiled_buckets[bi], &self.cuda_stream, &alias_node)
|
||||
.expect("Cannot find intermediate output buffer in runtime!");
|
||||
let output_buf = Self::copy_device_buffer_to_new_slice(&self.cuda_stream, output);
|
||||
|
||||
// Take the HLIR buffer (has stale pre-op data)
|
||||
let hlir_buf = match self
|
||||
match self
|
||||
.hlir_buffers
|
||||
.remove(&hlir_node)
|
||||
.expect("Cannot find HLIR input buffer in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => buf,
|
||||
CudaInput::Buffer(_buf) => {}
|
||||
CudaInput::Ptr(p) => panic!("Cannot take raw pointer input (ptr=0x{:x})", p),
|
||||
};
|
||||
|
||||
// Put stale HLIR buffer into intermediate slot (keeps allocation alive)
|
||||
self.compiled_buckets[bi]
|
||||
.buffers
|
||||
.insert(alias_node, hlir_buf);
|
||||
}
|
||||
|
||||
// Return the output buffer (has correct data)
|
||||
output_buf
|
||||
@@ -595,20 +684,16 @@ impl CudaRuntime {
|
||||
.get(&input_id)
|
||||
.expect("Cannot find input in LLIR mapping!");
|
||||
|
||||
// Swap intermediate buffer <-> input buffer
|
||||
let intermediate_buf = self.compiled_buckets[bi]
|
||||
.buffers
|
||||
.get_mut(&data_llir_node)
|
||||
.expect("Output not in intermediate buffers");
|
||||
if let CudaInput::Buffer(input_buf) = self
|
||||
.hlir_buffers
|
||||
.get_mut(&input_id)
|
||||
.expect("Input not in hlir_buffers")
|
||||
{
|
||||
std::mem::swap(intermediate_buf, input_buf);
|
||||
} else {
|
||||
panic!("Input is a raw pointer, cannot swap");
|
||||
}
|
||||
let src = Self::bucket_buffer(
|
||||
&self.compiled_buckets[bi],
|
||||
&self.cuda_stream,
|
||||
&data_llir_node,
|
||||
)
|
||||
.expect("Output not in intermediate buffers");
|
||||
let input_buf = Self::copy_device_buffer_to_new_slice(&self.cuda_stream, src);
|
||||
self.hlir_buffers
|
||||
.insert(input_id, CudaInput::Buffer(input_buf));
|
||||
self.changed_hlir.insert(input_id);
|
||||
|
||||
// Update cached pointer for the input
|
||||
let ptr = match &self.hlir_buffers[&input_id] {
|
||||
@@ -624,7 +709,7 @@ impl CudaRuntime {
|
||||
/// They will be re-allocated on the next `execute()` call.
|
||||
pub fn free_intermediate_buffers(&mut self) {
|
||||
for bucket in &mut self.compiled_buckets {
|
||||
bucket.buffers.clear();
|
||||
bucket.arena = None;
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
}
|
||||
}
|
||||
@@ -635,42 +720,248 @@ impl CudaRuntime {
|
||||
stream: &Arc<CudaStream>,
|
||||
dyn_dims: &FxHashMap<char, usize>,
|
||||
) {
|
||||
let is_first_alloc = bucket.buffers.is_empty();
|
||||
|
||||
// Only sync if we might need to free/reallocate buffers
|
||||
if is_first_alloc {
|
||||
stream.synchronize().unwrap();
|
||||
let needs_new_plan = !Self::buffer_plan_matches(bucket, dyn_dims);
|
||||
if needs_new_plan {
|
||||
if bucket.arena.is_some() {
|
||||
stream.synchronize().unwrap();
|
||||
}
|
||||
Self::plan_intermediate_buffers(bucket, dyn_dims);
|
||||
}
|
||||
|
||||
if bucket.arena_bytes == 0 {
|
||||
bucket.arena = None;
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
if bucket
|
||||
.arena
|
||||
.as_ref()
|
||||
.is_none_or(|arena| arena.len() < bucket.arena_bytes)
|
||||
{
|
||||
bucket.arena = Some(stream.alloc_zeros(bucket.arena_bytes).unwrap());
|
||||
}
|
||||
|
||||
let arena_ptr = bucket.arena.as_ref().unwrap().device_ptr(stream).0;
|
||||
for (logical_node, &offset) in &bucket.logical_buffer_offsets {
|
||||
if let Some(ptr) = arena_ptr.checked_add(offset as u64) {
|
||||
bucket.cached_buffer_ptrs.insert(*logical_node, ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_plan_matches(bucket: &CompiledBucket, dyn_dims: &FxHashMap<char, usize>) -> bool {
|
||||
if bucket.buffer_specs.is_empty() {
|
||||
return true;
|
||||
}
|
||||
if bucket.logical_buffer_offsets.is_empty() && !bucket.buffer_specs.is_empty() {
|
||||
return false;
|
||||
}
|
||||
bucket
|
||||
.intermediate_buffer_dims
|
||||
.iter()
|
||||
.all(|dim| bucket.last_dyn_map.get(dim) == dyn_dims.get(dim))
|
||||
}
|
||||
|
||||
fn plan_intermediate_buffers(bucket: &mut CompiledBucket, dyn_dims: &FxHashMap<char, usize>) {
|
||||
bucket.logical_buffer_offsets.clear();
|
||||
bucket.logical_buffer_bytes.clear();
|
||||
bucket.arena_bytes = 0;
|
||||
bucket.intermediate_buffer_dims.clear();
|
||||
let mut total_alloc: usize = 0;
|
||||
let mut realloc_count: usize = 0;
|
||||
for (node, spec) in bucket.buffer_specs.clone() {
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
bucket.last_dyn_map = dyn_dims.clone();
|
||||
|
||||
let mut logical_bytes = FxHashMap::default();
|
||||
for (node, spec) in &bucket.buffer_specs {
|
||||
bucket
|
||||
.intermediate_buffer_dims
|
||||
.extend(spec.bytes.dyn_vars());
|
||||
let needed_bytes = spec.bytes.exec(dyn_dims).unwrap();
|
||||
|
||||
if needed_bytes == 0 {
|
||||
continue;
|
||||
let bytes = spec.bytes.exec(dyn_dims).unwrap();
|
||||
if bytes > 0 {
|
||||
logical_bytes.insert(*node, bytes);
|
||||
}
|
||||
|
||||
// Only allocate/reallocate if we don't have a buffer or existing one is too small
|
||||
let existing_len = bucket.buffers.get(&node).map(|b| b.len()).unwrap_or(0);
|
||||
if existing_len >= needed_bytes {
|
||||
continue; // Existing buffer is large enough, reuse it
|
||||
}
|
||||
|
||||
// Need to allocate (or reallocate)
|
||||
total_alloc += needed_bytes;
|
||||
realloc_count += 1;
|
||||
bucket
|
||||
.buffers
|
||||
.insert(node, stream.alloc_zeros(needed_bytes).unwrap());
|
||||
let ptr = bucket.buffers[&node].device_ptr(stream).0;
|
||||
bucket.cached_buffer_ptrs.insert(node, ptr);
|
||||
}
|
||||
let _ = (realloc_count, total_alloc);
|
||||
|
||||
if logical_bytes.is_empty() {
|
||||
bucket.arena = None;
|
||||
return;
|
||||
}
|
||||
let total_spec_count = logical_bytes.len();
|
||||
let total_spec_bytes = logical_bytes.values().copied().sum::<usize>();
|
||||
|
||||
let mut first_use: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
let mut last_use: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
let exec_order = toposort(&bucket.exec_graph, None).unwrap_or_default();
|
||||
let output_alias_map = bucket.output_alias_map.clone();
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
let Some(node) = resolve_logical_buffer_node(node, &logical_bytes, &output_alias_map)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
first_use
|
||||
.entry(node)
|
||||
.and_modify(|first| *first = (*first).min(step))
|
||||
.or_insert(step);
|
||||
last_use
|
||||
.entry(node)
|
||||
.and_modify(|last| *last = (*last).max(step))
|
||||
.or_insert(step);
|
||||
};
|
||||
|
||||
let mut time = 0usize;
|
||||
for exec_node in exec_order.iter().copied() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
let precise_extra_lifetimes = exec_op.internal.extra_buffer_lifetimes();
|
||||
let span = precise_extra_lifetimes
|
||||
.as_ref()
|
||||
.and_then(|lifetimes| lifetimes.iter().map(|(_, _, end)| *end).max())
|
||||
.map(|end| end + 1)
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let start_time = time;
|
||||
let end_time = time + span - 1;
|
||||
time += span;
|
||||
|
||||
let precise_nodes = precise_extra_lifetimes
|
||||
.as_ref()
|
||||
.map(|lifetimes| {
|
||||
lifetimes
|
||||
.iter()
|
||||
.filter_map(|(node, _, _)| {
|
||||
resolve_logical_buffer_node(*node, &logical_bytes, &output_alias_map)
|
||||
})
|
||||
.collect::<FxHashSet<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut touch_if_not_precise = |node: NodeIndex, step: usize| {
|
||||
if resolve_logical_buffer_node(node, &logical_bytes, &output_alias_map)
|
||||
.is_some_and(|node| precise_nodes.contains(&node))
|
||||
{
|
||||
return;
|
||||
}
|
||||
touch(node, step);
|
||||
};
|
||||
|
||||
touch_if_not_precise(exec_op.output, start_time);
|
||||
touch_if_not_precise(exec_op.output, end_time);
|
||||
for &input in &exec_op.inputs {
|
||||
touch_if_not_precise(input, start_time);
|
||||
touch_if_not_precise(input, end_time);
|
||||
}
|
||||
|
||||
if let Some(lifetimes) = precise_extra_lifetimes {
|
||||
for (node, start, end) in lifetimes {
|
||||
touch(node, start_time + start);
|
||||
touch(node, start_time + end);
|
||||
}
|
||||
} else {
|
||||
for extra_node in exec_op.internal.extra_buffer_nodes() {
|
||||
touch(extra_node, start_time);
|
||||
touch(extra_node, end_time);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for &producer in bucket.output_producers.values() {
|
||||
let mut alias_node = producer;
|
||||
while let Some(target) = bucket.output_alias_map.get(&alias_node) {
|
||||
alias_node = *target;
|
||||
}
|
||||
touch(alias_node, time);
|
||||
|
||||
let mut data_node = producer;
|
||||
while let Some(target) = bucket.output_data_map.get(&data_node) {
|
||||
data_node = *target;
|
||||
}
|
||||
touch(data_node, time);
|
||||
touch(producer, time);
|
||||
}
|
||||
|
||||
let mut planned = logical_bytes
|
||||
.into_iter()
|
||||
.filter(|(node, _)| first_use.contains_key(node) || last_use.contains_key(node))
|
||||
.map(|(node, bytes)| PlannedBuffer {
|
||||
node,
|
||||
bytes,
|
||||
start: first_use.get(&node).copied().unwrap_or(0),
|
||||
end: last_use.get(&node).copied().unwrap_or(0),
|
||||
})
|
||||
.collect_vec();
|
||||
planned.sort_by_key(|buf| (buf.start, std::cmp::Reverse(buf.bytes), buf.node.index()));
|
||||
let planned_logical_count = planned.len();
|
||||
let planned_logical_bytes = planned.iter().map(|buf| buf.bytes).sum::<usize>();
|
||||
let logical_peak = logical_interval_peak(&planned);
|
||||
|
||||
let mut arena_end = 0usize;
|
||||
let mut placed: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(planned.len());
|
||||
let mut placement_order = planned.iter().collect_vec();
|
||||
placement_order.sort_by_key(|buf| {
|
||||
(
|
||||
std::cmp::Reverse(buf.bytes),
|
||||
std::cmp::Reverse(buf.end.saturating_sub(buf.start)),
|
||||
buf.start,
|
||||
buf.node.index(),
|
||||
)
|
||||
});
|
||||
|
||||
for buf in placement_order {
|
||||
let allocation_bytes = align_up(buf.bytes, ARENA_ALIGNMENT);
|
||||
let mut candidates = vec![0usize];
|
||||
for &(placed_start, placed_end, placed_offset, placed_bytes) in &placed {
|
||||
if intervals_overlap(buf.start, buf.end, placed_start, placed_end) {
|
||||
candidates.push(align_up(placed_offset + placed_bytes, ARENA_ALIGNMENT));
|
||||
}
|
||||
}
|
||||
candidates.sort_unstable();
|
||||
candidates.dedup();
|
||||
|
||||
let offset = candidates
|
||||
.into_iter()
|
||||
.find(|&candidate| {
|
||||
placed
|
||||
.iter()
|
||||
.all(|&(placed_start, placed_end, placed_offset, placed_bytes)| {
|
||||
!intervals_overlap(buf.start, buf.end, placed_start, placed_end)
|
||||
|| !byte_ranges_overlap(
|
||||
candidate,
|
||||
allocation_bytes,
|
||||
placed_offset,
|
||||
placed_bytes,
|
||||
)
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
placed
|
||||
.iter()
|
||||
.filter(|(placed_start, placed_end, _, _)| {
|
||||
intervals_overlap(buf.start, buf.end, *placed_start, *placed_end)
|
||||
})
|
||||
.map(|(_, _, offset, bytes)| align_up(offset + bytes, ARENA_ALIGNMENT))
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
bucket.logical_buffer_offsets.insert(buf.node, offset);
|
||||
bucket.logical_buffer_bytes.insert(buf.node, buf.bytes);
|
||||
placed.push((buf.start, buf.end, offset, allocation_bytes));
|
||||
arena_end = arena_end.max(offset + allocation_bytes);
|
||||
}
|
||||
bucket.arena_bytes = arena_end;
|
||||
|
||||
if std::env::var_os("LUMINAL_CUDA_MEMORY_DEBUG").is_some() {
|
||||
eprintln!(
|
||||
" CUDA memory plan specs={total_spec_count} used={planned_logical_count} skipped={} spec_bytes={} used_bytes={} skipped_bytes={} logical_peak={} arena_plan={} allocations={}",
|
||||
total_spec_count.saturating_sub(planned_logical_count),
|
||||
total_spec_bytes,
|
||||
planned_logical_bytes,
|
||||
total_spec_bytes.saturating_sub(planned_logical_bytes),
|
||||
logical_peak,
|
||||
bucket.arena_bytes,
|
||||
bucket.logical_buffer_offsets.len(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-allocate buffers with the given dynamic dimension values.
|
||||
@@ -679,10 +970,7 @@ impl CudaRuntime {
|
||||
pub fn prebuild_graphs(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
// 1. Allocate intermediate buffers (needed for buffer pointers)
|
||||
if bucket.buffers.is_empty() {
|
||||
bucket.last_dyn_map = dyn_map.clone();
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
}
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
|
||||
// 2. Process changed HLIR inputs to get their buffer pointers
|
||||
if !self.changed_hlir.is_empty() || !bucket.hlir_synced {
|
||||
@@ -809,6 +1097,57 @@ fn format_duration_precise(d: &std::time::Duration) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_logical_buffer_node(
|
||||
mut node: NodeIndex,
|
||||
logical_bytes: &FxHashMap<NodeIndex, usize>,
|
||||
output_alias_map: &FxHashMap<NodeIndex, NodeIndex>,
|
||||
) -> Option<NodeIndex> {
|
||||
let mut visited = FxHashSet::default();
|
||||
while !logical_bytes.contains_key(&node) {
|
||||
if !visited.insert(node) {
|
||||
return None;
|
||||
}
|
||||
let target = output_alias_map.get(&node)?;
|
||||
node = *target;
|
||||
}
|
||||
|
||||
Some(node)
|
||||
}
|
||||
|
||||
fn align_up(value: usize, alignment: usize) -> usize {
|
||||
if alignment <= 1 {
|
||||
value
|
||||
} else {
|
||||
value.div_ceil(alignment) * alignment
|
||||
}
|
||||
}
|
||||
|
||||
fn intervals_overlap(a_start: usize, a_end: usize, b_start: usize, b_end: usize) -> bool {
|
||||
a_start <= b_end && b_start <= a_end
|
||||
}
|
||||
|
||||
fn byte_ranges_overlap(a_offset: usize, a_bytes: usize, b_offset: usize, b_bytes: usize) -> bool {
|
||||
a_offset < b_offset + b_bytes && b_offset < a_offset + a_bytes
|
||||
}
|
||||
|
||||
fn logical_interval_peak(planned: &[PlannedBuffer]) -> usize {
|
||||
let mut events = Vec::with_capacity(planned.len() * 2);
|
||||
for buf in planned {
|
||||
events.push((buf.start, buf.bytes as i128));
|
||||
events.push((buf.end.saturating_add(1), -(buf.bytes as i128)));
|
||||
}
|
||||
events.sort_by_key(|(step, delta)| (*step, *delta));
|
||||
|
||||
let mut current = 0i128;
|
||||
let mut peak = 0i128;
|
||||
for (_, delta) in events {
|
||||
current += delta;
|
||||
peak = peak.max(current);
|
||||
}
|
||||
|
||||
peak.max(0) as usize
|
||||
}
|
||||
|
||||
impl Runtime for CudaRuntime {
|
||||
type Ops = (crate::kernel::Ops, crate::host::Ops);
|
||||
type CompileArg = Arc<CudaStream>;
|
||||
@@ -817,9 +1156,14 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
fn late_egglog_passes(
|
||||
ops: &[Arc<Box<dyn luminal::op::EgglogOp>>],
|
||||
_options: &luminal::graph::BuildSearchSpaceOptions,
|
||||
options: &luminal::graph::BuildSearchSpaceOptions,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
|
||||
vec![crate::memory_analysis::cuda_memory_analysis_pass(ops)]
|
||||
vec![crate::memory_analysis::cuda_memory_analysis_pass(
|
||||
ops,
|
||||
options.max_memory_bytes,
|
||||
dyn_map,
|
||||
)]
|
||||
}
|
||||
|
||||
fn estimate_graph_memory<'a>(
|
||||
@@ -904,7 +1248,7 @@ impl Runtime for CudaRuntime {
|
||||
fn clear_intermediate_buffers(&mut self) {
|
||||
let _ = self.cuda_stream.synchronize();
|
||||
for bucket in &mut self.compiled_buckets {
|
||||
bucket.buffers.clear();
|
||||
bucket.arena = None;
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
}
|
||||
}
|
||||
@@ -912,14 +1256,37 @@ impl Runtime for CudaRuntime {
|
||||
fn intermediate_buffer_bytes(&self) -> usize {
|
||||
self.compiled_buckets
|
||||
.iter()
|
||||
.map(|b| b.buffers.values().map(|buf| buf.len()).sum::<usize>())
|
||||
.map(|b| b.arena.as_ref().map(|arena| arena.len()).unwrap_or(0))
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
self.compiled_buckets
|
||||
.get(self.active_bucket)
|
||||
.map(|bucket| bucket.arena_bytes)
|
||||
}
|
||||
|
||||
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
self.compiled_buckets
|
||||
.get(self.active_bucket)
|
||||
.map(|bucket| bucket.arena.as_ref().map(|arena| arena.len()).unwrap_or(0))
|
||||
}
|
||||
|
||||
fn has_nan_outputs(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {
|
||||
let _ = self.cuda_stream.synchronize();
|
||||
let bucket = self.active();
|
||||
for (node_id, buf) in &bucket.buffers {
|
||||
let mut checked = FxHashSet::default();
|
||||
for producer in bucket.output_producers.values().copied() {
|
||||
let mut node_id = producer;
|
||||
while let Some(alias_target) = bucket.output_alias_map.get(&node_id) {
|
||||
node_id = *alias_target;
|
||||
}
|
||||
if !checked.insert(node_id) {
|
||||
continue;
|
||||
}
|
||||
let Some(buf) = Self::bucket_buffer(bucket, &self.cuda_stream, &node_id) else {
|
||||
continue;
|
||||
};
|
||||
let n_bytes = buf.len();
|
||||
if n_bytes == 0 || n_bytes % 4 != 0 {
|
||||
continue;
|
||||
@@ -929,7 +1296,7 @@ impl Runtime for CudaRuntime {
|
||||
// and their bit patterns can produce false positives when reinterpreted as f32.
|
||||
let is_float = bucket
|
||||
.buffer_specs
|
||||
.get(node_id)
|
||||
.get(&node_id)
|
||||
.map(|spec| matches!(spec.dtype, DType::F32))
|
||||
.unwrap_or(true);
|
||||
|
||||
@@ -937,7 +1304,7 @@ impl Runtime for CudaRuntime {
|
||||
continue;
|
||||
}
|
||||
|
||||
let host_bytes: Vec<u8> = match self.cuda_stream.clone_dtoh(buf) {
|
||||
let host_bytes: Vec<u8> = match buf.clone_dtoh(&self.cuda_stream) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
@@ -957,9 +1324,9 @@ impl Runtime for CudaRuntime {
|
||||
_trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
// Clear active bucket's buffers before loading new LLIR for profiling
|
||||
// Clear active bucket's arena before loading new LLIR for profiling.
|
||||
if !self.compiled_buckets.is_empty() {
|
||||
self.active_mut().buffers.clear();
|
||||
self.active_mut().arena = None;
|
||||
}
|
||||
self.load_llir(llir_graph);
|
||||
self.profiling = true;
|
||||
@@ -1020,7 +1387,7 @@ impl Runtime for CudaRuntime {
|
||||
if idx != self.active_bucket {
|
||||
// Free the old bucket's intermediates to avoid holding 2 full sets in GPU memory
|
||||
let old = self.active_bucket;
|
||||
self.compiled_buckets[old].buffers.clear();
|
||||
self.compiled_buckets[old].arena = None;
|
||||
self.compiled_buckets[old].cached_buffer_ptrs.clear();
|
||||
self.active_bucket = idx;
|
||||
// Mark bucket as needing HLIR sync since it may have missed changes
|
||||
@@ -1029,17 +1396,7 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
let buffers_empty = bucket.buffers.is_empty();
|
||||
let dyn_map_len_changed = dyn_map.len() != bucket.last_dyn_map.len();
|
||||
let dyn_dims_changed = dyn_map
|
||||
.iter()
|
||||
.filter(|(d, _)| bucket.intermediate_buffer_dims.contains(*d))
|
||||
.any(|(d, v)| bucket.last_dyn_map.get(d).map(|n| *n != *v).unwrap_or(true));
|
||||
let needs_realloc = buffers_empty || dyn_map_len_changed || dyn_dims_changed;
|
||||
if needs_realloc {
|
||||
bucket.last_dyn_map = dyn_map.clone();
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
}
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
// Cache HLIR input pointers
|
||||
if !self.changed_hlir.is_empty() || !bucket.hlir_synced {
|
||||
let hlir_nodes: Vec<NodeIndex> = if !bucket.hlir_synced {
|
||||
@@ -1081,82 +1438,45 @@ impl Runtime for CudaRuntime {
|
||||
trace!("Executing: {:?}", exec_op);
|
||||
|
||||
// Build buffer map for the HostOp interface
|
||||
let mut buffer_map: FxHashMap<NodeIndex, &CudaSlice<u8>> = FxHashMap::default();
|
||||
let mut buffer_map: FxHashMap<NodeIndex, DeviceBuffer> = FxHashMap::default();
|
||||
|
||||
// Add output buffer -- prefer external output pointer if registered (zero copy)
|
||||
if let Some(ext) = self.external_output_buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, &**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
if let Some(buf) = Self::resolve_runtime_buffer(
|
||||
bucket,
|
||||
&self.cuda_stream,
|
||||
&self.hlir_buffers,
|
||||
&self.external_buffers,
|
||||
&self.external_output_buffers,
|
||||
exec_op.output,
|
||||
) {
|
||||
buffer_map.insert(exec_op.output, buf);
|
||||
}
|
||||
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
|
||||
for inp in exec_op.inputs.iter() {
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
buffer_map.insert(*inp, &**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
if !buffer_map.contains_key(inp)
|
||||
&& let Some(buf) = bucket.buffers.get(inp)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
} else if let Some(buf) = bucket.buffers.get(inp) {
|
||||
buffer_map.insert(*inp, buf);
|
||||
|
||||
for &inp in &exec_op.inputs {
|
||||
if let Some(buf) = Self::resolve_runtime_buffer(
|
||||
bucket,
|
||||
&self.cuda_stream,
|
||||
&self.hlir_buffers,
|
||||
&self.external_buffers,
|
||||
&self.external_output_buffers,
|
||||
inp,
|
||||
) {
|
||||
buffer_map.insert(inp, buf);
|
||||
}
|
||||
}
|
||||
// Add extra buffer nodes (for CudaGraphOp)
|
||||
|
||||
let extra_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
for extra_node in extra_nodes {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(ext) = self.external_output_buffers.get(&extra_node) {
|
||||
e.insert(&**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
e.insert(buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
e.insert(&**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Resolve output aliases
|
||||
for (&alias_node, &alias_target) in &bucket.output_alias_map {
|
||||
if !buffer_map.contains_key(&alias_node) {
|
||||
continue;
|
||||
}
|
||||
// Try HLIR buffer first (includes external device pointers)
|
||||
let resolved: Option<&CudaSlice<u8>> =
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => Some(buf),
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
self.external_buffers.get(hlir_node).map(|ext| &**ext)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(buf) = resolved {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node)
|
||||
&& let Some(buf) = Self::resolve_runtime_buffer(
|
||||
bucket,
|
||||
&self.cuda_stream,
|
||||
&self.hlir_buffers,
|
||||
&self.external_buffers,
|
||||
&self.external_output_buffers,
|
||||
extra_node,
|
||||
)
|
||||
{
|
||||
e.insert(buf);
|
||||
}
|
||||
}
|
||||
let _span = span!(
|
||||
@@ -1253,13 +1573,7 @@ impl Runtime for CudaRuntime {
|
||||
for (bucket_indices, representative_dyn_map, llir) in bucket_llirs {
|
||||
let mut bucket = self.compile_bucket(llir);
|
||||
bucket.bucket_indices = bucket_indices.clone();
|
||||
// Eagerly allocate intermediate buffers using the representative dyn_map
|
||||
bucket.last_dyn_map = representative_dyn_map.clone();
|
||||
Self::allocate_intermediate_buffers(
|
||||
&mut bucket,
|
||||
&self.cuda_stream,
|
||||
representative_dyn_map,
|
||||
);
|
||||
let _ = representative_dyn_map;
|
||||
self.compiled_buckets.push(bucket);
|
||||
}
|
||||
self.active_bucket = 0;
|
||||
@@ -1361,7 +1675,8 @@ impl CudaRuntime {
|
||||
}
|
||||
}
|
||||
});
|
||||
let allocated = !is_marker || has_external_consumer;
|
||||
let allocated = kernel_op.output_aliases_input().is_none()
|
||||
&& (!is_marker || has_external_consumer);
|
||||
if allocated {
|
||||
bucket.buffer_specs.insert(
|
||||
node,
|
||||
|
||||
@@ -41,9 +41,8 @@ fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
all_names
|
||||
}
|
||||
|
||||
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
|
||||
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
|
||||
/// the ConsumedBuffer (not in any other ICons).
|
||||
/// When dest is NOT shared with any other compute op, KernelScatterNoCopy should
|
||||
/// be the only scatter variant left after post-cleanup.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -62,12 +61,17 @@ fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should be available (dest is not shared)
|
||||
// KernelScatterNoCopy should be the only scatter variant (dest is not shared)
|
||||
assert!(
|
||||
names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy to be available but got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "Scatter"),
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid, got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
|
||||
@@ -109,8 +113,42 @@ fn test_scatter_nocopy_not_selected_when_dest_shared() {
|
||||
);
|
||||
}
|
||||
|
||||
/// Shared-use detection must catch the destination in non-first input
|
||||
/// positions too. Gather takes indexes first and data second, so this would
|
||||
/// miss the unsafe read if cleanup only inspected the head of the input list.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_when_dest_shared_as_later_input() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let scatter_indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
let read_indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
let scatter_result = src.scatter(scatter_indexes, dest);
|
||||
let _dest_also_read = dest.gather(read_indexes).output();
|
||||
let _result = scatter_result.output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest is read by another op, got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected regular Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// Actually execute the scatter and verify correctness.
|
||||
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
|
||||
/// Post-cleanup should force the valid no-copy extraction.
|
||||
#[test]
|
||||
fn test_scatter_execution_correctness() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -135,9 +173,8 @@ fn test_scatter_execution_correctness() {
|
||||
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
|
||||
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
|
||||
|
||||
// Try many random extractions to cover both Scatter and ScatterNoCopy
|
||||
// Try many random extractions; each valid choice should now use ScatterNoCopy.
|
||||
let mut rng = rand::rng();
|
||||
let mut tested_scatter = false;
|
||||
let mut tested_nocopy = false;
|
||||
|
||||
for _ in 0..50 {
|
||||
@@ -180,27 +217,24 @@ fn test_scatter_execution_correctness() {
|
||||
|
||||
let actual = rt.get_f32(result);
|
||||
|
||||
let variant = if has_nocopy {
|
||||
tested_nocopy = true;
|
||||
"ScatterNoCopy"
|
||||
} else if has_scatter {
|
||||
tested_scatter = true;
|
||||
"Scatter"
|
||||
} else {
|
||||
"Unknown"
|
||||
};
|
||||
assert!(
|
||||
has_nocopy,
|
||||
"Expected ScatterNoCopy after post-cleanup, got no no-copy scatter"
|
||||
);
|
||||
assert!(
|
||||
!has_scatter,
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid"
|
||||
);
|
||||
tested_nocopy = true;
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
|
||||
"Scatter result mismatch with ScatterNoCopy: got {:?}, expected {:?}",
|
||||
actual, expected
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
|
||||
tested_scatter, tested_nocopy
|
||||
);
|
||||
println!("Tested ScatterNoCopy: {}", tested_nocopy);
|
||||
assert!(
|
||||
tested_nocopy,
|
||||
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
|
||||
@@ -242,12 +276,28 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Print which scatter variant was selected
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Selected: {name}");
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
scatter_names.contains(&"ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy in KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
assert!(
|
||||
!scatter_names.contains(&"Scatter"),
|
||||
"Regular Scatter should be pruned from KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
@@ -342,17 +392,31 @@ fn test_scatter_dual_cache() {
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
// Use seeded search for deterministic variant selection.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Dual test selected: {name}");
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
!scatter_names.is_empty(),
|
||||
"Expected scatter kernels in dual-cache search result"
|
||||
);
|
||||
assert!(
|
||||
scatter_names.iter().all(|name| *name == "ScatterNoCopy"),
|
||||
"Expected only ScatterNoCopy in dual-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
|
||||
@@ -52,6 +52,9 @@ fn main() {
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(500),
|
||||
|
||||
@@ -275,7 +275,7 @@ fn main() {
|
||||
let weights_path = artifact_dir.join("weights.safetensors");
|
||||
let cli = cli_args(&artifact_dir);
|
||||
let image_path = cli.image_path.clone();
|
||||
let search_graphs = 1usize;
|
||||
let search_graphs = 50usize;
|
||||
|
||||
println!("Using artifact directory: {}", artifact_dir.display());
|
||||
|
||||
|
||||
@@ -444,6 +444,7 @@ pub fn base_expression_egglog() -> String {
|
||||
p.add_ruleset("expr");
|
||||
p.add_ruleset("dtype_prop");
|
||||
p.add_ruleset("cleanup");
|
||||
p.add_ruleset("post_cleanup");
|
||||
|
||||
// Register all sorts
|
||||
s.register(&mut p);
|
||||
|
||||
@@ -9,6 +9,8 @@ use std::hash::{Hash, Hasher};
|
||||
use std::{str, sync::Arc, time::Duration};
|
||||
use tracing::trace;
|
||||
|
||||
pub use egraph_serialize::{ClassId, NodeId};
|
||||
|
||||
pub mod api;
|
||||
pub mod base;
|
||||
|
||||
@@ -36,7 +38,9 @@ struct EgglogSchedulePhase {
|
||||
schedule: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub type EGraphPostprocess = Arc<dyn Fn(&mut SerializedEGraph) + Send + Sync + 'static>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LateEgglogPass {
|
||||
/// Egglog declarations and rules for a backend-provided late pass.
|
||||
///
|
||||
@@ -49,6 +53,19 @@ pub struct LateEgglogPass {
|
||||
/// Backends can use this for analysis-only layers or for analysis followed
|
||||
/// by backend-specific cleanup rules.
|
||||
pub schedule: String,
|
||||
/// Optional Rust post-processing hook that runs on the serialized e-graph
|
||||
/// after egglog has finished and before the final empty-eclass cascade.
|
||||
pub postprocess: Option<EGraphPostprocess>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for LateEgglogPass {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("LateEgglogPass")
|
||||
.field("program", &self.program)
|
||||
.field("schedule", &self.schedule)
|
||||
.field("has_postprocess", &self.postprocess.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl LateEgglogPass {
|
||||
@@ -56,8 +73,17 @@ impl LateEgglogPass {
|
||||
Self {
|
||||
program: program.into(),
|
||||
schedule: schedule.into(),
|
||||
postprocess: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_postprocess(
|
||||
mut self,
|
||||
postprocess: impl Fn(&mut SerializedEGraph) + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
self.postprocess = Some(Arc::new(postprocess));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
@@ -130,6 +156,7 @@ pub struct OpTextParts {
|
||||
late_program: String,
|
||||
rewrites: String,
|
||||
late_phases: Vec<EgglogSchedulePhase>,
|
||||
late_postprocesses: Vec<EGraphPostprocess>,
|
||||
}
|
||||
|
||||
impl OpTextParts {
|
||||
@@ -181,6 +208,10 @@ impl OpTextParts {
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
late_postprocesses: late_passes
|
||||
.iter()
|
||||
.filter_map(|pass| pass.postprocess.clone())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,6 +261,10 @@ fn egglog_final_phases() -> Vec<EgglogSchedulePhase> {
|
||||
name: "cleanup".to_string(),
|
||||
schedule: "(saturate cleanup)".to_string(),
|
||||
},
|
||||
EgglogSchedulePhase {
|
||||
name: "post cleanup".to_string(),
|
||||
schedule: "(saturate post_cleanup)".to_string(),
|
||||
},
|
||||
EgglogSchedulePhase {
|
||||
name: "base cleanup".to_string(),
|
||||
schedule: "(saturate base_cleanup)".to_string(),
|
||||
@@ -297,9 +332,8 @@ use crate::{
|
||||
};
|
||||
use egglog::{ArcSort, CommandOutput, EGraph, Value};
|
||||
use egglog_reports::ReportLevel;
|
||||
use egraph_serialize::{ClassId, NodeId};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
/// This is snapshot of an EGraph with Rust native hash maps and sets for enabling more native traversal / algorithm writing.
|
||||
/// The name comes from the serialize egraph crates, which returns a ETermDAG, which caused issues, so this is a homebrew semi-static egraph
|
||||
pub struct SerializedEGraph {
|
||||
@@ -1909,6 +1943,20 @@ pub fn run_egglog_with_report_parts(
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
static WARNED_DEBUG_EGGLOG: AtomicBool = AtomicBool::new(false);
|
||||
if !WARNED_DEBUG_EGGLOG.swap(true, Ordering::Relaxed) {
|
||||
eprintln!(
|
||||
"{}",
|
||||
" Egglog warning: running in a debug build; model-sized saturation can be very slow. Use `cargo run --release ...` for CUDA model examples."
|
||||
.yellow()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
let full_start = std::time::Instant::now();
|
||||
@@ -2178,6 +2226,10 @@ pub fn run_egglog_with_report_parts(
|
||||
}
|
||||
}
|
||||
|
||||
for postprocess in &op_parts.late_postprocesses {
|
||||
postprocess(&mut egraph);
|
||||
}
|
||||
|
||||
// Cascade: remove enodes whose children reference empty eclasses
|
||||
loop {
|
||||
let mut to_remove = vec![];
|
||||
|
||||
56
src/graph.rs
56
src/graph.rs
@@ -1030,7 +1030,8 @@ impl Graph {
|
||||
let mut ops = Rt::Ops::into_vec();
|
||||
ops.extend(<crate::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup_hlir = TypeId::of::<Rt>() != TypeId::of::<NativeRuntime>();
|
||||
let late_passes = Rt::late_egglog_passes(&ops, &options);
|
||||
let memory_dyn_map = self.memory_limit_dyn_map();
|
||||
let late_passes = Rt::late_egglog_passes(&ops, &options, &memory_dyn_map);
|
||||
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![
|
||||
@@ -1074,7 +1075,8 @@ impl Graph {
|
||||
ops.retain(|o| !exclude_ops.contains(&o.sort().name));
|
||||
ops.extend(<crate::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup_hlir = TypeId::of::<Rt>() != TypeId::of::<NativeRuntime>();
|
||||
let late_passes = Rt::late_egglog_passes(&ops, &options);
|
||||
let memory_dyn_map = self.memory_limit_dyn_map();
|
||||
let late_passes = Rt::late_egglog_passes(&ops, &options, &memory_dyn_map);
|
||||
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![
|
||||
@@ -1176,6 +1178,16 @@ impl Graph {
|
||||
combos
|
||||
}
|
||||
|
||||
fn memory_limit_dyn_map(&self) -> FxHashMap<char, usize> {
|
||||
let mut dyn_map = self.dyn_map.clone();
|
||||
for (&dim, buckets) in &self.dim_buckets {
|
||||
if let Some(max) = buckets.iter().map(|bucket| bucket.max).max() {
|
||||
dyn_map.insert(dim, max);
|
||||
}
|
||||
}
|
||||
dyn_map
|
||||
}
|
||||
|
||||
/// Format a human-readable label for a bucket combination.
|
||||
fn format_bucket_label(&self, bucket_indices: &FxHashMap<char, usize>) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
@@ -1309,7 +1321,12 @@ impl Graph {
|
||||
let has_nan = runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
(
|
||||
rep_metric,
|
||||
append_memory_display(rep_display, memory_bytes),
|
||||
append_memory_display(
|
||||
rep_display,
|
||||
memory_bytes,
|
||||
runtime.planned_intermediate_buffer_bytes(),
|
||||
runtime.allocated_intermediate_buffer_bytes(),
|
||||
),
|
||||
has_nan,
|
||||
)
|
||||
}));
|
||||
@@ -1421,7 +1438,12 @@ impl Graph {
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
(
|
||||
rep_metric,
|
||||
append_memory_display(rep_display, memory_bytes),
|
||||
append_memory_display(
|
||||
rep_display,
|
||||
memory_bytes,
|
||||
runtime.planned_intermediate_buffer_bytes(),
|
||||
runtime.allocated_intermediate_buffer_bytes(),
|
||||
),
|
||||
has_nan,
|
||||
)
|
||||
}));
|
||||
@@ -1610,11 +1632,27 @@ fn stable_toposort_by_node_index(graph: &HLIRGraph) -> Option<Vec<NodeIndex>> {
|
||||
(ordered.len() == graph.node_count()).then_some(ordered)
|
||||
}
|
||||
|
||||
fn append_memory_display(display: String, memory_bytes: Option<usize>) -> String {
|
||||
let Some(bytes) = memory_bytes else {
|
||||
return display;
|
||||
};
|
||||
format!("{display} | MEM: {}", format_memory_bytes(bytes))
|
||||
fn append_memory_display(
|
||||
display: String,
|
||||
estimate_bytes: Option<usize>,
|
||||
planned_bytes: Option<usize>,
|
||||
allocated_bytes: Option<usize>,
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
if let Some(bytes) = estimate_bytes {
|
||||
parts.push(format!("EST: {}", format_memory_bytes(bytes)));
|
||||
}
|
||||
if let Some(bytes) = planned_bytes {
|
||||
parts.push(format!("PLAN: {}", format_memory_bytes(bytes)));
|
||||
}
|
||||
if let Some(bytes) = allocated_bytes {
|
||||
parts.push(format!("ALLOC: {}", format_memory_bytes(bytes)));
|
||||
}
|
||||
if parts.is_empty() {
|
||||
display
|
||||
} else {
|
||||
format!("{display} | {}", parts.join(" | "))
|
||||
}
|
||||
}
|
||||
|
||||
fn format_memory_bytes(bytes: usize) -> String {
|
||||
|
||||
@@ -19,6 +19,7 @@ pub trait Runtime {
|
||||
fn late_egglog_passes(
|
||||
_ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
_options: &crate::graph::BuildSearchSpaceOptions,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<crate::egglog_utils::LateEgglogPass>
|
||||
where
|
||||
Self: Sized,
|
||||
@@ -57,6 +58,14 @@ pub trait Runtime {
|
||||
fn intermediate_buffer_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
/// Total bytes in the active runtime memory plan, if the runtime has one.
|
||||
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Total active intermediate allocation bytes, if the runtime can report it.
|
||||
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Check if the most recent execution produced NaN in any output buffer.
|
||||
/// Used by the search to reject NaN-producing graph variants.
|
||||
fn has_nan_outputs(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user