Compare commits

...

2 Commits

Author SHA1 Message Date
Joe Fioti
62e86f9dc5 Reuse cuBLASLt prepares across matching graph ops 2026-06-01 00:25:30 +00:00
Joe Fioti
75e4e6be0a Simplify example mains and trim CUDA profiling output (#339)
* Simplify example mains and trim CUDA profiling output

* Simplify model examples and adjust CUDA profiling output

* Simplify example model setup and CUDA profiling output
2026-05-29 23:37:13 -04:00
51 changed files with 5034 additions and 1818 deletions

View File

@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
let c = a.matmul(b).output();
// Compile
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

View File

@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
}
}
let mut rt = cx.search(rt, CompileOptions::new(5));
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
let mut bench_metrics = None;

View File

@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
rt.set_data(*node, &data);
}
let rt = cx.search(rt, CompileOptions::new(5));
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
Some(PreparedBench {
rt,

View File

@@ -1,4 +1,7 @@
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
use half::{bf16, f16};
use luminal::{
@@ -15,6 +18,8 @@ use luminal::{
},
};
#[cfg(test)]
use crate::kernel::CudaGraphHandle;
use crate::{
cudarc::{
cublas::sys::cublasOperation_t,
@@ -33,7 +38,7 @@ use crate::{
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
},
},
driver::{CudaStream, DevicePtr},
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::{DeviceBuffer, HostOp},
try_create_cublaslt,
@@ -581,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
}
}
#[derive(Debug, Clone, Copy)]
enum LtScalar {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum LtScalar {
F64(f64),
F32(f32),
F16(f16),
@@ -622,16 +627,16 @@ impl LtScalar {
}
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulProblem {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulProblem {
m: u64,
n: u64,
k: u64,
batch_count: i32,
}
#[derive(Debug, Clone, Copy)]
struct LtMatrixSpec {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatrixSpec {
dtype: cudaDataType,
rows: u64,
cols: u64,
@@ -640,8 +645,8 @@ struct LtMatrixSpec {
order: cublasLtOrder_t,
}
#[derive(Debug, Clone, Copy)]
struct LtComputeSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtComputeSpec {
compute_type: cublasComputeType_t,
scale_dtype: cudaDataType,
alpha: LtScalar,
@@ -649,8 +654,8 @@ struct LtComputeSpec {
epilogue: cublasLtEpilogue_t,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtMatmulSpec {
problem: LtMatmulProblem,
trans_a: cublasOperation_t,
trans_b: cublasOperation_t,
@@ -662,8 +667,8 @@ struct LtMatmulSpec {
workspace_size: usize,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulPointers {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulPointers {
a: u64,
b: u64,
c: u64,
@@ -673,7 +678,35 @@ struct LtMatmulPointers {
b_scale: Option<u64>,
}
struct LtRawDescriptors {
impl LtMatmulPointers {
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
let mut fields = Vec::new();
if self.a != other.a {
fields.push("a");
}
if self.b != other.b {
fields.push("b");
}
if self.c != other.c {
fields.push("c");
}
if self.d != other.d {
fields.push("d");
}
if self.bias != other.bias {
fields.push("bias");
}
if self.a_scale != other.a_scale {
fields.push("a_scale");
}
if self.b_scale != other.b_scale {
fields.push("b_scale");
}
fields
}
}
pub(crate) struct LtRawDescriptors {
matmul_desc: cublasLtMatmulDesc_t,
a_desc: cublasLtMatrixLayout_t,
b_desc: cublasLtMatrixLayout_t,
@@ -682,6 +715,23 @@ struct LtRawDescriptors {
preference: cublasLtMatmulPreference_t,
}
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
> = OnceLock::new();
#[cfg(test)]
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
}
#[cfg(test)]
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
}
impl Default for LtRawDescriptors {
fn default() -> Self {
Self {
@@ -720,6 +770,121 @@ impl Drop for LtRawDescriptors {
}
}
pub(crate) struct PreparedCuBlasLtMatmul {
cublaslt: Arc<CudaBlasLT>,
spec: LtMatmulSpec,
resources: LtRawDescriptors,
heuristic: cublasLtMatmulHeuristicResult_t,
_workspace: CudaSlice<u8>,
workspace_ptr: u64,
_a_scale: Option<CudaSlice<f32>>,
default_a_scale_ptr: Option<u64>,
_b_scale: Option<CudaSlice<f32>>,
default_b_scale_ptr: Option<u64>,
_c_scale: Option<CudaSlice<f32>>,
_d_scale: Option<CudaSlice<f32>>,
}
impl PreparedCuBlasLtMatmul {
fn update_descriptor_pointers(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
stream.context().bind_to_thread()?;
if let Some(bias_ptr) = ptrs.bias {
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias_ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
ptr,
)?;
}
Ok(())
}
pub(crate) fn enqueue(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
self.update_descriptor_pointers(stream, ptrs)?;
let alpha_ptr = self.spec.compute.alpha.as_ptr();
let beta_ptr = self.spec.compute.beta.as_ptr();
unsafe {
cublasLtMatmul(
*self.cublaslt.handle(),
self.resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
self.resources.a_desc,
ptrs.b as *const std::ffi::c_void,
self.resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
self.resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
self.resources.d_desc,
&self.heuristic.algo,
self.workspace_ptr as *mut std::ffi::c_void,
self.spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtCaptureSignature {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtPrepareKey {
spec: LtMatmulSpec,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtResolvedGraphCall {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
impl CuBlasLtResolvedGraphCall {
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
CuBlasLtCaptureSignature {
spec: self.spec,
ptrs: self.ptrs,
}
}
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
CuBlasLtPrepareKey { spec: self.spec }
}
}
fn create_matrix_layout(
desc: &mut cublasLtMatrixLayout_t,
spec: LtMatrixSpec,
@@ -796,12 +961,15 @@ fn set_scalar_scale_pointer(
Ok(())
}
fn run_cublaslt_matmul(
pub(crate) fn prepare_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
#[cfg(test)]
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
return Err(anyhow::anyhow!(
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
@@ -813,17 +981,17 @@ fn run_cublaslt_matmul(
let mut resources = LtRawDescriptors::default();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
drop(workspace_guard);
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
};
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
@@ -879,29 +1047,27 @@ fn run_cublaslt_matmul(
}
}
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
(Some(ptr), None)
} else if let Some(scale) = &a_scale {
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
(Some(ptr), None)
} else if let Some(scale) = &b_scale {
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
@@ -935,6 +1101,7 @@ fn run_cublaslt_matmul(
ptr,
)?;
}
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
create_matrix_layout(&mut resources.a_desc, spec.a)?;
create_matrix_layout(&mut resources.b_desc, spec.b)?;
@@ -952,58 +1119,148 @@ fn run_cublaslt_matmul(
}
}
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
let cached_heuristic = {
let cache = heuristic_cache.lock().unwrap();
cache
.iter()
.find(|(cached_spec, _)| cached_spec == spec)
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
};
if let Some(cached) = cached_heuristic {
heuristic = cached;
} else {
let mut algo_count: i32 = 0;
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
}
let alpha_ptr = spec.compute.alpha.as_ptr();
let beta_ptr = spec.compute.beta.as_ptr();
cublasLtMatmul(
*cublaslt.handle(),
resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
resources.a_desc,
ptrs.b as *const std::ffi::c_void,
resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
resources.d_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
heuristic_cache
.lock()
.unwrap()
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
}
Ok(())
Ok(PreparedCuBlasLtMatmul {
cublaslt: cublaslt.clone(),
spec: *spec,
resources,
heuristic,
_workspace: workspace,
workspace_ptr,
_a_scale: a_scale,
default_a_scale_ptr,
_b_scale: b_scale,
default_b_scale_ptr,
_c_scale: c_scale,
_d_scale: d_scale,
})
}
fn run_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
prepared.enqueue(stream, ptrs)
}
#[cfg(test)]
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
let capture_stream = stream.context().new_stream()?;
let cublaslt = try_create_cublaslt(stream.clone())
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
let a_buf = stream.clone_htod(&[1.0f32])?;
let b_buf = stream.clone_htod(&[1.0f32])?;
let d_buf = unsafe { stream.alloc::<f32>(1)? };
let (a, a_guard) = a_buf.device_ptr(stream);
let (b, b_guard) = b_buf.device_ptr(stream);
let (d, d_guard) = d_buf.device_ptr(stream);
drop((a_guard, b_guard, d_guard));
let matrix = LtMatrixSpec {
dtype: cudaDataType::CUDA_R_32F,
rows: 1,
cols: 1,
ld: 1,
batch_stride: 1,
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m: 1,
n: 1,
k: 1,
batch_count: 1,
},
trans_a: cublasOperation_t::CUBLAS_OP_N,
trans_b: cublasOperation_t::CUBLAS_OP_N,
a: matrix,
b: matrix,
c: matrix,
d: matrix,
compute: LtComputeSpec {
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
scale_dtype: cudaDataType::CUDA_R_32F,
alpha: LtScalar::F32(1.0),
beta: LtScalar::F32(0.0),
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
},
workspace_size: 1024 * 1024,
};
let ptrs = LtMatmulPointers {
a,
b,
c: d,
d,
bias: None,
a_scale: None,
b_scale: None,
};
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
let entry = graph.add_empty_node(&[])?;
capture_stream.join(stream)?;
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
let end_result = graph.end_capture(&capture_stream);
enqueue_result?;
end_result?;
Ok(())
}
let supported = probe(stream).is_ok();
let _ = stream.synchronize();
supported
}
fn resolve_cublaslt_pointers(
@@ -1126,6 +1383,151 @@ impl CuBlasLt {
Ok(created)
}
pub(crate) fn graph_inputs(&self) -> usize {
self.n_inputs()
}
pub(crate) fn resolve_for_graph(
&self,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
let element_size = (self.d_dtype.bits() / 8) as u64;
assert!(
element_size > 0,
"cuBLAS LT does not support sub-byte dtype {}",
self.d_dtype
);
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
let ptrs = resolve_cublaslt_pointers(
self_node,
inputs,
buffers,
self.beta,
self.epilogue,
self.a_scale_input,
self.b_scale_input,
)?;
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
let _span = span!(
Level::TRACE,
"cuBLASLT_resolve_graph",
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
?self.epilogue,
)
.entered();
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
let c_spec = LtMatrixSpec {
dtype: c_cuda_dtype,
rows: m,
cols: n,
ld: ldc,
batch_stride: stride_c,
order: self.c_order,
};
let d_spec = LtMatrixSpec {
dtype: d_cuda_dtype,
rows: m,
cols: n,
ld: ldd,
batch_stride: stride_d,
order: self.d_order,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m,
n,
k,
batch_count,
},
trans_a: a_layout,
trans_b: b_layout,
a: LtMatrixSpec {
dtype: a_cuda_dtype,
rows: a_rows,
cols: a_cols,
ld: lda,
batch_stride: stride_a,
order: self.a_order,
},
b: LtMatrixSpec {
dtype: b_cuda_dtype,
rows: b_rows,
cols: b_cols,
ld: ldb,
batch_stride: stride_b,
order: self.b_order,
},
c: c_spec,
d: d_spec,
compute: LtComputeSpec {
compute_type: self.compute_type,
scale_dtype: scale_cuda_dtype,
alpha,
beta,
epilogue: self.epilogue,
},
workspace_size: WORKSPACE_SIZE,
};
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
}
pub(crate) fn prepare_resolved_for_graph(
&self,
stream: &Arc<CudaStream>,
resolved: CuBlasLtResolvedGraphCall,
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
let cublaslt = self.get_cublaslt(stream)?;
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
}
#[cfg(test)]
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
(

View File

@@ -2,7 +2,7 @@ use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use luminal::{op::EgglogOp, prelude::*};
mod cublaslt;
pub(crate) mod cublaslt;
pub mod flashinfer;
pub mod moe;
@@ -167,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
None
}
/// Returns pairs of extra buffer nodes that must not share arena storage.
///
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
/// two buffers may have disjoint positions in one topological order while
/// still being unordered by real dependencies, so CUDA could overlap them.
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
None
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.

View File

@@ -7,7 +7,10 @@ use std::sync::Arc;
use cudarc::driver::{
CudaContext, CudaFunction, CudaStream, DriverError,
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
sys::{
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
},
};
/// A CUDA graph that can be modified and instantiated.
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
}
}
/// Updates a kernel node in the mutable source graph.
pub unsafe fn set_kernel_node_params(
&mut self,
node: CUgraphNode,
func: CUfunction,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
shared_mem_bytes: u32,
kernel_params: *mut *mut c_void,
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let params = sys::CUDA_KERNEL_NODE_PARAMS {
func,
gridDimX: grid_dim.0,
gridDimY: grid_dim.1,
gridDimZ: grid_dim.2,
blockDimX: block_dim.0,
blockDimY: block_dim.1,
blockDimZ: block_dim.2,
sharedMemBytes: shared_mem_bytes,
kernelParams: kernel_params,
extra: std::ptr::null_mut(),
kern: std::ptr::null_mut(),
ctx: std::ptr::null_mut(),
};
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, &params).result() }
}
/// Adds an empty dependency node to the graph.
pub fn add_empty_node(
&mut self,
dependencies: &[CUgraphNode],
) -> Result<CUgraphNode, DriverError> {
self.ctx.bind_to_thread()?;
let mut node = MaybeUninit::uninit();
unsafe {
sys::cuGraphAddEmptyNode(
node.as_mut_ptr(),
self.cu_graph,
dependencies.as_ptr(),
dependencies.len(),
)
.result()?;
Ok(node.assume_init())
}
}
/// Destroys a node in the mutable graph.
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { sys::cuGraphDestroyNode(node).result() }
}
/// Adds dependency edges to the mutable graph.
pub fn add_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Removes dependency edges from the mutable graph.
pub fn remove_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Returns all nodes currently in the graph.
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut nodes = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
}
nodes.truncate(count);
Ok(nodes)
}
/// Returns the direct dependencies of a node.
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Returns the direct dependents of a node.
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Begins stream capture that appends captured work into this graph.
pub fn begin_capture_to_graph(
&mut self,
stream: &CudaStream,
dependencies: &[CUgraphNode],
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe {
sys::cuStreamBeginCaptureToGraph(
stream.cu_stream(),
self.cu_graph,
dependencies.as_ptr(),
std::ptr::null(),
dependencies.len(),
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
)
.result()
}
}
/// Ends stream capture previously started by begin_capture_to_graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut graph = MaybeUninit::uninit();
unsafe {
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
let captured = graph.assume_init();
if captured != self.cu_graph && !captured.is_null() {
sys::cuGraphDestroy(captured).result()?;
}
}
Ok(())
}
/// Adds an event record node to the graph for timing.
pub fn add_event_record_node(
&mut self,
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, &params) }
.result()
}
/// Attempts to update this executable graph from an already-mutated source graph.
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut result = CUgraphExecUpdateResultInfo {
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
errorNode: std::ptr::null_mut(),
errorFromNode: std::ptr::null_mut(),
};
unsafe {
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
}
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
return Err(DriverError(
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
));
}
Ok(())
}
}
impl Drop for CudaGraphExecHandle {
@@ -480,6 +672,38 @@ mod tests {
assert_eq!(result[0], 6.0f32);
}
#[test]
fn test_graph_empty_node_dependency_reconnect() {
let Ok(ctx) = CudaContext::new(0) else { return };
let mut graph = CudaGraphHandle::new(ctx).unwrap();
let entry = graph.add_empty_node(&[]).unwrap();
let middle = graph.add_empty_node(&[entry]).unwrap();
let exit = graph.add_empty_node(&[middle]).unwrap();
let nodes = graph.nodes().unwrap();
assert!(nodes.contains(&entry));
assert!(nodes.contains(&middle));
assert!(nodes.contains(&exit));
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
graph.add_dependencies(&[entry], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert!(exit_deps.contains(&entry));
assert!(exit_deps.contains(&middle));
graph.remove_dependencies(&[middle], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert_eq!(exit_deps.len(), 1);
assert!(exit_deps.contains(&entry));
unsafe {
graph.destroy_node(middle).unwrap();
}
assert!(!graph.nodes().unwrap().contains(&middle));
}
// CUDA Graph Tests
#[test]
@@ -499,7 +723,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
rt.execute(&cx.dyn_map);
@@ -531,7 +755,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
let mut results = Vec::new();
for _ in 0..5 {
rt.execute(&cx.dyn_map);
@@ -569,7 +793,7 @@ mod tests {
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
.iter()
@@ -611,7 +835,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
@@ -642,7 +866,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
for _ in 0..10 {
rt.execute(&cx.dyn_map);
}
@@ -675,7 +899,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Initial execution
rt.execute(&cx.dyn_map);

View File

@@ -304,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
// Kernel to host op compilation
mod to_host;
#[cfg(test)]
pub(crate) use to_host::CudaGraphDebugSummary;
pub use to_host::{CudaGraphOp, kernel_to_host};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,11 @@ fn test_bucket_dispatch_simple() {
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Test bucket 1: s=1
cx.set_dim('s', 1);
@@ -91,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute at s=1
cx.set_dim('s', 1);
@@ -146,7 +154,11 @@ fn test_bucket_results_match_unbucketed() {
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
rt1 = cx1.search_with_rng(
rt1,
CompileOptions::default().search_graph_limit(5),
&mut rng1,
);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
@@ -158,7 +170,11 @@ fn test_bucket_results_match_unbucketed() {
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
rt2 = cx2.search_with_rng(
rt2,
CompileOptions::default().search_graph_limit(5),
&mut rng2,
);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
@@ -186,7 +202,11 @@ fn test_bucket_out_of_range_panics() {
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
@@ -211,7 +231,11 @@ fn test_bucket_no_buckets_backward_compat() {
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
@@ -257,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
@@ -311,7 +339,11 @@ fn test_bucket_multiple_executions_same_bucket() {
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {

View File

@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Print and verify which scatter variant was selected
let scatter_names: Vec<_> = rt
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
// Use seeded search for deterministic variant selection.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Print and verify selected variants
let scatter_names: Vec<_> = rt
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
rt.set_data(gather_idx, scatter);
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
assert_eq!(rt.get_f32(gathered), expected_gather);
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(weights);
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
rt.set_data(v_full, v_data.clone());
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
rt.set_data(v_full, v_data.clone());
rt.set_data(weights, weights_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(x, x_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(roundtrip);
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
rt.set_data(x, x_data.clone());
rt.set_data(w, w_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(topk);
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(argsort);
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(ranks);
@@ -1395,7 +1447,11 @@ fn test_dynamic_3d_flat_index_iota_rows() {
let mut rt = CudaRuntime::initialize(stream);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(idx);
@@ -1438,7 +1494,11 @@ fn test_dynamic_2d_to_3d_gather_rows() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(data, data_values.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(out);
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
rt.set_data(topk, topk_data.clone());
rt.set_data(weights, weights_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);

View File

@@ -7,10 +7,12 @@ use luminal::{
prelude::*,
};
use rand::{SeedableRng, rngs::StdRng};
use std::sync::Arc;
use crate::{
host::{
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
cublaslt_type_tuple,
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
expected
}
fn reference_mixed_chain(
a: &[f32],
pre: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> Vec<f32> {
let mut expected = vec![0.0; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0.0;
for inner in 0..k {
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
}
expected[row * n + col] = acc.exp();
}
}
expected
}
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
crate::try_create_cublaslt(stream.clone()).is_ok()
}
fn build_mixed_chain_graph(
m: impl Into<Expression>,
n: usize,
k: usize,
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
let m = m.into();
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let pre = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = ((a + pre).matmul(b).exp()).output();
(cx, a.id, pre.id, b.id, out.id)
}
fn add_in_place(values: &mut [f32], addends: &[f32]) {
for (value, addend) in values.iter_mut().zip(addends) {
*value += *addend;
@@ -507,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
}
}
#[test]
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
});
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
let summaries = rt.debug_cuda_graph_summaries();
let mixed = summaries
.iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (5, 8, 8);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let first = a.matmul(b);
let out = (a + first.sin()).matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"ordered matching cuBLASLt prepared reuse",
|llir| {
let orders = cublaslt_matrix_order_tuples(llir);
orders.len() == 2 && orders[0] == orders[1]
},
);
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
let dep = a_data
.iter()
.zip(&first)
.map(|(a, first)| a + first.sin())
.collect::<Vec<_>>();
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 2)
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
assert_eq!(
summary.n_cublaslt_prepared, 1,
"dependency-ordered matching cuBLASLt calls should share prepared resources"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
a.output();
b.output();
cx.set_dim('p', 1);
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"cuBLASLt unchanged under unrelated dyn dim",
|_| true,
);
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
reset_cublaslt_prepare_count_for_test();
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let first_prepare_count = cublaslt_prepare_count_for_test();
assert!(
first_prepare_count > 0,
"first execution should prepare the captured cuBLASLt island"
);
cx.set_dim('p', 2);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count,
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('m', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('m', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [
(7usize, 0xC002_0001),
(9usize, 0xC002_0002),
(7usize, 0xC002_0003),
] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cublaslt_with_dynamic_c_spec_is_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('c', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('c', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
cx.set_dim('c', c);
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('s', 1);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
.into_iter()
.collect();
let bucket_llirs = vec![
(
[('s', 0usize)].into_iter().collect(),
[('s', 1usize)].into_iter().collect(),
llir.clone(),
),
(
[('s', 1usize)].into_iter().collect(),
[('s', 3usize)].into_iter().collect(),
llir,
),
];
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data.clone());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"singleton s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
cx.set_dim('s', 3);
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"range s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
reset_cublaslt_prepare_count_for_test();
let mut first_prepare_count = None;
for seed in [0xCC00_0001, 0xCC00_0002] {
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
if first_prepare_count.is_none() {
first_prepare_count = Some(cublaslt_prepare_count_for_test());
}
}
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count.unwrap(),
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
);
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after pointer recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
cx.set_dim('m', 7);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
}
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after dynamic-shape recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {

View File

@@ -89,7 +89,7 @@ fn run_reference_attention(
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
rt.set_data(v_t, v.to_vec());
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());

View File

@@ -61,7 +61,7 @@ fn generic_matmul_executes_noncontiguous_merged_head_projection() {
rt.set_data(attn, attn_data.as_slice());
rt.set_data(weight, weight_data.as_slice());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
assert!(
rt.kernel_names().contains(&"GenericMatmul"),
"expected GenericMatmul to be selected, kernels: {:?}",
@@ -109,7 +109,7 @@ fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
);
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
let mut choices = random_initial_choice(egraph, &mut rng);
let kernel_class = &egraph.node_to_class[*kernel_node];
choices.insert(kernel_class, kernel_node);

View File

@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -518,7 +518,7 @@ mod gemma {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -655,7 +655,7 @@ mod qwen {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -259,7 +259,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data);
rt = cx.search(rt, CompileOptions::new(10));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&cx.dyn_map);
let out_dim0 = rt.get_i32(sorted_dim0.id);
let out_dim1 = rt.get_i32(sorted_dim1.id);
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
rt.set_data(token_ids, token_data.clone());
rt.set_data(embed_table, embed_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);

View File

@@ -31,7 +31,7 @@ pub fn kernel_add_bandwidth_test() {
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Warm up
rt.execute(&cx.dyn_map);

View File

@@ -233,7 +233,9 @@ fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
rt.set_data(model.router, router_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, CompileOptions::new(10));
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
rt.get_f32(model.output.id)
@@ -278,7 +280,9 @@ fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
rt.set_data(model.per_expert_scale, per_expert_scale_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, CompileOptions::new(10));
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
rt.get_f32(model.output.id)

View File

@@ -50,7 +50,7 @@ fn rope_matches_cpu_reference() {
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);
@@ -98,7 +98,7 @@ fn rope_flux2_shape() {
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);

View File

@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
// Use minimal search iterations to avoid excessive graph rewriting
// which can cause float drift through softmax/RMSNorm reordering
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
.collect();
rt.set_data(input, input_data.clone());
rt.set_data(weight, weight_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
rt.set_data(wk, wk_data.clone());
rt.set_data(wv, wv_data.clone());
rt.set_data(wo, wo_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -530,7 +530,7 @@ fn test_rolled_chained_scalar_muls() {
let mut rt = CudaRuntime::initialize(stream);
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
rt.set_data(x, x_data.clone());
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -331,7 +331,7 @@ pub fn fuzz_cuda_search_space_equivalence(
let mut native_rng = StdRng::seed_from_u64(config.seed);
let mut native_rt = cx.search_with_rng(
NativeRuntime::default(),
CompileOptions::new(1),
CompileOptions::default().search_graph_limit(1),
&mut native_rng,
);
for input in inputs {
@@ -701,7 +701,7 @@ pub fn test_unary_cuda<T: TestDType>(
let input_data = generator(n_elements, seed);
rt.set_data(a, input_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, b.id);
@@ -776,7 +776,7 @@ pub fn test_binary_cuda<T: TestDType>(
let b_data = b_generator(b_elements, seed.wrapping_add(1));
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, c.id);
@@ -844,7 +844,7 @@ pub fn test_mod(
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(c);

View File

@@ -503,7 +503,8 @@ fn main() -> Result<(), Box<dyn Error>> {
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, CompileOptions::new(SEARCH_GRAPHS));
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
runtime = cx.search(runtime, search_options);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -41,7 +41,11 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
let mut rng = StdRng::seed_from_u64(0);
cx.search_with_rng(rt, CompileOptions::new(limit), &mut rng)
cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(limit),
&mut rng,
)
}
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
@@ -301,7 +305,7 @@ fn dynamic_dim_sum_reduce_runs() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -322,7 +326,7 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, vec![1.0f32; 4]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
cx.set_dim('s', 1);
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
@@ -350,7 +354,7 @@ fn metal_int_arithmetic_preserves_large_values() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(token, &[16_385i32]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -373,7 +377,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -395,7 +399,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -417,7 +421,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -449,7 +453,7 @@ fn metal_simple_add() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -469,7 +473,7 @@ fn metal_simple_mul() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -487,7 +491,7 @@ fn metal_simple_exp2() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -504,7 +508,7 @@ fn metal_simple_log2() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -529,7 +533,7 @@ fn metal_simple_sin() {
3.0 * std::f32::consts::FRAC_PI_2,
],
);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -546,7 +550,7 @@ fn metal_simple_sqrt() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -563,7 +567,7 @@ fn metal_simple_recip() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -582,7 +586,7 @@ fn metal_simple_mod() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -601,7 +605,7 @@ fn metal_simple_less_than() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -621,7 +625,7 @@ fn metal_simple_sum_reduce() {
let mut rt = MetalRuntime::initialize(());
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -640,7 +644,7 @@ fn metal_simple_max_reduce() {
let mut rt = MetalRuntime::initialize(());
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -657,7 +661,7 @@ fn metal_f16_cast_roundtrip() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -678,7 +682,7 @@ fn metal_f16_intermediate_add_roundtrip() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1026,7 +1030,7 @@ fn metal_rms_norm() {
rt.set_data(input, &input_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1066,7 +1070,7 @@ fn metal_self_attention() {
rt.set_data(wk, &wk_data);
rt.set_data(wv, &wv_data);
rt.set_data(wo, &wo_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1125,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
rt.set_data(wk, to_f16_vec(&wk_data));
rt.set_data(wv, to_f16_vec(&wv_data));
rt.set_data(wo, to_f16_vec(&wo_data));
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1169,7 +1173,7 @@ fn metal_swiglu_mlp() {
rt.set_data(w_gate, &gate_data);
rt.set_data(w_up, &up_data);
rt.set_data(w_down, &down_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1219,7 +1223,7 @@ fn metal_mini_transformer_layer() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1285,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1327,7 +1331,7 @@ fn test_scatter_basic() {
rt.set_data(src, &[10.0, 20.0, 30.0]);
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1349,7 +1353,7 @@ fn test_scatter_buffer_roundtrip() {
rt.set_data(src, &[0.0]);
rt.set_data(indexes, &[0.0]);
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
for (pos, value, expected) in [
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
@@ -1383,7 +1387,7 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
rt.set_data(weights, &[99.0, 99.0, 99.0]);
rt.set_data(bias, &[0.5, 1.0, -1.5]);
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1448,7 +1452,7 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1475,7 +1479,7 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
);
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1495,7 +1499,7 @@ fn test_scatter_into_nonzero_dest() {
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[2f32]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1522,7 +1526,7 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
rt.set_data(src, &[7.0, 8.0]);
rt.set_data(indexes, &[1.0, 3.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1551,7 +1555,7 @@ fn test_scatter_no_copy_handles_2d_destination() {
rt.set_data(src, &[9.0, 8.0]);
rt.set_data(indexes, &[2.0, 4.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1578,7 +1582,7 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[1.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1605,7 +1609,7 @@ fn test_scatter_all_positions() {
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1624,7 +1628,7 @@ fn test_gather_preserves_data_dtype() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(data, &[1.25, 2.5]);
rt.set_data(indexes, &[1.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);

View File

@@ -167,7 +167,10 @@ mod tests {
let result = gather_rows(data, indices, 3).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
rt.set_data(
@@ -193,7 +196,10 @@ mod tests {
let result = scatter_rows(src, indices, dest, 3).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
rt.set_data(indices.id, vec![1, 3]);
@@ -219,7 +225,10 @@ mod tests {
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
@@ -272,7 +281,10 @@ mod tests {
let v_cache_new = v_cache_new.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
rt.set_data(q.id, vec![1., 0., 1., 0.]);
@@ -345,7 +357,10 @@ mod tests {
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
// K cached at slot 0: [1, 0]
@@ -417,7 +432,10 @@ mod tests {
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Cache has 1 token at slot 0
let mut k_cache_data = vec![0.; num_slots * kv_dim];

View File

@@ -184,7 +184,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 2.0, 3.0];
// Router strongly favors expert 0
@@ -239,7 +242,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 1.0];
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
@@ -293,7 +299,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
@@ -350,7 +359,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(in_dim);
let router_data = random_vec(in_dim * n_experts);
@@ -395,7 +407,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(batch * in_dim);
let router_data = random_vec(in_dim * n_experts);

View File

@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
.unwrap_or(default)
}
fn search_options() -> CompileOptions {
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
}
fn env_f32(name: &str, default: f32) -> f32 {
std::env::var(name)
.ok()
@@ -187,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
println!("Compiling text encoder...");
let t0 = Instant::now();
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
println!("Encoding prompt...");
@@ -345,10 +349,9 @@ fn run_full_pipeline(
{
use rand::SeedableRng;
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
let opts = luminal::graph::CompileOptions::new(env_usize("SEARCH_ITERS", 5));
runtime = cx.search_with_rng(runtime, opts, &mut rng);
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
} else {
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
}
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
@@ -415,7 +418,7 @@ fn run_full_pipeline(
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
runtime.set_data(latent_in, vae_input);
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
runtime.execute(&cx.dyn_map);
let img = runtime.get_f32(out);
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'

View File

@@ -720,6 +720,10 @@ mod tests {
out
}
fn one_search() -> CompileOptions {
CompileOptions::default().search_graph_limit(1)
}
#[test]
fn conv2d_bias_matches_reference() {
let mut cx = Graph::default();
@@ -747,7 +751,7 @@ mod tests {
);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -766,7 +770,7 @@ mod tests {
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.execute(&cx.dyn_map);
@@ -790,7 +794,7 @@ mod tests {
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(ctx.default_stream());
rt.set_data(input_t, input);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);
@@ -821,7 +825,7 @@ mod tests {
);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -864,7 +868,7 @@ mod tests {
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);

View File

@@ -84,7 +84,8 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -36,14 +36,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -68,114 +70,11 @@ pub struct Gemma {
impl Gemma {
pub fn init(cx: &mut Graph) -> Self {
let mut w = vec![];
for l in 0..LAYERS {
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(GemmaLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
input_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
cx,
),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
});
}
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
lm_head,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -185,11 +84,7 @@ impl Gemma {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -226,6 +121,114 @@ struct GemmaLayer {
rope_scaling_factor: f32,
}
impl GemmaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
input_layernorm: layer_norm(cx, l, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = gemma_rotary_embeddings(
qk_norm(q, self.q_norm, N_HEADS),
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
qk_norm(k, self.k_norm, N_KV_HEADS),
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = x + self.post_attention_layernorm.forward(attn_proj);
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
(
x + self.post_feedforward_layernorm.forward(mlp_out),
k_cache_out,
v_cache_out,
)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
/// This produces far fewer e-graph nodes than the tanh-based expansion.
#[allow(clippy::excessive_precision)]
@@ -363,59 +366,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl GemmaLayer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm + RoPE
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
let q_rope = gemma_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
k_normed,
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
// O projection + post-attention norm + residual
let attn_proj = attn_out.matmul(self.o_proj.t());
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
let x = x + attn_normed;
// Pre-feedforward norm + MLP + post-feedforward norm + residual
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
(x + mlp_normed, k_cache_out, v_cache_out)
}
}

View File

@@ -81,11 +81,10 @@ fn main() {
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(
runtime,
CompileOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(search_graphs)
.profile_timeout(Duration::from_secs(2));
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);

View File

@@ -83,20 +83,16 @@ impl KVCache {
let mut v_caches = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let k = cx
.named_tensor(
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
let v = cx
.named_tensor(
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
v_caches.push(persist(
cx,
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
}
Self {
k_caches,
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
impl Gemma4MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let gate = cx
.named_tensor(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let up = cx
.named_tensor(
format!("model.layers.{layer}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{layer}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
(spec.q_dim, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist();
let v_proj = spec.has_v_proj.then(|| {
cx.named_tensor(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist()
});
let o_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
(HIDDEN, spec.q_dim),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_norm.weight"),
spec.head_dim,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_norm.weight"),
spec.head_dim,
)
.persist();
let layer_scalar = cx
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
.persist();
let router_scale = cx
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
.persist();
let router_proj = cx
.named_tensor(
format!("model.layers.{layer}.router.proj.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let per_expert_scale = cx
.named_tensor(
format!("model.layers.{layer}.router.per_expert_scale"),
NUM_EXPERTS,
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.gate_up_proj"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist()
.as_dtype(DType::Bf16);
let down_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.down_proj"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist()
.as_dtype(DType::Bf16);
layers.push(Gemma4Layer {
spec,
gate,
up,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
layer_scalar,
input_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm_1: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
cx,
),
post_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
cx,
),
pre_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
cx,
),
moe: Gemma4SparseMoE {
router_scale,
router_proj,
per_expert_scale,
gate_up_weights,
down_weights,
},
});
}
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
Self {
embedding,
lm_head,
layers,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS)
.map(|layer| Gemma4Layer::init(cx, layer))
.collect(),
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -287,11 +127,7 @@ impl Gemma4MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (layer_idx, layer) in self.layers.iter().enumerate() {
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
down_weights: GraphTensor,
}
impl Gemma4Layer {
fn init(cx: &mut Graph, layer: usize) -> Self {
let spec = layer_spec(layer);
Self {
spec,
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
v_proj: spec
.has_v_proj
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
moe: Gemma4SparseMoE {
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
gate_up_weights: layer_tensor(
cx,
layer,
"experts.gate_up_proj",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
layer,
"experts.down_proj",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
}
@@ -505,81 +499,6 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl Gemma4Layer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
}

View File

@@ -338,13 +338,11 @@ fn main() {
println!(" Search trials: {SEARCH_TRIALS}");
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(
runtime,
CompileOptions::new(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -111,125 +111,18 @@ impl Llama {
config: LlamaConfig,
fp8_linears: bool,
) -> Self {
let mut layers = Vec::with_capacity(config.layers);
for l in 0..config.layers {
layers.push(LlamaLayer {
config,
up: linear_weight(
cx,
format!("model.layers.{l}.mlp.up_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
up_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.up_proj"),
fp8_linears,
),
gate: linear_weight(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
gate_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
fp8_linears,
),
down: linear_weight(
cx,
format!("model.layers.{l}.mlp.down_proj"),
(config.hidden, config.intermediate),
fp8_linears,
),
down_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.down_proj"),
fp8_linears,
),
q_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
q_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
fp8_linears,
),
k_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
k_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
fp8_linears,
),
v_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
v_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
fp8_linears,
),
o_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
o_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
fp8_linears,
),
attn_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
config,
embedding: cx
.named_tensor(
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
)
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
.persist(),
lm_norm: LayerNorm::new(
config.hidden,
Some("model.norm.weight"),
None,
false,
1e-5,
embedding: persist(
cx,
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
),
layers: (0..config.layers)
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
.collect(),
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
}
}
@@ -243,12 +136,7 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let hidden = self.config.hidden;
let mut x = self.embedding.gather(
(input * hidden).expand_dim(1, hidden)
+ input.graph().arange(hidden).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, input, self.config.hidden);
let mut cache_outputs = Vec::with_capacity(self.config.layers);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
weight: GraphTensor,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
Self {
config,
up: layer_linear_weight(
cx,
l,
"mlp.up_proj",
(config.intermediate, config.hidden),
fp8,
),
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
gate: layer_linear_weight(
cx,
l,
"mlp.gate_proj",
(config.intermediate, config.hidden),
fp8,
),
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
down: layer_linear_weight(
cx,
l,
"mlp.down_proj",
(config.hidden, config.intermediate),
fp8,
),
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
q_proj: layer_linear_weight(
cx,
l,
"self_attn.q_proj",
(config.hidden, config.hidden),
fp8,
),
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
k_proj: layer_linear_weight(
cx,
l,
"self_attn.k_proj",
(config.kv_dim(), config.hidden),
fp8,
),
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
v_proj: layer_linear_weight(
cx,
l,
"self_attn.v_proj",
(config.kv_dim(), config.hidden),
fp8,
),
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
o_proj: layer_linear_weight(
cx,
l,
"self_attn.o_proj",
(config.hidden, config.hidden),
fp8,
),
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
attn_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.input_layernorm.weight"),
),
mlp_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn linear_weight(
cx: &mut Graph,
prefix: impl ToString,
@@ -325,6 +377,16 @@ fn linear_weight(
}
}
fn layer_linear_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
fp8: bool,
) -> GraphTensor {
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
}
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
if !fp8 {
return None;
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
})
}
fn layer_linear_scales(
cx: &mut Graph,
layer: usize,
suffix: &str,
fp8: bool,
) -> Option<Fp8LinearScales> {
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
}
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * hidden).expand_dim(1, hidden)
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
)
}
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
scale.expand_rhs(like.dims())
}
@@ -443,87 +526,3 @@ fn attention(
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}

View File

@@ -241,7 +241,8 @@ fn main() {
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Re-initialize KV cache after search (search consumes buffers)
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();

View File

@@ -25,8 +25,8 @@ pub struct PagedKVCache {
impl PagedKVCache {
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
@@ -44,78 +44,11 @@ pub struct Llama {
impl Llama {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
layers.push(LlamaLayer {
up: cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
gate: cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
down: cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist(),
q_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
k_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
v_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
o_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
embedding: cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -141,12 +74,8 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &PagedKVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let mut x = self.embedding.gather(
(input * HIDDEN).expand_dim(1, HIDDEN)
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = vec![];
let mut x = token_embedding(self.embedding, input);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
@@ -177,6 +106,99 @@ struct LlamaLayer {
mlp_rms: LayerNorm,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
1e-5,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
@@ -264,44 +286,3 @@ fn paged_attention(
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// Apply RoPE before scattering into cache
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -188,7 +188,8 @@ where
cx.set_dim('p', 0);
runtime.set_i32_data(input.id, vec![1; search_s]);
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(config.search_graphs));
let search_options = CompileOptions::default().search_graph_limit(config.search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..config.layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);

View File

@@ -34,14 +34,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(layers);
let mut v_caches = Vec::with_capacity(layers);
for l in 0..layers {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -63,105 +65,10 @@ impl Qwen {
layers <= LAYERS,
"requested {layers} layers, but model has {LAYERS}"
);
let mut w = vec![];
for l in 0..layers {
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(QwenLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..layers).map(|l| QwenLayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -172,11 +79,7 @@ impl Qwen {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(self.layers.len());
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -209,6 +112,90 @@ struct QwenLayer {
mlp_rms: LayerNorm,
}
impl QwenLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = qwen_rotary_embeddings(qk_norm(q, self.q_norm, N_HEADS), pos_ids, N_HEADS);
let k_rope =
qwen_rotary_embeddings(qk_norm(k, self.k_norm, N_KV_HEADS), pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// Per-head RMS normalization for QK-norm.
/// Input: [seq, dim] where dim = n_heads * HEAD_DIM
/// split_dims to [seq, n_heads, HEAD_DIM], RMS norm over last axis, multiply by weight, merge back.
@@ -331,36 +318,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl QwenLayer {
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm: per-head RMS normalization
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
// RoPE
let q_rope = qwen_rotary_embeddings(q_normed, pos_ids, N_HEADS);
let k_rope = qwen_rotary_embeddings(k_normed, pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -82,7 +82,8 @@ fn main() {
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(runtime, CompileOptions::new(search_graphs), &mut rng);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -29,17 +29,19 @@ pub struct KVCache {
impl KVCache {
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -58,111 +60,11 @@ pub struct Qwen3MoE {
impl Qwen3MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
let router = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_up_weights"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist();
let down_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_weights"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist();
layers.push(Qwen3MoELayer {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
moe: QwenMoE {
router,
gate_up_weights: gate_up_weights.as_dtype(DType::Bf16),
down_weights: down_weights.as_dtype(DType::Bf16),
},
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers,
lm_norm,
lm_head,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| Qwen3MoELayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
}
}
@@ -172,11 +74,7 @@ impl Qwen3MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -214,6 +112,39 @@ struct QwenMoE {
}
impl Qwen3MoELayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
moe: QwenMoE {
router: layer_weight(cx, l, "mlp.gate", (NUM_EXPERTS, HIDDEN)),
gate_up_weights: layer_tensor(
cx,
l,
"mlp.gate_up_weights",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
l,
"mlp.down_weights",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
mut x: GraphTensor,
@@ -247,6 +178,51 @@ impl Qwen3MoELayer {
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
impl QwenMoE {
fn forward(&self, x: GraphTensor) -> GraphTensor {
let n = x.dims().len(); // 2 for [s, H]

View File

@@ -12,7 +12,7 @@ fn main() {
// Compile
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

View File

@@ -96,7 +96,8 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1i32; max_prefill]);
runtime.set_data(pos_ids, (0..max_prefill as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Reset the KV caches and re-set the mel after search (which executes test runs).
for i in 0..N_TEXT_LAYER {

View File

@@ -33,6 +33,14 @@ fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
x.matmul(w.t())
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
fn conv1d_bias(
@@ -90,27 +98,13 @@ struct AttentionWeights {
impl AttentionWeights {
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
Self {
q_proj: cx
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
.persist(),
q_bias: cx
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
.persist(),
k_proj: cx
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
.persist(),
v_proj: cx
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
.persist(),
v_bias: cx
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
.persist(),
out_proj: cx
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
.persist(),
out_bias: cx
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
.persist(),
q_proj: persist(cx, format!("{prefix}.q_proj.weight"), (dim, dim)),
q_bias: persist(cx, format!("{prefix}.q_proj.bias"), dim),
k_proj: persist(cx, format!("{prefix}.k_proj.weight"), (dim, dim)),
v_proj: persist(cx, format!("{prefix}.v_proj.weight"), (dim, dim)),
v_bias: persist(cx, format!("{prefix}.v_proj.bias"), dim),
out_proj: persist(cx, format!("{prefix}.out_proj.weight"), (dim, dim)),
out_bias: persist(cx, format!("{prefix}.out_proj.bias"), dim),
}
}
}
@@ -125,6 +119,14 @@ fn merge_heads(x: GraphTensor) -> GraphTensor {
x.transpose(0, 1).merge_dims(1, 2)
}
fn embedding_lookup(embedding: GraphTensor, ids: GraphTensor) -> GraphTensor {
let seq = ids.dims1();
embedding.gather(
(ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
)
}
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
let q = linear_with_bias(x, w.q_proj, w.q_bias);
@@ -239,18 +241,10 @@ impl EncoderLayer {
N_AUDIO_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_AUDIO_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
}
}
@@ -295,18 +289,10 @@ impl DecoderLayer {
N_TEXT_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_TEXT_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
}
}
@@ -346,14 +332,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
for l in 0..N_TEXT_LAYER {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -376,27 +364,23 @@ pub struct WhisperEncoder {
impl WhisperEncoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
conv1_w: cx
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
.persist(),
conv1_b: cx
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
.persist(),
conv2_w: cx
.named_tensor(
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
)
.persist(),
conv2_b: cx
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
.persist(),
positional_embedding: cx
.named_tensor(
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
)
.persist(),
conv1_w: persist(
cx,
"model.encoder.conv1.weight",
(N_AUDIO_STATE, N_MELS * 3),
),
conv1_b: persist(cx, "model.encoder.conv1.bias", N_AUDIO_STATE),
conv2_w: persist(
cx,
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
),
conv2_b: persist(cx, "model.encoder.conv2.bias", N_AUDIO_STATE),
positional_embedding: persist(
cx,
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
),
layers: (0..N_AUDIO_LAYER)
.map(|i| EncoderLayer::new(i, cx))
.collect(),
@@ -427,15 +411,16 @@ pub struct WhisperDecoder {
impl WhisperDecoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
embed_tokens: cx
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
.persist(),
embed_positions: cx
.named_tensor(
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
)
.persist(),
embed_tokens: persist(
cx,
"model.decoder.embed_tokens.weight",
(N_VOCAB, N_TEXT_STATE),
),
embed_positions: persist(
cx,
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
),
layers: (0..N_TEXT_LAYER)
.map(|i| DecoderLayer::new(i, cx))
.collect(),
@@ -450,18 +435,8 @@ impl WhisperDecoder {
xa: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
// Token embedding gather
let mut x = self.embed_tokens.gather(
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
// Positional embedding gather (using pos_ids)
let pos_emb = self.embed_positions.gather(
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
x += pos_emb;
let mut x = embedding_lookup(self.embed_tokens, token_ids);
x += embedding_lookup(self.embed_positions, pos_ids);
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
for (i, layer) in self.layers.iter().enumerate() {

View File

@@ -335,7 +335,8 @@ fn main() {
println!("Compiling (search_graphs={search_graphs})...");
let t0 = Instant::now();
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
println!(" search took {:?}", t0.elapsed());
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)

View File

@@ -174,7 +174,10 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
let mut rt = graph.search(
rt,
CompileOptions::default().search_graph_limit(args.search_iters),
);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);

View File

@@ -463,7 +463,10 @@ pub(super) mod tests {
let c = func(a, b).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let lhs_values = lhs_transform(random_vec(a_shape.iter().copied().product()));
let rhs_values = rhs_transform(random_vec(b_shape.iter().copied().product()));

View File

@@ -968,7 +968,10 @@ mod tests {
let zeros = cx.iota(Expression::from(0usize), 6);
let inv = values.scatter(perm, zeros).cast(DType::F32).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.]);
rt.set_data(indexes.id, vec![5, 0, 3, 2]);
rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0]);
@@ -985,7 +988,10 @@ mod tests {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![10., 20., 30.]);
rt.set_data(indexes.id, vec![1, 3, 4]);
rt.set_data(dest.id, vec![0., 0., 0., 0., 0.]);
@@ -1001,7 +1007,10 @@ mod tests {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![99.]);
rt.set_data(indexes.id, vec![2]);
rt.set_data(dest.id, vec![1., 2., 3., 4., 5.]);
@@ -1017,7 +1026,10 @@ mod tests {
let dest = cx.tensor(4);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![40., 30., 20., 10.]);
rt.set_data(indexes.id, vec![3, 2, 1, 0]);
rt.set_data(dest.id, vec![1., 2., 3., 4.]);
@@ -1045,7 +1057,10 @@ mod tests {
let repeated = (a.repeat((2, 2)) * 1.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt.execute(&cx.dyn_map);

View File

@@ -126,7 +126,10 @@ mod tests {
let b = func(&mut cx).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.execute(&cx.dyn_map);
@@ -211,7 +214,10 @@ mod tests {
let stacked = cx.stack(&[a, b, c], 0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let a_data = random_vec(6);
let b_data = random_vec(6);

View File

@@ -493,7 +493,10 @@ pub(super) mod tests {
let b = func(a).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let v = random_vec(shape.iter().copied().product());
rt.set_data(a.id, v.clone());

View File

@@ -82,6 +82,13 @@ struct SearchSpaceContext {
intervals: DynDimIntervals,
}
#[derive(Debug, Clone)]
struct SearchProfileBucketContext {
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
bucket_indices: FxHashMap<char, usize>,
representative_dyn_map: FxHashMap<char, usize>,
}
/// A compiled bucket: (bucket_indices, representative_dyn_map, stitched_llir).
pub type BucketLLIR = (FxHashMap<char, usize>, FxHashMap<char, usize>, LLIRGraph);
@@ -137,7 +144,9 @@ impl DimBucket {
/// Use the builder pattern to configure search parameters:
/// ```
/// use luminal::prelude::CompileOptions;
/// let opts = CompileOptions::new(5)
/// let opts = CompileOptions::default()
/// .search_graph_limit(5)
/// .search_time_limit(std::time::Duration::from_secs(30))
/// .generation_size(50)
/// .mutations(40)
/// .trials(15);
@@ -146,6 +155,8 @@ impl DimBucket {
pub struct CompileOptions {
/// Maximum number of graphs to evaluate during search.
pub limit: usize,
/// Maximum wall-clock time to spend searching.
pub search_time_limit: std::time::Duration,
/// Optional maximum runtime-specific intermediate memory, in bytes.
///
/// When this is `None`, search does not apply a memory cap. Runtimes that
@@ -163,8 +174,6 @@ pub struct CompileOptions {
/// Per-candidate profiling timeout. If a profile call reaches this budget,
/// that candidate is discarded and search continues.
pub profile_timeout: Option<std::time::Duration>,
/// Optional per-group search timeout.
pub group_timeout: Option<std::time::Duration>,
/// Optional profiling dimension overrides.
pub profile_dims: FxHashMap<char, usize>,
/// Bucket definitions per dynamic dimension. Dimensions without buckets use
@@ -173,20 +182,16 @@ pub struct CompileOptions {
}
impl CompileOptions {
/// Create compile options with the given search limit. Other fields use defaults.
pub fn new(limit: usize) -> Self {
Self {
limit,
max_memory_bytes: None,
generation_size: 10,
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: Some(std::time::Duration::from_secs(1)),
group_timeout: None,
profile_dims: FxHashMap::default(),
dim_buckets: FxHashMap::default(),
}
/// Set the maximum number of graphs to evaluate during search.
pub fn search_graph_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
/// Set the maximum wall-clock time to spend searching.
pub fn search_time_limit(mut self, search_time_limit: std::time::Duration) -> Self {
self.search_time_limit = search_time_limit;
self
}
/// Set a maximum intermediate memory budget in bytes.
@@ -235,12 +240,6 @@ impl CompileOptions {
self
}
/// Set an optional per-group search timeout.
pub fn group_timeout(mut self, group_timeout: std::time::Duration) -> Self {
self.group_timeout = Some(group_timeout);
self
}
/// Override a dynamic dimension value used during search profiling.
pub fn profile_dim(mut self, dim: char, value: usize) -> Self {
self.profile_dims.insert(dim, value);
@@ -261,7 +260,18 @@ impl CompileOptions {
impl Default for CompileOptions {
fn default() -> Self {
Self::new(1)
Self {
limit: 100,
search_time_limit: std::time::Duration::MAX,
max_memory_bytes: None,
generation_size: 10,
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: Some(std::time::Duration::from_secs(1)),
profile_dims: FxHashMap::default(),
dim_buckets: FxHashMap::default(),
}
}
}
@@ -1302,10 +1312,19 @@ impl Graph {
"dim buckets must be configured in CompileOptions before build_search_space; search cannot change buckets after build",
);
let search_started_at = std::time::Instant::now();
if self.search_space_dim_buckets.is_empty() {
// No buckets: existing single-search path
let stitched =
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None, 0);
let stitched = self.search_single(
&mut runtime,
&options,
rng,
&self.dyn_map.clone(),
None,
None,
0,
search_started_at,
);
runtime.clear_intermediate_buffers();
runtime.load_llir(&stitched);
@@ -1338,7 +1357,13 @@ impl Graph {
rng,
&context.representative_dyn_map,
Some((combo_idx, n_combos)),
Some(SearchProfileBucketContext {
dim_buckets: self.search_space_dim_buckets.clone(),
bucket_indices: context.bucket_indices.clone(),
representative_dyn_map: context.representative_dyn_map.clone(),
}),
combo_idx,
search_started_at,
);
bucket_llirs.push((
context.bucket_indices,
@@ -1425,6 +1450,7 @@ impl Graph {
/// Run the genetic search and return the unrolled LLIR for the winning
/// genome. `bucket_progress`: if `Some((current_bucket_idx, total_buckets))`
/// adds a second "Bucket" progress bar.
#[allow(clippy::too_many_arguments)]
fn search_single<R: Runtime + 'static, G: rand::Rng>(
&mut self,
runtime: &mut R,
@@ -1432,7 +1458,9 @@ impl Graph {
rng: &mut G,
dyn_map: &FxHashMap<char, usize>,
bucket_progress: Option<(usize, usize)>,
bucket_profile_context: Option<SearchProfileBucketContext>,
egraph_index: usize,
search_started_at: std::time::Instant,
) -> LLIRGraph {
let mut profile_dyn_map = dyn_map.clone();
for (&dim, &value) in &options.profile_dims {
@@ -1482,11 +1510,11 @@ impl Graph {
}
};
let group_start = std::time::Instant::now();
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
runtime.clear_intermediate_buffers();
let search_time_limit_reached = || search_started_at.elapsed() >= options.search_time_limit;
let profile_timed_out = |elapsed: std::time::Duration| {
options
.profile_timeout
@@ -1531,12 +1559,27 @@ impl Graph {
collapse_loops_to_first_iter(&mut graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let (rep_metric, rep_display) =
if let Some(bucket_context) = &bucket_profile_context {
runtime.profile_with_bucket_context(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
ProfileBucketContext {
dim_buckets: &bucket_context.dim_buckets,
bucket_indices: &bucket_context.bucket_indices,
representative_dyn_map: &bucket_context.representative_dyn_map,
},
)
} else {
runtime.profile(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
)
};
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
(
@@ -1561,11 +1604,8 @@ impl Graph {
break;
}
Ok(_) | Err(_) => {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
panic!("Failed to find a viable initial genome before timeout");
if search_time_limit_reached() {
panic!("Failed to find a viable initial genome before search time limit");
}
list_cache.clear();
expr_cache.clear();
@@ -1586,10 +1626,7 @@ impl Graph {
let mut resample_generation = false;
while n_graphs < search_limit {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
if search_time_limit_reached() {
break;
}
@@ -1623,10 +1660,7 @@ impl Graph {
let mut generation_found_non_timeout = false;
for genome in all_offspring {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
if search_time_limit_reached() {
break;
}
n_graphs += 1;
@@ -1659,12 +1693,27 @@ impl Graph {
collapse_loops_to_first_iter(&mut llir_graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let (rep_metric, rep_display) =
if let Some(bucket_context) = &bucket_profile_context {
runtime.profile_with_bucket_context(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
ProfileBucketContext {
dim_buckets: &bucket_context.dim_buckets,
bucket_indices: &bucket_context.bucket_indices,
representative_dyn_map: &bucket_context.representative_dyn_map,
},
)
} else {
runtime.profile(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
)
};
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan =
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
@@ -2808,6 +2857,7 @@ mod tests {
use super::*;
use crate::egglog_utils::hash_egglog_normalized;
use crate::tests::{assert_close, random_vec};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Default)]
struct TestMemoryRuntime;
@@ -2845,6 +2895,69 @@ mod tests {
}
}
static PROFILE_CALLS: AtomicUsize = AtomicUsize::new(0);
#[derive(Default)]
struct CountingRuntime;
impl Runtime for CountingRuntime {
type Ops = ();
type CompileArg = ();
type ExecReturn = ();
type ProfileMetric = usize;
fn initialize(_: Self::CompileArg) -> Self {
Self
}
fn load_llir(&mut self, _: &LLIRGraph) {}
fn execute(&mut self, _: &FxHashMap<char, usize>) -> Self::ExecReturn {}
fn profile(
&mut self,
_: &LLIRGraph,
_: &FxHashMap<char, usize>,
_: usize,
_: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String) {
let count = PROFILE_CALLS.fetch_add(1, Ordering::SeqCst);
(count, format!("{count} ms"))
}
}
#[test]
fn compile_options_defaults_and_search_time_limit_builder() {
let opts = CompileOptions::default();
assert_eq!(opts.limit, 100);
assert_eq!(opts.search_time_limit, std::time::Duration::MAX);
let time_limit = std::time::Duration::from_millis(25);
let opts = CompileOptions::default()
.search_graph_limit(7)
.search_time_limit(time_limit);
assert_eq!(opts.limit, 7);
assert_eq!(opts.search_time_limit, time_limit);
}
#[test]
fn search_time_limit_stops_after_initial_viable_candidate() {
let mut cx = Graph::new();
let a = cx.tensor((4, 8));
let b = cx.tensor((8, 4));
let c = cx.tensor((4, 4));
let _ = (a.matmul(b) + c).relu().softmax(1).output();
cx.build_search_space::<CountingRuntime>(CompileOptions::default());
PROFILE_CALLS.store(0, Ordering::SeqCst);
let _ = cx.search(
CountingRuntime,
CompileOptions::default().search_time_limit(std::time::Duration::ZERO),
);
assert_eq!(PROFILE_CALLS.load(Ordering::SeqCst), 1);
}
#[test]
fn build_search_space_without_explicit_max_memory_has_no_cap() {
let mut cx = Graph::new();
@@ -2862,7 +2975,10 @@ mod tests {
let _ = cx.tensor(1).output();
cx.build_search_space::<TestMemoryRuntime>(CompileOptions::default().max_memory_bytes(0));
let _ = cx.search(TestMemoryRuntime, CompileOptions::new(1));
let _ = cx.search(
TestMemoryRuntime,
CompileOptions::default().search_graph_limit(1),
);
}
#[test]
@@ -2870,7 +2986,10 @@ mod tests {
let mut cx = Graph::new();
let _ = cx.tensor(1).output();
let _ = cx.compile(TestMemoryRuntime, CompileOptions::new(1));
let _ = cx.compile(
TestMemoryRuntime,
CompileOptions::default().search_graph_limit(1),
);
assert_eq!(cx.egraphs.len(), 1);
assert_eq!(cx.search_space_max_memory_bytes, None);
@@ -2908,7 +3027,7 @@ mod tests {
let mut rng = rand::rng();
let _ = cx.search_with_rng(
TestMemoryRuntime,
CompileOptions::new(1).dim_buckets(
CompileOptions::default().search_graph_limit(1).dim_buckets(
's',
&[DimBucket::new(1, 1), DimBucket::new(2, 4).representative(4)],
),
@@ -3019,7 +3138,7 @@ mod tests {
let vals = random_vec(8);
let mut rt = NativeRuntime::default();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.set_data(x.id, vals.clone());
rt.execute(&cx.dyn_map);
@@ -3056,7 +3175,7 @@ mod tests {
let vals = random_vec(8);
let mut rt = NativeRuntime::default();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.set_data(x.id, vals.clone());
rt.execute(&cx.dyn_map);

View File

@@ -7,6 +7,12 @@ use crate::prelude::*;
use as_any::{AsAny, Downcast};
use rustc_hash::FxHashMap;
pub struct ProfileBucketContext<'a> {
pub dim_buckets: &'a FxHashMap<char, Vec<DimBucket>>,
pub bucket_indices: &'a FxHashMap<char, usize>,
pub representative_dyn_map: &'a FxHashMap<char, usize>,
}
pub trait Runtime {
type Ops: IntoEgglogOp;
type CompileArg;
@@ -36,6 +42,20 @@ pub trait Runtime {
trials: usize,
timeout: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String);
/// Profile one candidate in the context of a specific dynamic-dimension
/// bucket. Runtimes with bucket-sensitive lowering can override this so
/// search ranks candidates under the same execution model used after
/// final bucket compilation.
fn profile_with_bucket_context(
&mut self,
llir_graph: &LLIRGraph,
dyn_map: &FxHashMap<char, usize>,
trials: usize,
timeout: Option<std::time::Duration>,
_bucket_context: ProfileBucketContext<'_>,
) -> (Self::ProfileMetric, String) {
self.profile(llir_graph, dyn_map, trials, timeout)
}
/// Aggregate multiple profile metrics into one comparable metric.
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {

View File

@@ -24,7 +24,7 @@ proptest! {
let d = (b * c / e).sin().output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
rt.set_data(b.id, vals.clone());
rt.set_data(c.id, vals.clone());
@@ -58,7 +58,7 @@ proptest! {
let a = b.matmul(c).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
let lhs = lhs.into_iter().take(m * k).collect::<Vec<f32>>();
let rhs = rhs.into_iter().take(k * n).collect::<Vec<f32>>();
rt.set_data(b.id, lhs.clone());
@@ -82,7 +82,7 @@ proptest! {
let a = cx.tensor((2, 2));
let b = (a.permute((1, 0)) * 1.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
rt.set_data(a.id, values.clone());
rt.execute(&cx.dyn_map);
@@ -99,7 +99,7 @@ proptest! {
let mask = a.ge(kth_largest.expand_dim(1, cols)).cast(crate::dtype::DType::F32);
let filtered = (a * mask).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
let values = values.into_iter().take(rows * cols).collect::<Vec<f32>>();
rt.set_data(a.id, values.clone());
rt.execute(&cx.dyn_map);
@@ -465,7 +465,10 @@ fn test_inputs_consumed_after_execute() {
let a = cx.tensor(3);
let _b = (a * 2.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
rt.execute(&cx.dyn_map);
// Second execute should panic — input 'a' was consumed
@@ -481,7 +484,10 @@ fn test_passthrough_preserves_weights() {
w.persist();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Iteration 1
rt.set_data(w.id, vec![1.0, 2.0, 3.0]);
@@ -503,7 +509,10 @@ fn test_only_outputs_remain() {
let a = cx.tensor(3);
let _b = (a * 2.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
rt.execute(&cx.dyn_map);
let output_count = rt
@@ -556,7 +565,10 @@ fn integration_auto_loop_rolling_matches_reference_native_runtime() {
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
graph.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = graph.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = graph.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(input_id, input);
for (node, data) in weight_ids.iter().zip(weights.iter()) {
rt.set_data(*node, data.clone());