mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
main
...
flashinfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62e86f9dc5 |
@@ -1,4 +1,7 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use half::{bf16, f16};
|
||||
use luminal::{
|
||||
@@ -15,6 +18,8 @@ use luminal::{
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::kernel::CudaGraphHandle;
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
@@ -33,7 +38,7 @@ use crate::{
|
||||
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
@@ -581,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum LtScalar {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) enum LtScalar {
|
||||
F64(f64),
|
||||
F32(f32),
|
||||
F16(f16),
|
||||
@@ -622,16 +627,16 @@ impl LtScalar {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulProblem {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulProblem {
|
||||
m: u64,
|
||||
n: u64,
|
||||
k: u64,
|
||||
batch_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatrixSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatrixSpec {
|
||||
dtype: cudaDataType,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
@@ -640,8 +645,8 @@ struct LtMatrixSpec {
|
||||
order: cublasLtOrder_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtComputeSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtComputeSpec {
|
||||
compute_type: cublasComputeType_t,
|
||||
scale_dtype: cudaDataType,
|
||||
alpha: LtScalar,
|
||||
@@ -649,8 +654,8 @@ struct LtComputeSpec {
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtMatmulSpec {
|
||||
problem: LtMatmulProblem,
|
||||
trans_a: cublasOperation_t,
|
||||
trans_b: cublasOperation_t,
|
||||
@@ -662,8 +667,8 @@ struct LtMatmulSpec {
|
||||
workspace_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulPointers {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulPointers {
|
||||
a: u64,
|
||||
b: u64,
|
||||
c: u64,
|
||||
@@ -673,7 +678,35 @@ struct LtMatmulPointers {
|
||||
b_scale: Option<u64>,
|
||||
}
|
||||
|
||||
struct LtRawDescriptors {
|
||||
impl LtMatmulPointers {
|
||||
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
|
||||
let mut fields = Vec::new();
|
||||
if self.a != other.a {
|
||||
fields.push("a");
|
||||
}
|
||||
if self.b != other.b {
|
||||
fields.push("b");
|
||||
}
|
||||
if self.c != other.c {
|
||||
fields.push("c");
|
||||
}
|
||||
if self.d != other.d {
|
||||
fields.push("d");
|
||||
}
|
||||
if self.bias != other.bias {
|
||||
fields.push("bias");
|
||||
}
|
||||
if self.a_scale != other.a_scale {
|
||||
fields.push("a_scale");
|
||||
}
|
||||
if self.b_scale != other.b_scale {
|
||||
fields.push("b_scale");
|
||||
}
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LtRawDescriptors {
|
||||
matmul_desc: cublasLtMatmulDesc_t,
|
||||
a_desc: cublasLtMatrixLayout_t,
|
||||
b_desc: cublasLtMatrixLayout_t,
|
||||
@@ -682,6 +715,23 @@ struct LtRawDescriptors {
|
||||
preference: cublasLtMatmulPreference_t,
|
||||
}
|
||||
|
||||
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
|
||||
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
|
||||
> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
|
||||
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
|
||||
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
impl Default for LtRawDescriptors {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -720,6 +770,121 @@ impl Drop for LtRawDescriptors {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PreparedCuBlasLtMatmul {
|
||||
cublaslt: Arc<CudaBlasLT>,
|
||||
spec: LtMatmulSpec,
|
||||
resources: LtRawDescriptors,
|
||||
heuristic: cublasLtMatmulHeuristicResult_t,
|
||||
_workspace: CudaSlice<u8>,
|
||||
workspace_ptr: u64,
|
||||
_a_scale: Option<CudaSlice<f32>>,
|
||||
default_a_scale_ptr: Option<u64>,
|
||||
_b_scale: Option<CudaSlice<f32>>,
|
||||
default_b_scale_ptr: Option<u64>,
|
||||
_c_scale: Option<CudaSlice<f32>>,
|
||||
_d_scale: Option<CudaSlice<f32>>,
|
||||
}
|
||||
|
||||
impl PreparedCuBlasLtMatmul {
|
||||
fn update_descriptor_pointers(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.context().bind_to_thread()?;
|
||||
if let Some(bias_ptr) = ptrs.bias {
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias_ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
|
||||
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
|
||||
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn enqueue(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
self.update_descriptor_pointers(stream, ptrs)?;
|
||||
let alpha_ptr = self.spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = self.spec.compute.beta.as_ptr();
|
||||
unsafe {
|
||||
cublasLtMatmul(
|
||||
*self.cublaslt.handle(),
|
||||
self.resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
self.resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
self.resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
self.resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
self.resources.d_desc,
|
||||
&self.heuristic.algo,
|
||||
self.workspace_ptr as *mut std::ffi::c_void,
|
||||
self.spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtCaptureSignature {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtPrepareKey {
|
||||
spec: LtMatmulSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtResolvedGraphCall {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
impl CuBlasLtResolvedGraphCall {
|
||||
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
|
||||
CuBlasLtCaptureSignature {
|
||||
spec: self.spec,
|
||||
ptrs: self.ptrs,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
|
||||
CuBlasLtPrepareKey { spec: self.spec }
|
||||
}
|
||||
}
|
||||
|
||||
fn create_matrix_layout(
|
||||
desc: &mut cublasLtMatrixLayout_t,
|
||||
spec: LtMatrixSpec,
|
||||
@@ -796,12 +961,15 @@ fn set_scalar_scale_pointer(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
pub(crate) fn prepare_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
#[cfg(test)]
|
||||
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
|
||||
@@ -813,17 +981,17 @@ fn run_cublaslt_matmul(
|
||||
|
||||
let mut resources = LtRawDescriptors::default();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
|
||||
drop(workspace_guard);
|
||||
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
@@ -879,29 +1047,27 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &a_scale {
|
||||
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &b_scale {
|
||||
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
|
||||
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
|
||||
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
@@ -935,6 +1101,7 @@ fn run_cublaslt_matmul(
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
|
||||
|
||||
create_matrix_layout(&mut resources.a_desc, spec.a)?;
|
||||
create_matrix_layout(&mut resources.b_desc, spec.b)?;
|
||||
@@ -952,58 +1119,148 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
|
||||
let cached_heuristic = {
|
||||
let cache = heuristic_cache.lock().unwrap();
|
||||
cache
|
||||
.iter()
|
||||
.find(|(cached_spec, _)| cached_spec == spec)
|
||||
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
|
||||
};
|
||||
if let Some(cached) = cached_heuristic {
|
||||
heuristic = cached;
|
||||
} else {
|
||||
let mut algo_count: i32 = 0;
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
}
|
||||
|
||||
let alpha_ptr = spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = spec.compute.beta.as_ptr();
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
resources.d_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
heuristic_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(PreparedCuBlasLtMatmul {
|
||||
cublaslt: cublaslt.clone(),
|
||||
spec: *spec,
|
||||
resources,
|
||||
heuristic,
|
||||
_workspace: workspace,
|
||||
workspace_ptr,
|
||||
_a_scale: a_scale,
|
||||
default_a_scale_ptr,
|
||||
_b_scale: b_scale,
|
||||
default_b_scale_ptr,
|
||||
_c_scale: c_scale,
|
||||
_d_scale: d_scale,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
|
||||
prepared.enqueue(stream, ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
|
||||
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
|
||||
let capture_stream = stream.context().new_stream()?;
|
||||
let cublaslt = try_create_cublaslt(stream.clone())
|
||||
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
|
||||
|
||||
let a_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let b_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let d_buf = unsafe { stream.alloc::<f32>(1)? };
|
||||
let (a, a_guard) = a_buf.device_ptr(stream);
|
||||
let (b, b_guard) = b_buf.device_ptr(stream);
|
||||
let (d, d_guard) = d_buf.device_ptr(stream);
|
||||
drop((a_guard, b_guard, d_guard));
|
||||
|
||||
let matrix = LtMatrixSpec {
|
||||
dtype: cudaDataType::CUDA_R_32F,
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
ld: 1,
|
||||
batch_stride: 1,
|
||||
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m: 1,
|
||||
n: 1,
|
||||
k: 1,
|
||||
batch_count: 1,
|
||||
},
|
||||
trans_a: cublasOperation_t::CUBLAS_OP_N,
|
||||
trans_b: cublasOperation_t::CUBLAS_OP_N,
|
||||
a: matrix,
|
||||
b: matrix,
|
||||
c: matrix,
|
||||
d: matrix,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
scale_dtype: cudaDataType::CUDA_R_32F,
|
||||
alpha: LtScalar::F32(1.0),
|
||||
beta: LtScalar::F32(0.0),
|
||||
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
},
|
||||
workspace_size: 1024 * 1024,
|
||||
};
|
||||
let ptrs = LtMatmulPointers {
|
||||
a,
|
||||
b,
|
||||
c: d,
|
||||
d,
|
||||
bias: None,
|
||||
a_scale: None,
|
||||
b_scale: None,
|
||||
};
|
||||
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
|
||||
|
||||
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
|
||||
let entry = graph.add_empty_node(&[])?;
|
||||
capture_stream.join(stream)?;
|
||||
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
|
||||
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
|
||||
let end_result = graph.end_capture(&capture_stream);
|
||||
enqueue_result?;
|
||||
end_result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let supported = probe(stream).is_ok();
|
||||
let _ = stream.synchronize();
|
||||
supported
|
||||
}
|
||||
|
||||
fn resolve_cublaslt_pointers(
|
||||
@@ -1126,6 +1383,151 @@ impl CuBlasLt {
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
pub(crate) fn graph_inputs(&self) -> usize {
|
||||
self.n_inputs()
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_for_graph(
|
||||
&self,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
|
||||
|
||||
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
|
||||
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
|
||||
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
|
||||
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
|
||||
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
|
||||
let element_size = (self.d_dtype.bits() / 8) as u64;
|
||||
assert!(
|
||||
element_size > 0,
|
||||
"cuBLAS LT does not support sub-byte dtype {}",
|
||||
self.d_dtype
|
||||
);
|
||||
|
||||
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
|
||||
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
self_node,
|
||||
inputs,
|
||||
buffers,
|
||||
self.beta,
|
||||
self.epilogue,
|
||||
self.a_scale_input,
|
||||
self.b_scale_input,
|
||||
)?;
|
||||
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
|
||||
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
|
||||
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
|
||||
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT_resolve_graph",
|
||||
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
|
||||
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
|
||||
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
|
||||
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
|
||||
?self.epilogue,
|
||||
)
|
||||
.entered();
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
|
||||
let c_spec = LtMatrixSpec {
|
||||
dtype: c_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldc,
|
||||
batch_stride: stride_c,
|
||||
order: self.c_order,
|
||||
};
|
||||
let d_spec = LtMatrixSpec {
|
||||
dtype: d_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldd,
|
||||
batch_stride: stride_d,
|
||||
order: self.d_order,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_count,
|
||||
},
|
||||
trans_a: a_layout,
|
||||
trans_b: b_layout,
|
||||
a: LtMatrixSpec {
|
||||
dtype: a_cuda_dtype,
|
||||
rows: a_rows,
|
||||
cols: a_cols,
|
||||
ld: lda,
|
||||
batch_stride: stride_a,
|
||||
order: self.a_order,
|
||||
},
|
||||
b: LtMatrixSpec {
|
||||
dtype: b_cuda_dtype,
|
||||
rows: b_rows,
|
||||
cols: b_cols,
|
||||
ld: ldb,
|
||||
batch_stride: stride_b,
|
||||
order: self.b_order,
|
||||
},
|
||||
c: c_spec,
|
||||
d: d_spec,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: self.compute_type,
|
||||
scale_dtype: scale_cuda_dtype,
|
||||
alpha,
|
||||
beta,
|
||||
epilogue: self.epilogue,
|
||||
},
|
||||
workspace_size: WORKSPACE_SIZE,
|
||||
};
|
||||
|
||||
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_resolved_for_graph(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
resolved: CuBlasLtResolvedGraphCall,
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
|
||||
(
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublaslt;
|
||||
pub(crate) mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
@@ -167,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns pairs of extra buffer nodes that must not share arena storage.
|
||||
///
|
||||
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
|
||||
/// two buffers may have disjoint positions in one topological order while
|
||||
/// still being unordered by real dependencies, so CUDA could overlap them.
|
||||
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -7,7 +7,10 @@ use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaContext, CudaFunction, CudaStream, DriverError,
|
||||
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
|
||||
sys::{
|
||||
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
|
||||
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
|
||||
},
|
||||
};
|
||||
|
||||
/// A CUDA graph that can be modified and instantiated.
|
||||
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates a kernel node in the mutable source graph.
|
||||
pub unsafe fn set_kernel_node_params(
|
||||
&mut self,
|
||||
node: CUgraphNode,
|
||||
func: CUfunction,
|
||||
grid_dim: (u32, u32, u32),
|
||||
block_dim: (u32, u32, u32),
|
||||
shared_mem_bytes: u32,
|
||||
kernel_params: *mut *mut c_void,
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let params = sys::CUDA_KERNEL_NODE_PARAMS {
|
||||
func,
|
||||
gridDimX: grid_dim.0,
|
||||
gridDimY: grid_dim.1,
|
||||
gridDimZ: grid_dim.2,
|
||||
blockDimX: block_dim.0,
|
||||
blockDimY: block_dim.1,
|
||||
blockDimZ: block_dim.2,
|
||||
sharedMemBytes: shared_mem_bytes,
|
||||
kernelParams: kernel_params,
|
||||
extra: std::ptr::null_mut(),
|
||||
kern: std::ptr::null_mut(),
|
||||
ctx: std::ptr::null_mut(),
|
||||
};
|
||||
|
||||
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, ¶ms).result() }
|
||||
}
|
||||
|
||||
/// Adds an empty dependency node to the graph.
|
||||
pub fn add_empty_node(
|
||||
&mut self,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<CUgraphNode, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut node = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphAddEmptyNode(
|
||||
node.as_mut_ptr(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
dependencies.len(),
|
||||
)
|
||||
.result()?;
|
||||
Ok(node.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Destroys a node in the mutable graph.
|
||||
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe { sys::cuGraphDestroyNode(node).result() }
|
||||
}
|
||||
|
||||
/// Adds dependency edges to the mutable graph.
|
||||
pub fn add_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Removes dependency edges from the mutable graph.
|
||||
pub fn remove_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Returns all nodes currently in the graph.
|
||||
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut nodes = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
nodes.truncate(count);
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
/// Returns the direct dependencies of a node.
|
||||
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Returns the direct dependents of a node.
|
||||
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Begins stream capture that appends captured work into this graph.
|
||||
pub fn begin_capture_to_graph(
|
||||
&mut self,
|
||||
stream: &CudaStream,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuStreamBeginCaptureToGraph(
|
||||
stream.cu_stream(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
std::ptr::null(),
|
||||
dependencies.len(),
|
||||
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
|
||||
)
|
||||
.result()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ends stream capture previously started by begin_capture_to_graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut graph = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
|
||||
let captured = graph.assume_init();
|
||||
if captured != self.cu_graph && !captured.is_null() {
|
||||
sys::cuGraphDestroy(captured).result()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds an event record node to the graph for timing.
|
||||
pub fn add_event_record_node(
|
||||
&mut self,
|
||||
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
|
||||
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, ¶ms) }
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Attempts to update this executable graph from an already-mutated source graph.
|
||||
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut result = CUgraphExecUpdateResultInfo {
|
||||
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
|
||||
errorNode: std::ptr::null_mut(),
|
||||
errorFromNode: std::ptr::null_mut(),
|
||||
};
|
||||
unsafe {
|
||||
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
|
||||
}
|
||||
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
|
||||
return Err(DriverError(
|
||||
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphExecHandle {
|
||||
@@ -480,6 +672,38 @@ mod tests {
|
||||
assert_eq!(result[0], 6.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_empty_node_dependency_reconnect() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let mut graph = CudaGraphHandle::new(ctx).unwrap();
|
||||
|
||||
let entry = graph.add_empty_node(&[]).unwrap();
|
||||
let middle = graph.add_empty_node(&[entry]).unwrap();
|
||||
let exit = graph.add_empty_node(&[middle]).unwrap();
|
||||
|
||||
let nodes = graph.nodes().unwrap();
|
||||
assert!(nodes.contains(&entry));
|
||||
assert!(nodes.contains(&middle));
|
||||
assert!(nodes.contains(&exit));
|
||||
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
|
||||
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
|
||||
|
||||
graph.add_dependencies(&[entry], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert!(exit_deps.contains(&entry));
|
||||
assert!(exit_deps.contains(&middle));
|
||||
|
||||
graph.remove_dependencies(&[middle], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert_eq!(exit_deps.len(), 1);
|
||||
assert!(exit_deps.contains(&entry));
|
||||
|
||||
unsafe {
|
||||
graph.destroy_node(middle).unwrap();
|
||||
}
|
||||
assert!(!graph.nodes().unwrap().contains(&middle));
|
||||
}
|
||||
|
||||
// CUDA Graph Tests
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -304,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
|
||||
|
||||
// Kernel to host op compilation
|
||||
mod to_host;
|
||||
#[cfg(test)]
|
||||
pub(crate) use to_host::CudaGraphDebugSummary;
|
||||
pub use to_host::{CudaGraphOp, kernel_to_host};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -7,10 +7,12 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
host::{
|
||||
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
|
||||
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
|
||||
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
|
||||
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
|
||||
cublaslt_type_tuple,
|
||||
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
|
||||
expected
|
||||
}
|
||||
|
||||
fn reference_mixed_chain(
|
||||
a: &[f32],
|
||||
pre: &[f32],
|
||||
b: &[f32],
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut expected = vec![0.0; m * n];
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut acc = 0.0;
|
||||
for inner in 0..k {
|
||||
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
|
||||
}
|
||||
expected[row * n + col] = acc.exp();
|
||||
}
|
||||
}
|
||||
expected
|
||||
}
|
||||
|
||||
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
|
||||
crate::try_create_cublaslt(stream.clone()).is_ok()
|
||||
}
|
||||
|
||||
fn build_mixed_chain_graph(
|
||||
m: impl Into<Expression>,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let m = m.into();
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let pre = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = ((a + pre).matmul(b).exp()).output();
|
||||
(cx, a.id, pre.id, b.id, out.id)
|
||||
}
|
||||
|
||||
fn add_in_place(values: &mut [f32], addends: &[f32]) {
|
||||
for (value, addend) in values.iter_mut().zip(addends) {
|
||||
*value += *addend;
|
||||
@@ -507,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
|
||||
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
let mixed = summaries
|
||||
.iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
|
||||
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
|
||||
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
|
||||
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (5, 8, 8);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let first = a.matmul(b);
|
||||
let out = (a + first.sin()).matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"ordered matching cuBLASLt prepared reuse",
|
||||
|llir| {
|
||||
let orders = cublaslt_matrix_order_tuples(llir);
|
||||
orders.len() == 2 && orders[0] == orders[1]
|
||||
},
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
|
||||
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
let dep = a_data
|
||||
.iter()
|
||||
.zip(&first)
|
||||
.map(|(a, first)| a + first.sin())
|
||||
.collect::<Vec<_>>();
|
||||
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 2)
|
||||
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
|
||||
assert_eq!(
|
||||
summary.n_cublaslt_prepared, 1,
|
||||
"dependency-ordered matching cuBLASLt calls should share prepared resources"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
a.output();
|
||||
b.output();
|
||||
cx.set_dim('p', 1);
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"cuBLASLt unchanged under unrelated dyn dim",
|
||||
|_| true,
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let first_prepare_count = cublaslt_prepare_count_for_test();
|
||||
assert!(
|
||||
first_prepare_count > 0,
|
||||
"first execution should prepare the captured cuBLASLt island"
|
||||
);
|
||||
|
||||
cx.set_dim('p', 2);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count,
|
||||
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('m', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('m', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [
|
||||
(7usize, 0xC002_0001),
|
||||
(9usize, 0xC002_0002),
|
||||
(7usize, 0xC002_0003),
|
||||
] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_with_dynamic_c_spec_is_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('c', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('c', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
|
||||
cx.set_dim('c', c);
|
||||
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('s', 1);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
|
||||
|
||||
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let bucket_llirs = vec![
|
||||
(
|
||||
[('s', 0usize)].into_iter().collect(),
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
llir.clone(),
|
||||
),
|
||||
(
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
[('s', 3usize)].into_iter().collect(),
|
||||
llir,
|
||||
),
|
||||
];
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"singleton s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
|
||||
cx.set_dim('s', 3);
|
||||
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"range s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
let mut first_prepare_count = None;
|
||||
for seed in [0xCC00_0001, 0xCC00_0002] {
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
if first_prepare_count.is_none() {
|
||||
first_prepare_count = Some(cublaslt_prepare_count_for_test());
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count.unwrap(),
|
||||
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
|
||||
);
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after pointer recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
|
||||
cx.set_dim('m', 7);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after dynamic-shape recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {
|
||||
|
||||
@@ -109,7 +109,7 @@ fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
|
||||
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
68
src/graph.rs
68
src/graph.rs
@@ -82,6 +82,13 @@ struct SearchSpaceContext {
|
||||
intervals: DynDimIntervals,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SearchProfileBucketContext {
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
representative_dyn_map: FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
/// A compiled bucket: (bucket_indices, representative_dyn_map, stitched_llir).
|
||||
pub type BucketLLIR = (FxHashMap<char, usize>, FxHashMap<char, usize>, LLIRGraph);
|
||||
|
||||
@@ -1314,6 +1321,7 @@ impl Graph {
|
||||
rng,
|
||||
&self.dyn_map.clone(),
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
search_started_at,
|
||||
);
|
||||
@@ -1349,6 +1357,11 @@ impl Graph {
|
||||
rng,
|
||||
&context.representative_dyn_map,
|
||||
Some((combo_idx, n_combos)),
|
||||
Some(SearchProfileBucketContext {
|
||||
dim_buckets: self.search_space_dim_buckets.clone(),
|
||||
bucket_indices: context.bucket_indices.clone(),
|
||||
representative_dyn_map: context.representative_dyn_map.clone(),
|
||||
}),
|
||||
combo_idx,
|
||||
search_started_at,
|
||||
);
|
||||
@@ -1445,6 +1458,7 @@ impl Graph {
|
||||
rng: &mut G,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
bucket_progress: Option<(usize, usize)>,
|
||||
bucket_profile_context: Option<SearchProfileBucketContext>,
|
||||
egraph_index: usize,
|
||||
search_started_at: std::time::Instant,
|
||||
) -> LLIRGraph {
|
||||
@@ -1545,12 +1559,27 @@ impl Graph {
|
||||
collapse_loops_to_first_iter(&mut graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let (rep_metric, rep_display) =
|
||||
if let Some(bucket_context) = &bucket_profile_context {
|
||||
runtime.profile_with_bucket_context(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
ProfileBucketContext {
|
||||
dim_buckets: &bucket_context.dim_buckets,
|
||||
bucket_indices: &bucket_context.bucket_indices,
|
||||
representative_dyn_map: &bucket_context.representative_dyn_map,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
runtime.profile(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
)
|
||||
};
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
(
|
||||
@@ -1664,12 +1693,27 @@ impl Graph {
|
||||
collapse_loops_to_first_iter(&mut llir_graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let (rep_metric, rep_display) =
|
||||
if let Some(bucket_context) = &bucket_profile_context {
|
||||
runtime.profile_with_bucket_context(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
ProfileBucketContext {
|
||||
dim_buckets: &bucket_context.dim_buckets,
|
||||
bucket_indices: &bucket_context.bucket_indices,
|
||||
representative_dyn_map: &bucket_context.representative_dyn_map,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
runtime.profile(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
)
|
||||
};
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan =
|
||||
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
|
||||
20
src/op.rs
20
src/op.rs
@@ -7,6 +7,12 @@ use crate::prelude::*;
|
||||
use as_any::{AsAny, Downcast};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub struct ProfileBucketContext<'a> {
|
||||
pub dim_buckets: &'a FxHashMap<char, Vec<DimBucket>>,
|
||||
pub bucket_indices: &'a FxHashMap<char, usize>,
|
||||
pub representative_dyn_map: &'a FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
pub trait Runtime {
|
||||
type Ops: IntoEgglogOp;
|
||||
type CompileArg;
|
||||
@@ -36,6 +42,20 @@ pub trait Runtime {
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String);
|
||||
/// Profile one candidate in the context of a specific dynamic-dimension
|
||||
/// bucket. Runtimes with bucket-sensitive lowering can override this so
|
||||
/// search ranks candidates under the same execution model used after
|
||||
/// final bucket compilation.
|
||||
fn profile_with_bucket_context(
|
||||
&mut self,
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
_bucket_context: ProfileBucketContext<'_>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.profile(llir_graph, dyn_map, trials, timeout)
|
||||
}
|
||||
/// Aggregate multiple profile metrics into one comparable metric.
|
||||
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
|
||||
Reference in New Issue
Block a user