Compare commits

...

1 Commits

Author SHA1 Message Date
Joe Fioti
62e86f9dc5 Reuse cuBLASLt prepares across matching graph ops 2026-06-01 00:25:30 +00:00
10 changed files with 3692 additions and 434 deletions

View File

@@ -1,4 +1,7 @@
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
use half::{bf16, f16};
use luminal::{
@@ -15,6 +18,8 @@ use luminal::{
},
};
#[cfg(test)]
use crate::kernel::CudaGraphHandle;
use crate::{
cudarc::{
cublas::sys::cublasOperation_t,
@@ -33,7 +38,7 @@ use crate::{
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
},
},
driver::{CudaStream, DevicePtr},
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::{DeviceBuffer, HostOp},
try_create_cublaslt,
@@ -581,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
}
}
#[derive(Debug, Clone, Copy)]
enum LtScalar {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum LtScalar {
F64(f64),
F32(f32),
F16(f16),
@@ -622,16 +627,16 @@ impl LtScalar {
}
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulProblem {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulProblem {
m: u64,
n: u64,
k: u64,
batch_count: i32,
}
#[derive(Debug, Clone, Copy)]
struct LtMatrixSpec {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatrixSpec {
dtype: cudaDataType,
rows: u64,
cols: u64,
@@ -640,8 +645,8 @@ struct LtMatrixSpec {
order: cublasLtOrder_t,
}
#[derive(Debug, Clone, Copy)]
struct LtComputeSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtComputeSpec {
compute_type: cublasComputeType_t,
scale_dtype: cudaDataType,
alpha: LtScalar,
@@ -649,8 +654,8 @@ struct LtComputeSpec {
epilogue: cublasLtEpilogue_t,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtMatmulSpec {
problem: LtMatmulProblem,
trans_a: cublasOperation_t,
trans_b: cublasOperation_t,
@@ -662,8 +667,8 @@ struct LtMatmulSpec {
workspace_size: usize,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulPointers {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulPointers {
a: u64,
b: u64,
c: u64,
@@ -673,7 +678,35 @@ struct LtMatmulPointers {
b_scale: Option<u64>,
}
struct LtRawDescriptors {
impl LtMatmulPointers {
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
let mut fields = Vec::new();
if self.a != other.a {
fields.push("a");
}
if self.b != other.b {
fields.push("b");
}
if self.c != other.c {
fields.push("c");
}
if self.d != other.d {
fields.push("d");
}
if self.bias != other.bias {
fields.push("bias");
}
if self.a_scale != other.a_scale {
fields.push("a_scale");
}
if self.b_scale != other.b_scale {
fields.push("b_scale");
}
fields
}
}
pub(crate) struct LtRawDescriptors {
matmul_desc: cublasLtMatmulDesc_t,
a_desc: cublasLtMatrixLayout_t,
b_desc: cublasLtMatrixLayout_t,
@@ -682,6 +715,23 @@ struct LtRawDescriptors {
preference: cublasLtMatmulPreference_t,
}
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
> = OnceLock::new();
#[cfg(test)]
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
}
#[cfg(test)]
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
}
impl Default for LtRawDescriptors {
fn default() -> Self {
Self {
@@ -720,6 +770,121 @@ impl Drop for LtRawDescriptors {
}
}
pub(crate) struct PreparedCuBlasLtMatmul {
cublaslt: Arc<CudaBlasLT>,
spec: LtMatmulSpec,
resources: LtRawDescriptors,
heuristic: cublasLtMatmulHeuristicResult_t,
_workspace: CudaSlice<u8>,
workspace_ptr: u64,
_a_scale: Option<CudaSlice<f32>>,
default_a_scale_ptr: Option<u64>,
_b_scale: Option<CudaSlice<f32>>,
default_b_scale_ptr: Option<u64>,
_c_scale: Option<CudaSlice<f32>>,
_d_scale: Option<CudaSlice<f32>>,
}
impl PreparedCuBlasLtMatmul {
fn update_descriptor_pointers(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
stream.context().bind_to_thread()?;
if let Some(bias_ptr) = ptrs.bias {
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias_ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
ptr,
)?;
}
Ok(())
}
pub(crate) fn enqueue(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
self.update_descriptor_pointers(stream, ptrs)?;
let alpha_ptr = self.spec.compute.alpha.as_ptr();
let beta_ptr = self.spec.compute.beta.as_ptr();
unsafe {
cublasLtMatmul(
*self.cublaslt.handle(),
self.resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
self.resources.a_desc,
ptrs.b as *const std::ffi::c_void,
self.resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
self.resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
self.resources.d_desc,
&self.heuristic.algo,
self.workspace_ptr as *mut std::ffi::c_void,
self.spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtCaptureSignature {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtPrepareKey {
spec: LtMatmulSpec,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtResolvedGraphCall {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
impl CuBlasLtResolvedGraphCall {
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
CuBlasLtCaptureSignature {
spec: self.spec,
ptrs: self.ptrs,
}
}
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
CuBlasLtPrepareKey { spec: self.spec }
}
}
fn create_matrix_layout(
desc: &mut cublasLtMatrixLayout_t,
spec: LtMatrixSpec,
@@ -796,12 +961,15 @@ fn set_scalar_scale_pointer(
Ok(())
}
fn run_cublaslt_matmul(
pub(crate) fn prepare_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
#[cfg(test)]
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
return Err(anyhow::anyhow!(
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
@@ -813,17 +981,17 @@ fn run_cublaslt_matmul(
let mut resources = LtRawDescriptors::default();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
drop(workspace_guard);
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
};
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
@@ -879,29 +1047,27 @@ fn run_cublaslt_matmul(
}
}
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
(Some(ptr), None)
} else if let Some(scale) = &a_scale {
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
(Some(ptr), None)
} else if let Some(scale) = &b_scale {
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
@@ -935,6 +1101,7 @@ fn run_cublaslt_matmul(
ptr,
)?;
}
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
create_matrix_layout(&mut resources.a_desc, spec.a)?;
create_matrix_layout(&mut resources.b_desc, spec.b)?;
@@ -952,58 +1119,148 @@ fn run_cublaslt_matmul(
}
}
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
let cached_heuristic = {
let cache = heuristic_cache.lock().unwrap();
cache
.iter()
.find(|(cached_spec, _)| cached_spec == spec)
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
};
if let Some(cached) = cached_heuristic {
heuristic = cached;
} else {
let mut algo_count: i32 = 0;
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
}
let alpha_ptr = spec.compute.alpha.as_ptr();
let beta_ptr = spec.compute.beta.as_ptr();
cublasLtMatmul(
*cublaslt.handle(),
resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
resources.a_desc,
ptrs.b as *const std::ffi::c_void,
resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
resources.d_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
heuristic_cache
.lock()
.unwrap()
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
}
Ok(())
Ok(PreparedCuBlasLtMatmul {
cublaslt: cublaslt.clone(),
spec: *spec,
resources,
heuristic,
_workspace: workspace,
workspace_ptr,
_a_scale: a_scale,
default_a_scale_ptr,
_b_scale: b_scale,
default_b_scale_ptr,
_c_scale: c_scale,
_d_scale: d_scale,
})
}
fn run_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
prepared.enqueue(stream, ptrs)
}
#[cfg(test)]
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
let capture_stream = stream.context().new_stream()?;
let cublaslt = try_create_cublaslt(stream.clone())
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
let a_buf = stream.clone_htod(&[1.0f32])?;
let b_buf = stream.clone_htod(&[1.0f32])?;
let d_buf = unsafe { stream.alloc::<f32>(1)? };
let (a, a_guard) = a_buf.device_ptr(stream);
let (b, b_guard) = b_buf.device_ptr(stream);
let (d, d_guard) = d_buf.device_ptr(stream);
drop((a_guard, b_guard, d_guard));
let matrix = LtMatrixSpec {
dtype: cudaDataType::CUDA_R_32F,
rows: 1,
cols: 1,
ld: 1,
batch_stride: 1,
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m: 1,
n: 1,
k: 1,
batch_count: 1,
},
trans_a: cublasOperation_t::CUBLAS_OP_N,
trans_b: cublasOperation_t::CUBLAS_OP_N,
a: matrix,
b: matrix,
c: matrix,
d: matrix,
compute: LtComputeSpec {
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
scale_dtype: cudaDataType::CUDA_R_32F,
alpha: LtScalar::F32(1.0),
beta: LtScalar::F32(0.0),
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
},
workspace_size: 1024 * 1024,
};
let ptrs = LtMatmulPointers {
a,
b,
c: d,
d,
bias: None,
a_scale: None,
b_scale: None,
};
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
let entry = graph.add_empty_node(&[])?;
capture_stream.join(stream)?;
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
let end_result = graph.end_capture(&capture_stream);
enqueue_result?;
end_result?;
Ok(())
}
let supported = probe(stream).is_ok();
let _ = stream.synchronize();
supported
}
fn resolve_cublaslt_pointers(
@@ -1126,6 +1383,151 @@ impl CuBlasLt {
Ok(created)
}
pub(crate) fn graph_inputs(&self) -> usize {
self.n_inputs()
}
pub(crate) fn resolve_for_graph(
&self,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
let element_size = (self.d_dtype.bits() / 8) as u64;
assert!(
element_size > 0,
"cuBLAS LT does not support sub-byte dtype {}",
self.d_dtype
);
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
let ptrs = resolve_cublaslt_pointers(
self_node,
inputs,
buffers,
self.beta,
self.epilogue,
self.a_scale_input,
self.b_scale_input,
)?;
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
let _span = span!(
Level::TRACE,
"cuBLASLT_resolve_graph",
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
?self.epilogue,
)
.entered();
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
let c_spec = LtMatrixSpec {
dtype: c_cuda_dtype,
rows: m,
cols: n,
ld: ldc,
batch_stride: stride_c,
order: self.c_order,
};
let d_spec = LtMatrixSpec {
dtype: d_cuda_dtype,
rows: m,
cols: n,
ld: ldd,
batch_stride: stride_d,
order: self.d_order,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m,
n,
k,
batch_count,
},
trans_a: a_layout,
trans_b: b_layout,
a: LtMatrixSpec {
dtype: a_cuda_dtype,
rows: a_rows,
cols: a_cols,
ld: lda,
batch_stride: stride_a,
order: self.a_order,
},
b: LtMatrixSpec {
dtype: b_cuda_dtype,
rows: b_rows,
cols: b_cols,
ld: ldb,
batch_stride: stride_b,
order: self.b_order,
},
c: c_spec,
d: d_spec,
compute: LtComputeSpec {
compute_type: self.compute_type,
scale_dtype: scale_cuda_dtype,
alpha,
beta,
epilogue: self.epilogue,
},
workspace_size: WORKSPACE_SIZE,
};
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
}
pub(crate) fn prepare_resolved_for_graph(
&self,
stream: &Arc<CudaStream>,
resolved: CuBlasLtResolvedGraphCall,
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
let cublaslt = self.get_cublaslt(stream)?;
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
}
#[cfg(test)]
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
(

View File

@@ -2,7 +2,7 @@ use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use luminal::{op::EgglogOp, prelude::*};
mod cublaslt;
pub(crate) mod cublaslt;
pub mod flashinfer;
pub mod moe;
@@ -167,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
None
}
/// Returns pairs of extra buffer nodes that must not share arena storage.
///
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
/// two buffers may have disjoint positions in one topological order while
/// still being unordered by real dependencies, so CUDA could overlap them.
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
None
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.

View File

@@ -7,7 +7,10 @@ use std::sync::Arc;
use cudarc::driver::{
CudaContext, CudaFunction, CudaStream, DriverError,
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
sys::{
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
},
};
/// A CUDA graph that can be modified and instantiated.
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
}
}
/// Updates a kernel node in the mutable source graph.
pub unsafe fn set_kernel_node_params(
&mut self,
node: CUgraphNode,
func: CUfunction,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
shared_mem_bytes: u32,
kernel_params: *mut *mut c_void,
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let params = sys::CUDA_KERNEL_NODE_PARAMS {
func,
gridDimX: grid_dim.0,
gridDimY: grid_dim.1,
gridDimZ: grid_dim.2,
blockDimX: block_dim.0,
blockDimY: block_dim.1,
blockDimZ: block_dim.2,
sharedMemBytes: shared_mem_bytes,
kernelParams: kernel_params,
extra: std::ptr::null_mut(),
kern: std::ptr::null_mut(),
ctx: std::ptr::null_mut(),
};
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, &params).result() }
}
/// Adds an empty dependency node to the graph.
pub fn add_empty_node(
&mut self,
dependencies: &[CUgraphNode],
) -> Result<CUgraphNode, DriverError> {
self.ctx.bind_to_thread()?;
let mut node = MaybeUninit::uninit();
unsafe {
sys::cuGraphAddEmptyNode(
node.as_mut_ptr(),
self.cu_graph,
dependencies.as_ptr(),
dependencies.len(),
)
.result()?;
Ok(node.assume_init())
}
}
/// Destroys a node in the mutable graph.
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { sys::cuGraphDestroyNode(node).result() }
}
/// Adds dependency edges to the mutable graph.
pub fn add_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Removes dependency edges from the mutable graph.
pub fn remove_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Returns all nodes currently in the graph.
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut nodes = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
}
nodes.truncate(count);
Ok(nodes)
}
/// Returns the direct dependencies of a node.
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Returns the direct dependents of a node.
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Begins stream capture that appends captured work into this graph.
pub fn begin_capture_to_graph(
&mut self,
stream: &CudaStream,
dependencies: &[CUgraphNode],
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe {
sys::cuStreamBeginCaptureToGraph(
stream.cu_stream(),
self.cu_graph,
dependencies.as_ptr(),
std::ptr::null(),
dependencies.len(),
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
)
.result()
}
}
/// Ends stream capture previously started by begin_capture_to_graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut graph = MaybeUninit::uninit();
unsafe {
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
let captured = graph.assume_init();
if captured != self.cu_graph && !captured.is_null() {
sys::cuGraphDestroy(captured).result()?;
}
}
Ok(())
}
/// Adds an event record node to the graph for timing.
pub fn add_event_record_node(
&mut self,
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, &params) }
.result()
}
/// Attempts to update this executable graph from an already-mutated source graph.
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut result = CUgraphExecUpdateResultInfo {
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
errorNode: std::ptr::null_mut(),
errorFromNode: std::ptr::null_mut(),
};
unsafe {
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
}
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
return Err(DriverError(
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
));
}
Ok(())
}
}
impl Drop for CudaGraphExecHandle {
@@ -480,6 +672,38 @@ mod tests {
assert_eq!(result[0], 6.0f32);
}
#[test]
fn test_graph_empty_node_dependency_reconnect() {
let Ok(ctx) = CudaContext::new(0) else { return };
let mut graph = CudaGraphHandle::new(ctx).unwrap();
let entry = graph.add_empty_node(&[]).unwrap();
let middle = graph.add_empty_node(&[entry]).unwrap();
let exit = graph.add_empty_node(&[middle]).unwrap();
let nodes = graph.nodes().unwrap();
assert!(nodes.contains(&entry));
assert!(nodes.contains(&middle));
assert!(nodes.contains(&exit));
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
graph.add_dependencies(&[entry], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert!(exit_deps.contains(&entry));
assert!(exit_deps.contains(&middle));
graph.remove_dependencies(&[middle], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert_eq!(exit_deps.len(), 1);
assert!(exit_deps.contains(&entry));
unsafe {
graph.destroy_node(middle).unwrap();
}
assert!(!graph.nodes().unwrap().contains(&middle));
}
// CUDA Graph Tests
#[test]

View File

@@ -304,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
// Kernel to host op compilation
mod to_host;
#[cfg(test)]
pub(crate) use to_host::CudaGraphDebugSummary;
pub use to_host::{CudaGraphOp, kernel_to_host};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,12 @@ use luminal::{
prelude::*,
};
use rand::{SeedableRng, rngs::StdRng};
use std::sync::Arc;
use crate::{
host::{
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
cublaslt_type_tuple,
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
expected
}
fn reference_mixed_chain(
a: &[f32],
pre: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> Vec<f32> {
let mut expected = vec![0.0; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0.0;
for inner in 0..k {
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
}
expected[row * n + col] = acc.exp();
}
}
expected
}
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
crate::try_create_cublaslt(stream.clone()).is_ok()
}
fn build_mixed_chain_graph(
m: impl Into<Expression>,
n: usize,
k: usize,
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
let m = m.into();
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let pre = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = ((a + pre).matmul(b).exp()).output();
(cx, a.id, pre.id, b.id, out.id)
}
fn add_in_place(values: &mut [f32], addends: &[f32]) {
for (value, addend) in values.iter_mut().zip(addends) {
*value += *addend;
@@ -507,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
}
}
#[test]
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
});
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
let summaries = rt.debug_cuda_graph_summaries();
let mixed = summaries
.iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (5, 8, 8);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let first = a.matmul(b);
let out = (a + first.sin()).matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"ordered matching cuBLASLt prepared reuse",
|llir| {
let orders = cublaslt_matrix_order_tuples(llir);
orders.len() == 2 && orders[0] == orders[1]
},
);
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
let dep = a_data
.iter()
.zip(&first)
.map(|(a, first)| a + first.sin())
.collect::<Vec<_>>();
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 2)
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
assert_eq!(
summary.n_cublaslt_prepared, 1,
"dependency-ordered matching cuBLASLt calls should share prepared resources"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
a.output();
b.output();
cx.set_dim('p', 1);
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"cuBLASLt unchanged under unrelated dyn dim",
|_| true,
);
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
reset_cublaslt_prepare_count_for_test();
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let first_prepare_count = cublaslt_prepare_count_for_test();
assert!(
first_prepare_count > 0,
"first execution should prepare the captured cuBLASLt island"
);
cx.set_dim('p', 2);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count,
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('m', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('m', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [
(7usize, 0xC002_0001),
(9usize, 0xC002_0002),
(7usize, 0xC002_0003),
] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cublaslt_with_dynamic_c_spec_is_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('c', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('c', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
cx.set_dim('c', c);
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('s', 1);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
.into_iter()
.collect();
let bucket_llirs = vec![
(
[('s', 0usize)].into_iter().collect(),
[('s', 1usize)].into_iter().collect(),
llir.clone(),
),
(
[('s', 1usize)].into_iter().collect(),
[('s', 3usize)].into_iter().collect(),
llir,
),
];
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data.clone());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"singleton s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
cx.set_dim('s', 3);
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"range s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
reset_cublaslt_prepare_count_for_test();
let mut first_prepare_count = None;
for seed in [0xCC00_0001, 0xCC00_0002] {
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
if first_prepare_count.is_none() {
first_prepare_count = Some(cublaslt_prepare_count_for_test());
}
}
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count.unwrap(),
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
);
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after pointer recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
cx.set_dim('m', 7);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
}
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after dynamic-shape recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {

View File

@@ -109,7 +109,7 @@ fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
);
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
let mut choices = random_initial_choice(egraph, &mut rng);
let kernel_class = &egraph.node_to_class[*kernel_node];
choices.insert(kernel_class, kernel_node);

View File

@@ -82,6 +82,13 @@ struct SearchSpaceContext {
intervals: DynDimIntervals,
}
#[derive(Debug, Clone)]
struct SearchProfileBucketContext {
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
bucket_indices: FxHashMap<char, usize>,
representative_dyn_map: FxHashMap<char, usize>,
}
/// A compiled bucket: (bucket_indices, representative_dyn_map, stitched_llir).
pub type BucketLLIR = (FxHashMap<char, usize>, FxHashMap<char, usize>, LLIRGraph);
@@ -1314,6 +1321,7 @@ impl Graph {
rng,
&self.dyn_map.clone(),
None,
None,
0,
search_started_at,
);
@@ -1349,6 +1357,11 @@ impl Graph {
rng,
&context.representative_dyn_map,
Some((combo_idx, n_combos)),
Some(SearchProfileBucketContext {
dim_buckets: self.search_space_dim_buckets.clone(),
bucket_indices: context.bucket_indices.clone(),
representative_dyn_map: context.representative_dyn_map.clone(),
}),
combo_idx,
search_started_at,
);
@@ -1445,6 +1458,7 @@ impl Graph {
rng: &mut G,
dyn_map: &FxHashMap<char, usize>,
bucket_progress: Option<(usize, usize)>,
bucket_profile_context: Option<SearchProfileBucketContext>,
egraph_index: usize,
search_started_at: std::time::Instant,
) -> LLIRGraph {
@@ -1545,12 +1559,27 @@ impl Graph {
collapse_loops_to_first_iter(&mut graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let (rep_metric, rep_display) =
if let Some(bucket_context) = &bucket_profile_context {
runtime.profile_with_bucket_context(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
ProfileBucketContext {
dim_buckets: &bucket_context.dim_buckets,
bucket_indices: &bucket_context.bucket_indices,
representative_dyn_map: &bucket_context.representative_dyn_map,
},
)
} else {
runtime.profile(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
)
};
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
(
@@ -1664,12 +1693,27 @@ impl Graph {
collapse_loops_to_first_iter(&mut llir_graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let (rep_metric, rep_display) =
if let Some(bucket_context) = &bucket_profile_context {
runtime.profile_with_bucket_context(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
ProfileBucketContext {
dim_buckets: &bucket_context.dim_buckets,
bucket_indices: &bucket_context.bucket_indices,
representative_dyn_map: &bucket_context.representative_dyn_map,
},
)
} else {
runtime.profile(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
)
};
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan =
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);

View File

@@ -7,6 +7,12 @@ use crate::prelude::*;
use as_any::{AsAny, Downcast};
use rustc_hash::FxHashMap;
pub struct ProfileBucketContext<'a> {
pub dim_buckets: &'a FxHashMap<char, Vec<DimBucket>>,
pub bucket_indices: &'a FxHashMap<char, usize>,
pub representative_dyn_map: &'a FxHashMap<char, usize>,
}
pub trait Runtime {
type Ops: IntoEgglogOp;
type CompileArg;
@@ -36,6 +42,20 @@ pub trait Runtime {
trials: usize,
timeout: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String);
/// Profile one candidate in the context of a specific dynamic-dimension
/// bucket. Runtimes with bucket-sensitive lowering can override this so
/// search ranks candidates under the same execution model used after
/// final bucket compilation.
fn profile_with_bucket_context(
&mut self,
llir_graph: &LLIRGraph,
dyn_map: &FxHashMap<char, usize>,
trials: usize,
timeout: Option<std::time::Duration>,
_bucket_context: ProfileBucketContext<'_>,
) -> (Self::ProfileMetric, String) {
self.profile(llir_graph, dyn_map, trials, timeout)
}
/// Aggregate multiple profile metrics into one comparable metric.
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {