mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
cuda_133
...
flashinfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62e86f9dc5 | ||
|
|
75e4e6be0a |
@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
|
||||
let c = a.matmul(b).output();
|
||||
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
|
||||
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, CompileOptions::new(5));
|
||||
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
|
||||
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, CompileOptions::new(5));
|
||||
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use half::{bf16, f16};
|
||||
use luminal::{
|
||||
@@ -15,6 +18,8 @@ use luminal::{
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::kernel::CudaGraphHandle;
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
@@ -33,7 +38,7 @@ use crate::{
|
||||
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
@@ -581,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum LtScalar {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) enum LtScalar {
|
||||
F64(f64),
|
||||
F32(f32),
|
||||
F16(f16),
|
||||
@@ -622,16 +627,16 @@ impl LtScalar {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulProblem {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulProblem {
|
||||
m: u64,
|
||||
n: u64,
|
||||
k: u64,
|
||||
batch_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatrixSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatrixSpec {
|
||||
dtype: cudaDataType,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
@@ -640,8 +645,8 @@ struct LtMatrixSpec {
|
||||
order: cublasLtOrder_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtComputeSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtComputeSpec {
|
||||
compute_type: cublasComputeType_t,
|
||||
scale_dtype: cudaDataType,
|
||||
alpha: LtScalar,
|
||||
@@ -649,8 +654,8 @@ struct LtComputeSpec {
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtMatmulSpec {
|
||||
problem: LtMatmulProblem,
|
||||
trans_a: cublasOperation_t,
|
||||
trans_b: cublasOperation_t,
|
||||
@@ -662,8 +667,8 @@ struct LtMatmulSpec {
|
||||
workspace_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulPointers {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulPointers {
|
||||
a: u64,
|
||||
b: u64,
|
||||
c: u64,
|
||||
@@ -673,7 +678,35 @@ struct LtMatmulPointers {
|
||||
b_scale: Option<u64>,
|
||||
}
|
||||
|
||||
struct LtRawDescriptors {
|
||||
impl LtMatmulPointers {
|
||||
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
|
||||
let mut fields = Vec::new();
|
||||
if self.a != other.a {
|
||||
fields.push("a");
|
||||
}
|
||||
if self.b != other.b {
|
||||
fields.push("b");
|
||||
}
|
||||
if self.c != other.c {
|
||||
fields.push("c");
|
||||
}
|
||||
if self.d != other.d {
|
||||
fields.push("d");
|
||||
}
|
||||
if self.bias != other.bias {
|
||||
fields.push("bias");
|
||||
}
|
||||
if self.a_scale != other.a_scale {
|
||||
fields.push("a_scale");
|
||||
}
|
||||
if self.b_scale != other.b_scale {
|
||||
fields.push("b_scale");
|
||||
}
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LtRawDescriptors {
|
||||
matmul_desc: cublasLtMatmulDesc_t,
|
||||
a_desc: cublasLtMatrixLayout_t,
|
||||
b_desc: cublasLtMatrixLayout_t,
|
||||
@@ -682,6 +715,23 @@ struct LtRawDescriptors {
|
||||
preference: cublasLtMatmulPreference_t,
|
||||
}
|
||||
|
||||
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
|
||||
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
|
||||
> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
|
||||
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
|
||||
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
impl Default for LtRawDescriptors {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -720,6 +770,121 @@ impl Drop for LtRawDescriptors {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PreparedCuBlasLtMatmul {
|
||||
cublaslt: Arc<CudaBlasLT>,
|
||||
spec: LtMatmulSpec,
|
||||
resources: LtRawDescriptors,
|
||||
heuristic: cublasLtMatmulHeuristicResult_t,
|
||||
_workspace: CudaSlice<u8>,
|
||||
workspace_ptr: u64,
|
||||
_a_scale: Option<CudaSlice<f32>>,
|
||||
default_a_scale_ptr: Option<u64>,
|
||||
_b_scale: Option<CudaSlice<f32>>,
|
||||
default_b_scale_ptr: Option<u64>,
|
||||
_c_scale: Option<CudaSlice<f32>>,
|
||||
_d_scale: Option<CudaSlice<f32>>,
|
||||
}
|
||||
|
||||
impl PreparedCuBlasLtMatmul {
|
||||
fn update_descriptor_pointers(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.context().bind_to_thread()?;
|
||||
if let Some(bias_ptr) = ptrs.bias {
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias_ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
|
||||
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
|
||||
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn enqueue(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
self.update_descriptor_pointers(stream, ptrs)?;
|
||||
let alpha_ptr = self.spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = self.spec.compute.beta.as_ptr();
|
||||
unsafe {
|
||||
cublasLtMatmul(
|
||||
*self.cublaslt.handle(),
|
||||
self.resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
self.resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
self.resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
self.resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
self.resources.d_desc,
|
||||
&self.heuristic.algo,
|
||||
self.workspace_ptr as *mut std::ffi::c_void,
|
||||
self.spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtCaptureSignature {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtPrepareKey {
|
||||
spec: LtMatmulSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtResolvedGraphCall {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
impl CuBlasLtResolvedGraphCall {
|
||||
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
|
||||
CuBlasLtCaptureSignature {
|
||||
spec: self.spec,
|
||||
ptrs: self.ptrs,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
|
||||
CuBlasLtPrepareKey { spec: self.spec }
|
||||
}
|
||||
}
|
||||
|
||||
fn create_matrix_layout(
|
||||
desc: &mut cublasLtMatrixLayout_t,
|
||||
spec: LtMatrixSpec,
|
||||
@@ -796,12 +961,15 @@ fn set_scalar_scale_pointer(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
pub(crate) fn prepare_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
#[cfg(test)]
|
||||
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
|
||||
@@ -813,17 +981,17 @@ fn run_cublaslt_matmul(
|
||||
|
||||
let mut resources = LtRawDescriptors::default();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
|
||||
drop(workspace_guard);
|
||||
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
@@ -879,29 +1047,27 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &a_scale {
|
||||
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &b_scale {
|
||||
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
|
||||
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
|
||||
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
@@ -935,6 +1101,7 @@ fn run_cublaslt_matmul(
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
|
||||
|
||||
create_matrix_layout(&mut resources.a_desc, spec.a)?;
|
||||
create_matrix_layout(&mut resources.b_desc, spec.b)?;
|
||||
@@ -952,58 +1119,148 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
|
||||
let cached_heuristic = {
|
||||
let cache = heuristic_cache.lock().unwrap();
|
||||
cache
|
||||
.iter()
|
||||
.find(|(cached_spec, _)| cached_spec == spec)
|
||||
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
|
||||
};
|
||||
if let Some(cached) = cached_heuristic {
|
||||
heuristic = cached;
|
||||
} else {
|
||||
let mut algo_count: i32 = 0;
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
}
|
||||
|
||||
let alpha_ptr = spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = spec.compute.beta.as_ptr();
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
resources.d_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
heuristic_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(PreparedCuBlasLtMatmul {
|
||||
cublaslt: cublaslt.clone(),
|
||||
spec: *spec,
|
||||
resources,
|
||||
heuristic,
|
||||
_workspace: workspace,
|
||||
workspace_ptr,
|
||||
_a_scale: a_scale,
|
||||
default_a_scale_ptr,
|
||||
_b_scale: b_scale,
|
||||
default_b_scale_ptr,
|
||||
_c_scale: c_scale,
|
||||
_d_scale: d_scale,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
|
||||
prepared.enqueue(stream, ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
|
||||
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
|
||||
let capture_stream = stream.context().new_stream()?;
|
||||
let cublaslt = try_create_cublaslt(stream.clone())
|
||||
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
|
||||
|
||||
let a_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let b_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let d_buf = unsafe { stream.alloc::<f32>(1)? };
|
||||
let (a, a_guard) = a_buf.device_ptr(stream);
|
||||
let (b, b_guard) = b_buf.device_ptr(stream);
|
||||
let (d, d_guard) = d_buf.device_ptr(stream);
|
||||
drop((a_guard, b_guard, d_guard));
|
||||
|
||||
let matrix = LtMatrixSpec {
|
||||
dtype: cudaDataType::CUDA_R_32F,
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
ld: 1,
|
||||
batch_stride: 1,
|
||||
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m: 1,
|
||||
n: 1,
|
||||
k: 1,
|
||||
batch_count: 1,
|
||||
},
|
||||
trans_a: cublasOperation_t::CUBLAS_OP_N,
|
||||
trans_b: cublasOperation_t::CUBLAS_OP_N,
|
||||
a: matrix,
|
||||
b: matrix,
|
||||
c: matrix,
|
||||
d: matrix,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
scale_dtype: cudaDataType::CUDA_R_32F,
|
||||
alpha: LtScalar::F32(1.0),
|
||||
beta: LtScalar::F32(0.0),
|
||||
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
},
|
||||
workspace_size: 1024 * 1024,
|
||||
};
|
||||
let ptrs = LtMatmulPointers {
|
||||
a,
|
||||
b,
|
||||
c: d,
|
||||
d,
|
||||
bias: None,
|
||||
a_scale: None,
|
||||
b_scale: None,
|
||||
};
|
||||
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
|
||||
|
||||
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
|
||||
let entry = graph.add_empty_node(&[])?;
|
||||
capture_stream.join(stream)?;
|
||||
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
|
||||
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
|
||||
let end_result = graph.end_capture(&capture_stream);
|
||||
enqueue_result?;
|
||||
end_result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let supported = probe(stream).is_ok();
|
||||
let _ = stream.synchronize();
|
||||
supported
|
||||
}
|
||||
|
||||
fn resolve_cublaslt_pointers(
|
||||
@@ -1126,6 +1383,151 @@ impl CuBlasLt {
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
pub(crate) fn graph_inputs(&self) -> usize {
|
||||
self.n_inputs()
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_for_graph(
|
||||
&self,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
|
||||
|
||||
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
|
||||
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
|
||||
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
|
||||
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
|
||||
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
|
||||
let element_size = (self.d_dtype.bits() / 8) as u64;
|
||||
assert!(
|
||||
element_size > 0,
|
||||
"cuBLAS LT does not support sub-byte dtype {}",
|
||||
self.d_dtype
|
||||
);
|
||||
|
||||
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
|
||||
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
self_node,
|
||||
inputs,
|
||||
buffers,
|
||||
self.beta,
|
||||
self.epilogue,
|
||||
self.a_scale_input,
|
||||
self.b_scale_input,
|
||||
)?;
|
||||
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
|
||||
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
|
||||
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
|
||||
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT_resolve_graph",
|
||||
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
|
||||
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
|
||||
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
|
||||
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
|
||||
?self.epilogue,
|
||||
)
|
||||
.entered();
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
|
||||
let c_spec = LtMatrixSpec {
|
||||
dtype: c_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldc,
|
||||
batch_stride: stride_c,
|
||||
order: self.c_order,
|
||||
};
|
||||
let d_spec = LtMatrixSpec {
|
||||
dtype: d_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldd,
|
||||
batch_stride: stride_d,
|
||||
order: self.d_order,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_count,
|
||||
},
|
||||
trans_a: a_layout,
|
||||
trans_b: b_layout,
|
||||
a: LtMatrixSpec {
|
||||
dtype: a_cuda_dtype,
|
||||
rows: a_rows,
|
||||
cols: a_cols,
|
||||
ld: lda,
|
||||
batch_stride: stride_a,
|
||||
order: self.a_order,
|
||||
},
|
||||
b: LtMatrixSpec {
|
||||
dtype: b_cuda_dtype,
|
||||
rows: b_rows,
|
||||
cols: b_cols,
|
||||
ld: ldb,
|
||||
batch_stride: stride_b,
|
||||
order: self.b_order,
|
||||
},
|
||||
c: c_spec,
|
||||
d: d_spec,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: self.compute_type,
|
||||
scale_dtype: scale_cuda_dtype,
|
||||
alpha,
|
||||
beta,
|
||||
epilogue: self.epilogue,
|
||||
},
|
||||
workspace_size: WORKSPACE_SIZE,
|
||||
};
|
||||
|
||||
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_resolved_for_graph(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
resolved: CuBlasLtResolvedGraphCall,
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
|
||||
(
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublaslt;
|
||||
pub(crate) mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
@@ -167,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns pairs of extra buffer nodes that must not share arena storage.
|
||||
///
|
||||
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
|
||||
/// two buffers may have disjoint positions in one topological order while
|
||||
/// still being unordered by real dependencies, so CUDA could overlap them.
|
||||
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -7,7 +7,10 @@ use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaContext, CudaFunction, CudaStream, DriverError,
|
||||
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
|
||||
sys::{
|
||||
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
|
||||
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
|
||||
},
|
||||
};
|
||||
|
||||
/// A CUDA graph that can be modified and instantiated.
|
||||
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates a kernel node in the mutable source graph.
|
||||
pub unsafe fn set_kernel_node_params(
|
||||
&mut self,
|
||||
node: CUgraphNode,
|
||||
func: CUfunction,
|
||||
grid_dim: (u32, u32, u32),
|
||||
block_dim: (u32, u32, u32),
|
||||
shared_mem_bytes: u32,
|
||||
kernel_params: *mut *mut c_void,
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let params = sys::CUDA_KERNEL_NODE_PARAMS {
|
||||
func,
|
||||
gridDimX: grid_dim.0,
|
||||
gridDimY: grid_dim.1,
|
||||
gridDimZ: grid_dim.2,
|
||||
blockDimX: block_dim.0,
|
||||
blockDimY: block_dim.1,
|
||||
blockDimZ: block_dim.2,
|
||||
sharedMemBytes: shared_mem_bytes,
|
||||
kernelParams: kernel_params,
|
||||
extra: std::ptr::null_mut(),
|
||||
kern: std::ptr::null_mut(),
|
||||
ctx: std::ptr::null_mut(),
|
||||
};
|
||||
|
||||
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, ¶ms).result() }
|
||||
}
|
||||
|
||||
/// Adds an empty dependency node to the graph.
|
||||
pub fn add_empty_node(
|
||||
&mut self,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<CUgraphNode, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut node = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphAddEmptyNode(
|
||||
node.as_mut_ptr(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
dependencies.len(),
|
||||
)
|
||||
.result()?;
|
||||
Ok(node.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Destroys a node in the mutable graph.
|
||||
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe { sys::cuGraphDestroyNode(node).result() }
|
||||
}
|
||||
|
||||
/// Adds dependency edges to the mutable graph.
|
||||
pub fn add_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Removes dependency edges from the mutable graph.
|
||||
pub fn remove_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Returns all nodes currently in the graph.
|
||||
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut nodes = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
nodes.truncate(count);
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
/// Returns the direct dependencies of a node.
|
||||
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Returns the direct dependents of a node.
|
||||
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Begins stream capture that appends captured work into this graph.
|
||||
pub fn begin_capture_to_graph(
|
||||
&mut self,
|
||||
stream: &CudaStream,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuStreamBeginCaptureToGraph(
|
||||
stream.cu_stream(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
std::ptr::null(),
|
||||
dependencies.len(),
|
||||
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
|
||||
)
|
||||
.result()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ends stream capture previously started by begin_capture_to_graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut graph = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
|
||||
let captured = graph.assume_init();
|
||||
if captured != self.cu_graph && !captured.is_null() {
|
||||
sys::cuGraphDestroy(captured).result()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds an event record node to the graph for timing.
|
||||
pub fn add_event_record_node(
|
||||
&mut self,
|
||||
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
|
||||
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, ¶ms) }
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Attempts to update this executable graph from an already-mutated source graph.
|
||||
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut result = CUgraphExecUpdateResultInfo {
|
||||
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
|
||||
errorNode: std::ptr::null_mut(),
|
||||
errorFromNode: std::ptr::null_mut(),
|
||||
};
|
||||
unsafe {
|
||||
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
|
||||
}
|
||||
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
|
||||
return Err(DriverError(
|
||||
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphExecHandle {
|
||||
@@ -480,6 +672,38 @@ mod tests {
|
||||
assert_eq!(result[0], 6.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_empty_node_dependency_reconnect() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let mut graph = CudaGraphHandle::new(ctx).unwrap();
|
||||
|
||||
let entry = graph.add_empty_node(&[]).unwrap();
|
||||
let middle = graph.add_empty_node(&[entry]).unwrap();
|
||||
let exit = graph.add_empty_node(&[middle]).unwrap();
|
||||
|
||||
let nodes = graph.nodes().unwrap();
|
||||
assert!(nodes.contains(&entry));
|
||||
assert!(nodes.contains(&middle));
|
||||
assert!(nodes.contains(&exit));
|
||||
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
|
||||
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
|
||||
|
||||
graph.add_dependencies(&[entry], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert!(exit_deps.contains(&entry));
|
||||
assert!(exit_deps.contains(&middle));
|
||||
|
||||
graph.remove_dependencies(&[middle], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert_eq!(exit_deps.len(), 1);
|
||||
assert!(exit_deps.contains(&entry));
|
||||
|
||||
unsafe {
|
||||
graph.destroy_node(middle).unwrap();
|
||||
}
|
||||
assert!(!graph.nodes().unwrap().contains(&middle));
|
||||
}
|
||||
|
||||
// CUDA Graph Tests
|
||||
|
||||
#[test]
|
||||
@@ -499,7 +723,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -531,7 +755,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -569,7 +793,7 @@ mod tests {
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
@@ -611,7 +835,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
@@ -642,7 +866,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
@@ -675,7 +899,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -304,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
|
||||
|
||||
// Kernel to host op compilation
|
||||
mod to_host;
|
||||
#[cfg(test)]
|
||||
pub(crate) use to_host::CudaGraphDebugSummary;
|
||||
pub use to_host::{CudaGraphOp, kernel_to_host};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,11 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -91,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -146,7 +154,11 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
|
||||
rt1 = cx1.search_with_rng(
|
||||
rt1,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng1,
|
||||
);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -158,7 +170,11 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
|
||||
rt2 = cx2.search_with_rng(
|
||||
rt2,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng2,
|
||||
);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -186,7 +202,11 @@ fn test_bucket_out_of_range_panics() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -211,7 +231,11 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -257,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -311,7 +339,11 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
|
||||
@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
|
||||
|
||||
// Use seeded search for deterministic variant selection.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
|
||||
rt.set_data(gather_idx, scatter);
|
||||
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(gathered), expected_gather);
|
||||
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(weights);
|
||||
|
||||
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(weights, weights_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(roundtrip);
|
||||
|
||||
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(w, w_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(topk);
|
||||
|
||||
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(argsort);
|
||||
|
||||
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(ranks);
|
||||
|
||||
@@ -1395,7 +1447,11 @@ fn test_dynamic_3d_flat_index_iota_rows() {
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(idx);
|
||||
|
||||
@@ -1438,7 +1494,11 @@ fn test_dynamic_2d_to_3d_gather_rows() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(data, data_values.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(out);
|
||||
|
||||
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
|
||||
rt.set_data(topk, topk_data.clone());
|
||||
rt.set_data(weights, weights_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -7,10 +7,12 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
host::{
|
||||
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
|
||||
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
|
||||
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
|
||||
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
|
||||
cublaslt_type_tuple,
|
||||
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
|
||||
expected
|
||||
}
|
||||
|
||||
fn reference_mixed_chain(
|
||||
a: &[f32],
|
||||
pre: &[f32],
|
||||
b: &[f32],
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut expected = vec![0.0; m * n];
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut acc = 0.0;
|
||||
for inner in 0..k {
|
||||
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
|
||||
}
|
||||
expected[row * n + col] = acc.exp();
|
||||
}
|
||||
}
|
||||
expected
|
||||
}
|
||||
|
||||
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
|
||||
crate::try_create_cublaslt(stream.clone()).is_ok()
|
||||
}
|
||||
|
||||
fn build_mixed_chain_graph(
|
||||
m: impl Into<Expression>,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let m = m.into();
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let pre = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = ((a + pre).matmul(b).exp()).output();
|
||||
(cx, a.id, pre.id, b.id, out.id)
|
||||
}
|
||||
|
||||
fn add_in_place(values: &mut [f32], addends: &[f32]) {
|
||||
for (value, addend) in values.iter_mut().zip(addends) {
|
||||
*value += *addend;
|
||||
@@ -507,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
|
||||
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
let mixed = summaries
|
||||
.iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
|
||||
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
|
||||
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
|
||||
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (5, 8, 8);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let first = a.matmul(b);
|
||||
let out = (a + first.sin()).matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"ordered matching cuBLASLt prepared reuse",
|
||||
|llir| {
|
||||
let orders = cublaslt_matrix_order_tuples(llir);
|
||||
orders.len() == 2 && orders[0] == orders[1]
|
||||
},
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
|
||||
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
let dep = a_data
|
||||
.iter()
|
||||
.zip(&first)
|
||||
.map(|(a, first)| a + first.sin())
|
||||
.collect::<Vec<_>>();
|
||||
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 2)
|
||||
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
|
||||
assert_eq!(
|
||||
summary.n_cublaslt_prepared, 1,
|
||||
"dependency-ordered matching cuBLASLt calls should share prepared resources"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
a.output();
|
||||
b.output();
|
||||
cx.set_dim('p', 1);
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"cuBLASLt unchanged under unrelated dyn dim",
|
||||
|_| true,
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let first_prepare_count = cublaslt_prepare_count_for_test();
|
||||
assert!(
|
||||
first_prepare_count > 0,
|
||||
"first execution should prepare the captured cuBLASLt island"
|
||||
);
|
||||
|
||||
cx.set_dim('p', 2);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count,
|
||||
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('m', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('m', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [
|
||||
(7usize, 0xC002_0001),
|
||||
(9usize, 0xC002_0002),
|
||||
(7usize, 0xC002_0003),
|
||||
] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_with_dynamic_c_spec_is_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('c', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('c', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
|
||||
cx.set_dim('c', c);
|
||||
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('s', 1);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
|
||||
|
||||
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let bucket_llirs = vec![
|
||||
(
|
||||
[('s', 0usize)].into_iter().collect(),
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
llir.clone(),
|
||||
),
|
||||
(
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
[('s', 3usize)].into_iter().collect(),
|
||||
llir,
|
||||
),
|
||||
];
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"singleton s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
|
||||
cx.set_dim('s', 3);
|
||||
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"range s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
let mut first_prepare_count = None;
|
||||
for seed in [0xCC00_0001, 0xCC00_0002] {
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
if first_prepare_count.is_none() {
|
||||
first_prepare_count = Some(cublaslt_prepare_count_for_test());
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count.unwrap(),
|
||||
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
|
||||
);
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after pointer recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
|
||||
cx.set_dim('m', 7);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after dynamic-shape recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {
|
||||
|
||||
@@ -89,7 +89,7 @@ fn run_reference_attention(
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
|
||||
@@ -61,7 +61,7 @@ fn generic_matmul_executes_noncontiguous_merged_head_projection() {
|
||||
rt.set_data(attn, attn_data.as_slice());
|
||||
rt.set_data(weight, weight_data.as_slice());
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
assert!(
|
||||
rt.kernel_names().contains(&"GenericMatmul"),
|
||||
"expected GenericMatmul to be selected, kernels: {:?}",
|
||||
@@ -109,7 +109,7 @@ fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
|
||||
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -518,7 +518,7 @@ mod gemma {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -655,7 +655,7 @@ mod qwen {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, CompileOptions::new(10));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out_dim0 = rt.get_i32(sorted_dim0.id);
|
||||
let out_dim1 = rt.get_i32(sorted_dim1.id);
|
||||
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn kernel_add_bandwidth_test() {
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -233,7 +233,9 @@ fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
rt.get_f32(model.output.id)
|
||||
@@ -278,7 +280,9 @@ fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
rt.get_f32(model.output.id)
|
||||
|
||||
@@ -50,7 +50,7 @@ fn rope_matches_cpu_reference() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
@@ -98,7 +98,7 @@ fn rope_flux2_shape() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
|
||||
|
||||
// Use minimal search iterations to avoid excessive graph rewriting
|
||||
// which can cause float drift through softmax/RMSNorm reordering
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
|
||||
.collect();
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(weight, weight_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
|
||||
rt.set_data(wk, wk_data.clone());
|
||||
rt.set_data(wv, wv_data.clone());
|
||||
rt.set_data(wo, wo_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -530,7 +530,7 @@ fn test_rolled_chained_scalar_muls() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -331,7 +331,7 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_with_rng(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::new(1),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
@@ -701,7 +701,7 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
@@ -776,7 +776,7 @@ pub fn test_binary_cuda<T: TestDType>(
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
@@ -844,7 +844,7 @@ pub fn test_mod(
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
@@ -503,7 +503,8 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, CompileOptions::new(SEARCH_GRAPHS));
|
||||
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
|
||||
@@ -41,7 +41,11 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
|
||||
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
cx.search_with_rng(rt, CompileOptions::new(limit), &mut rng)
|
||||
cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(limit),
|
||||
&mut rng,
|
||||
)
|
||||
}
|
||||
|
||||
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
|
||||
@@ -301,7 +305,7 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -322,7 +326,7 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
|
||||
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, vec![1.0f32; 4]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
@@ -350,7 +354,7 @@ fn metal_int_arithmetic_preserves_large_values() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -373,7 +377,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -395,7 +399,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -417,7 +421,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -449,7 +453,7 @@ fn metal_simple_add() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -469,7 +473,7 @@ fn metal_simple_mul() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -487,7 +491,7 @@ fn metal_simple_exp2() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -504,7 +508,7 @@ fn metal_simple_log2() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -529,7 +533,7 @@ fn metal_simple_sin() {
|
||||
3.0 * std::f32::consts::FRAC_PI_2,
|
||||
],
|
||||
);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -546,7 +550,7 @@ fn metal_simple_sqrt() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -563,7 +567,7 @@ fn metal_simple_recip() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -582,7 +586,7 @@ fn metal_simple_mod() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
|
||||
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -601,7 +605,7 @@ fn metal_simple_less_than() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -621,7 +625,7 @@ fn metal_simple_sum_reduce() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -640,7 +644,7 @@ fn metal_simple_max_reduce() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
|
||||
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -657,7 +661,7 @@ fn metal_f16_cast_roundtrip() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -678,7 +682,7 @@ fn metal_f16_intermediate_add_roundtrip() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
|
||||
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1026,7 +1030,7 @@ fn metal_rms_norm() {
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1066,7 +1070,7 @@ fn metal_self_attention() {
|
||||
rt.set_data(wk, &wk_data);
|
||||
rt.set_data(wv, &wv_data);
|
||||
rt.set_data(wo, &wo_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1125,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
|
||||
rt.set_data(wk, to_f16_vec(&wk_data));
|
||||
rt.set_data(wv, to_f16_vec(&wv_data));
|
||||
rt.set_data(wo, to_f16_vec(&wo_data));
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1169,7 +1173,7 @@ fn metal_swiglu_mlp() {
|
||||
rt.set_data(w_gate, &gate_data);
|
||||
rt.set_data(w_up, &up_data);
|
||||
rt.set_data(w_down, &down_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1219,7 +1223,7 @@ fn metal_mini_transformer_layer() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1285,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1327,7 +1331,7 @@ fn test_scatter_basic() {
|
||||
rt.set_data(src, &[10.0, 20.0, 30.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
|
||||
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1349,7 +1353,7 @@ fn test_scatter_buffer_roundtrip() {
|
||||
rt.set_data(src, &[0.0]);
|
||||
rt.set_data(indexes, &[0.0]);
|
||||
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
|
||||
for (pos, value, expected) in [
|
||||
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
|
||||
@@ -1383,7 +1387,7 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
|
||||
rt.set_data(weights, &[99.0, 99.0, 99.0]);
|
||||
rt.set_data(bias, &[0.5, 1.0, -1.5]);
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1448,7 +1452,7 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1475,7 +1479,7 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1495,7 +1499,7 @@ fn test_scatter_into_nonzero_dest() {
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1522,7 +1526,7 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
|
||||
rt.set_data(src, &[7.0, 8.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1551,7 +1555,7 @@ fn test_scatter_no_copy_handles_2d_destination() {
|
||||
rt.set_data(src, &[9.0, 8.0]);
|
||||
rt.set_data(indexes, &[2.0, 4.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1578,7 +1582,7 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1605,7 +1609,7 @@ fn test_scatter_all_positions() {
|
||||
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
|
||||
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1624,7 +1628,7 @@ fn test_gather_preserves_data_dtype() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -167,7 +167,10 @@ mod tests {
|
||||
let result = gather_rows(data, indices, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
|
||||
rt.set_data(
|
||||
@@ -193,7 +196,10 @@ mod tests {
|
||||
let result = scatter_rows(src, indices, dest, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
|
||||
rt.set_data(indices.id, vec![1, 3]);
|
||||
@@ -219,7 +225,10 @@ mod tests {
|
||||
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
|
||||
@@ -272,7 +281,10 @@ mod tests {
|
||||
let v_cache_new = v_cache_new.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
|
||||
rt.set_data(q.id, vec![1., 0., 1., 0.]);
|
||||
@@ -345,7 +357,10 @@ mod tests {
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
|
||||
// K cached at slot 0: [1, 0]
|
||||
@@ -417,7 +432,10 @@ mod tests {
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Cache has 1 token at slot 0
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
|
||||
@@ -184,7 +184,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 2.0, 3.0];
|
||||
// Router strongly favors expert 0
|
||||
@@ -239,7 +242,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 1.0];
|
||||
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
|
||||
@@ -293,7 +299,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![
|
||||
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
|
||||
@@ -350,7 +359,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
@@ -395,7 +407,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(batch * in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
|
||||
@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn search_options() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
|
||||
}
|
||||
|
||||
fn env_f32(name: &str, default: f32) -> f32 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -187,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
|
||||
|
||||
println!("Compiling text encoder...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
println!("Encoding prompt...");
|
||||
@@ -345,10 +349,9 @@ fn run_full_pipeline(
|
||||
{
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
|
||||
let opts = luminal::graph::CompileOptions::new(env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search_with_rng(runtime, opts, &mut rng);
|
||||
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
|
||||
} else {
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
}
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
@@ -415,7 +418,7 @@ fn run_full_pipeline(
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
|
||||
runtime.set_data(latent_in, vae_input);
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let img = runtime.get_f32(out);
|
||||
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'
|
||||
|
||||
@@ -720,6 +720,10 @@ mod tests {
|
||||
out
|
||||
}
|
||||
|
||||
fn one_search() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(1)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_bias_matches_reference() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -747,7 +751,7 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
@@ -766,7 +770,7 @@ mod tests {
|
||||
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -790,7 +794,7 @@ mod tests {
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(ctx.default_stream());
|
||||
rt.set_data(input_t, input);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
@@ -821,7 +825,7 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
@@ -864,7 +868,7 @@ mod tests {
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
|
||||
@@ -84,7 +84,8 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
|
||||
@@ -36,14 +36,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -68,114 +70,11 @@ pub struct Gemma {
|
||||
|
||||
impl Gemma {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut w = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
w.push(GemmaLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
input_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
});
|
||||
}
|
||||
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers: w,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
|
||||
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,11 +84,7 @@ impl Gemma {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -226,6 +121,114 @@ struct GemmaLayer {
|
||||
rope_scaling_factor: f32,
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
input_layernorm: layer_norm(cx, l, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
qk_norm(q, self.q_norm, N_HEADS),
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
qk_norm(k, self.k_norm, N_KV_HEADS),
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = x + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
(
|
||||
x + self.post_feedforward_layernorm.forward(mlp_out),
|
||||
k_cache_out,
|
||||
v_cache_out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
|
||||
/// This produces far fewer e-graph nodes than the tanh-based expansion.
|
||||
#[allow(clippy::excessive_precision)]
|
||||
@@ -363,59 +366,3 @@ fn hlir_attention(
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// QK-norm + RoPE
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
|
||||
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
// O projection + post-attention norm + residual
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
|
||||
let x = x + attn_normed;
|
||||
|
||||
// Pre-feedforward norm + MLP + post-feedforward norm + residual
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
|
||||
(x + mlp_normed, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,11 +81,10 @@ fn main() {
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(
|
||||
runtime,
|
||||
CompileOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
|
||||
&mut rng,
|
||||
);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(search_graphs)
|
||||
.profile_timeout(Duration::from_secs(2));
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
|
||||
@@ -83,20 +83,16 @@ impl KVCache {
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS)
|
||||
.map(|layer| Gemma4Layer::init(cx, layer))
|
||||
.collect(),
|
||||
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,11 +127,7 @@ impl Gemma4MoE {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
fn init(cx: &mut Graph, layer: usize) -> Self {
|
||||
let spec = layer_spec(layer);
|
||||
Self {
|
||||
spec,
|
||||
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
|
||||
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
|
||||
v_proj: spec
|
||||
.has_v_proj
|
||||
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
|
||||
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
|
||||
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
|
||||
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
|
||||
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
|
||||
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
|
||||
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
|
||||
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
|
||||
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
|
||||
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
|
||||
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
|
||||
gate_up_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.gate_up_proj",
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
down_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.down_proj",
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_tensor(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
@@ -505,81 +499,6 @@ fn hlir_attention(
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
@@ -338,13 +338,11 @@ fn main() {
|
||||
println!(" Search trials: {SEARCH_TRIALS}");
|
||||
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(
|
||||
runtime,
|
||||
CompileOptions::new(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST);
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
|
||||
@@ -111,125 +111,18 @@ impl Llama {
|
||||
config: LlamaConfig,
|
||||
fp8_linears: bool,
|
||||
) -> Self {
|
||||
let mut layers = Vec::with_capacity(config.layers);
|
||||
for l in 0..config.layers {
|
||||
layers.push(LlamaLayer {
|
||||
config,
|
||||
up: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
up_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
gate: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
gate_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
down: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
(config.hidden, config.intermediate),
|
||||
fp8_linears,
|
||||
),
|
||||
down_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
attn_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
config,
|
||||
embedding: cx
|
||||
.named_tensor(
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
)
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
embedding: persist(
|
||||
cx,
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
),
|
||||
layers: (0..config.layers)
|
||||
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
|
||||
.collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
|
||||
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,12 +136,7 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let hidden = self.config.hidden;
|
||||
let mut x = self.embedding.gather(
|
||||
(input * hidden).expand_dim(1, hidden)
|
||||
+ input.graph().arange(hidden).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, input, self.config.hidden);
|
||||
let mut cache_outputs = Vec::with_capacity(self.config.layers);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
|
||||
weight: GraphTensor,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
|
||||
Self {
|
||||
config,
|
||||
up: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.up_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
|
||||
gate: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.gate_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
|
||||
down: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.down_proj",
|
||||
(config.hidden, config.intermediate),
|
||||
fp8,
|
||||
),
|
||||
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
|
||||
q_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.q_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
|
||||
k_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.k_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
|
||||
v_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.v_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
|
||||
o_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.o_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
|
||||
attn_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.input_layernorm.weight"),
|
||||
),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn linear_weight(
|
||||
cx: &mut Graph,
|
||||
prefix: impl ToString,
|
||||
@@ -325,6 +377,16 @@ fn linear_weight(
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_linear_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
fp8: bool,
|
||||
) -> GraphTensor {
|
||||
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
|
||||
}
|
||||
|
||||
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
|
||||
if !fp8 {
|
||||
return None;
|
||||
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
|
||||
})
|
||||
}
|
||||
|
||||
fn layer_linear_scales(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
fp8: bool,
|
||||
) -> Option<Fp8LinearScales> {
|
||||
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * hidden).expand_dim(1, hidden)
|
||||
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
|
||||
scale.expand_rhs(like.dims())
|
||||
}
|
||||
@@ -443,87 +526,3 @@ fn attention(
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,7 +241,8 @@ fn main() {
|
||||
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
|
||||
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
// Re-initialize KV cache after search (search consumes buffers)
|
||||
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
|
||||
|
||||
@@ -25,8 +25,8 @@ pub struct PagedKVCache {
|
||||
|
||||
impl PagedKVCache {
|
||||
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
|
||||
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
|
||||
@@ -44,78 +44,11 @@ pub struct Llama {
|
||||
|
||||
impl Llama {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,12 +74,8 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &PagedKVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = vec![];
|
||||
let mut x = token_embedding(self.embedding, input);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
@@ -177,6 +106,99 @@ struct LlamaLayer {
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
@@ -264,44 +286,3 @@ fn paged_attention(
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// Apply RoPE before scattering into cache
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,8 @@ where
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_i32_data(input.id, vec![1; search_s]);
|
||||
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(config.search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(config.search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
for i in 0..config.layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
|
||||
@@ -34,14 +34,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(layers);
|
||||
let mut v_caches = Vec::with_capacity(layers);
|
||||
for l in 0..layers {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -63,105 +65,10 @@ impl Qwen {
|
||||
layers <= LAYERS,
|
||||
"requested {layers} layers, but model has {LAYERS}"
|
||||
);
|
||||
let mut w = vec![];
|
||||
for l in 0..layers {
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
w.push(QwenLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
let lm_norm = LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
layers: w,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..layers).map(|l| QwenLayer::init(cx, l)).collect(),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,11 +79,7 @@ impl Qwen {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(self.layers.len());
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -209,6 +112,90 @@ struct QwenLayer {
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
impl QwenLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = qwen_rotary_embeddings(qk_norm(q, self.q_norm, N_HEADS), pos_ids, N_HEADS);
|
||||
let k_rope =
|
||||
qwen_rotary_embeddings(qk_norm(k, self.k_norm, N_KV_HEADS), pos_ids, N_KV_HEADS);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-head RMS normalization for QK-norm.
|
||||
/// Input: [seq, dim] where dim = n_heads * HEAD_DIM
|
||||
/// split_dims to [seq, n_heads, HEAD_DIM], RMS norm over last axis, multiply by weight, merge back.
|
||||
@@ -331,36 +318,3 @@ fn hlir_attention(
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl QwenLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// QK-norm: per-head RMS normalization
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
|
||||
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
|
||||
|
||||
// RoPE
|
||||
let q_rope = qwen_rotary_embeddings(q_normed, pos_ids, N_HEADS);
|
||||
let k_rope = qwen_rotary_embeddings(k_normed, pos_ids, N_KV_HEADS);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,8 @@ fn main() {
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(runtime, CompileOptions::new(search_graphs), &mut rng);
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
|
||||
@@ -29,17 +29,19 @@ pub struct KVCache {
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -58,111 +60,11 @@ pub struct Qwen3MoE {
|
||||
|
||||
impl Qwen3MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let router = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_up_weights"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_weights"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
layers.push(Qwen3MoELayer {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
moe: QwenMoE {
|
||||
router,
|
||||
gate_up_weights: gate_up_weights.as_dtype(DType::Bf16),
|
||||
down_weights: down_weights.as_dtype(DType::Bf16),
|
||||
},
|
||||
});
|
||||
}
|
||||
let lm_norm = LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
layers,
|
||||
lm_norm,
|
||||
lm_head,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| Qwen3MoELayer::init(cx, l)).collect(),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,11 +74,7 @@ impl Qwen3MoE {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -214,6 +112,39 @@ struct QwenMoE {
|
||||
}
|
||||
|
||||
impl Qwen3MoELayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
moe: QwenMoE {
|
||||
router: layer_weight(cx, l, "mlp.gate", (NUM_EXPERTS, HIDDEN)),
|
||||
gate_up_weights: layer_tensor(
|
||||
cx,
|
||||
l,
|
||||
"mlp.gate_up_weights",
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
down_weights: layer_tensor(
|
||||
cx,
|
||||
l,
|
||||
"mlp.down_weights",
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
@@ -247,6 +178,51 @@ impl Qwen3MoELayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_tensor(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
impl QwenMoE {
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let n = x.dims().len(); // 2 for [s, H]
|
||||
|
||||
@@ -12,7 +12,7 @@ fn main() {
|
||||
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default());
|
||||
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
@@ -96,7 +96,8 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1i32; max_prefill]);
|
||||
runtime.set_data(pos_ids, (0..max_prefill as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
// Reset the KV caches and re-set the mel after search (which executes test runs).
|
||||
for i in 0..N_TEXT_LAYER {
|
||||
|
||||
@@ -33,6 +33,14 @@ fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
x.matmul(w.t())
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
|
||||
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
|
||||
fn conv1d_bias(
|
||||
@@ -90,27 +98,13 @@ struct AttentionWeights {
|
||||
impl AttentionWeights {
|
||||
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
q_bias: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_bias: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
|
||||
.persist(),
|
||||
out_proj: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
out_bias: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
|
||||
.persist(),
|
||||
q_proj: persist(cx, format!("{prefix}.q_proj.weight"), (dim, dim)),
|
||||
q_bias: persist(cx, format!("{prefix}.q_proj.bias"), dim),
|
||||
k_proj: persist(cx, format!("{prefix}.k_proj.weight"), (dim, dim)),
|
||||
v_proj: persist(cx, format!("{prefix}.v_proj.weight"), (dim, dim)),
|
||||
v_bias: persist(cx, format!("{prefix}.v_proj.bias"), dim),
|
||||
out_proj: persist(cx, format!("{prefix}.out_proj.weight"), (dim, dim)),
|
||||
out_bias: persist(cx, format!("{prefix}.out_proj.bias"), dim),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -125,6 +119,14 @@ fn merge_heads(x: GraphTensor) -> GraphTensor {
|
||||
x.transpose(0, 1).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn embedding_lookup(embedding: GraphTensor, ids: GraphTensor) -> GraphTensor {
|
||||
let seq = ids.dims1();
|
||||
embedding.gather(
|
||||
(ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
|
||||
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
|
||||
let q = linear_with_bias(x, w.q_proj, w.q_bias);
|
||||
@@ -239,18 +241,10 @@ impl EncoderLayer {
|
||||
N_AUDIO_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
|
||||
.persist(),
|
||||
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE)),
|
||||
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
|
||||
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM)),
|
||||
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_AUDIO_STATE),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
|
||||
}
|
||||
}
|
||||
@@ -295,18 +289,10 @@ impl DecoderLayer {
|
||||
N_TEXT_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
|
||||
.persist(),
|
||||
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE)),
|
||||
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
|
||||
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM)),
|
||||
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_TEXT_STATE),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
|
||||
}
|
||||
}
|
||||
@@ -346,14 +332,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for l in 0..N_TEXT_LAYER {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_TEXT_HEAD, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_TEXT_HEAD, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -376,27 +364,23 @@ pub struct WhisperEncoder {
|
||||
impl WhisperEncoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
conv1_w: cx
|
||||
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
|
||||
.persist(),
|
||||
conv1_b: cx
|
||||
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
conv2_w: cx
|
||||
.named_tensor(
|
||||
"model.encoder.conv2.weight",
|
||||
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
|
||||
)
|
||||
.persist(),
|
||||
conv2_b: cx
|
||||
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
positional_embedding: cx
|
||||
.named_tensor(
|
||||
"model.encoder.embed_positions.weight",
|
||||
(N_AUDIO_CTX, N_AUDIO_STATE),
|
||||
)
|
||||
.persist(),
|
||||
conv1_w: persist(
|
||||
cx,
|
||||
"model.encoder.conv1.weight",
|
||||
(N_AUDIO_STATE, N_MELS * 3),
|
||||
),
|
||||
conv1_b: persist(cx, "model.encoder.conv1.bias", N_AUDIO_STATE),
|
||||
conv2_w: persist(
|
||||
cx,
|
||||
"model.encoder.conv2.weight",
|
||||
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
|
||||
),
|
||||
conv2_b: persist(cx, "model.encoder.conv2.bias", N_AUDIO_STATE),
|
||||
positional_embedding: persist(
|
||||
cx,
|
||||
"model.encoder.embed_positions.weight",
|
||||
(N_AUDIO_CTX, N_AUDIO_STATE),
|
||||
),
|
||||
layers: (0..N_AUDIO_LAYER)
|
||||
.map(|i| EncoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
@@ -427,15 +411,16 @@ pub struct WhisperDecoder {
|
||||
impl WhisperDecoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embed_tokens: cx
|
||||
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
|
||||
.persist(),
|
||||
embed_positions: cx
|
||||
.named_tensor(
|
||||
"model.decoder.embed_positions.weight",
|
||||
(N_TEXT_CTX, N_TEXT_STATE),
|
||||
)
|
||||
.persist(),
|
||||
embed_tokens: persist(
|
||||
cx,
|
||||
"model.decoder.embed_tokens.weight",
|
||||
(N_VOCAB, N_TEXT_STATE),
|
||||
),
|
||||
embed_positions: persist(
|
||||
cx,
|
||||
"model.decoder.embed_positions.weight",
|
||||
(N_TEXT_CTX, N_TEXT_STATE),
|
||||
),
|
||||
layers: (0..N_TEXT_LAYER)
|
||||
.map(|i| DecoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
@@ -450,18 +435,8 @@ impl WhisperDecoder {
|
||||
xa: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
// Token embedding gather
|
||||
let mut x = self.embed_tokens.gather(
|
||||
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
// Positional embedding gather (using pos_ids)
|
||||
let pos_emb = self.embed_positions.gather(
|
||||
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
x += pos_emb;
|
||||
let mut x = embedding_lookup(self.embed_tokens, token_ids);
|
||||
x += embedding_lookup(self.embed_positions, pos_ids);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
|
||||
@@ -335,7 +335,8 @@ fn main() {
|
||||
|
||||
println!("Compiling (search_graphs={search_graphs})...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
println!(" search took {:?}", t0.elapsed());
|
||||
|
||||
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)
|
||||
|
||||
@@ -174,7 +174,10 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
|
||||
let mut rt = graph.search(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(args.search_iters),
|
||||
);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
@@ -463,7 +463,10 @@ pub(super) mod tests {
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let lhs_values = lhs_transform(random_vec(a_shape.iter().copied().product()));
|
||||
let rhs_values = rhs_transform(random_vec(b_shape.iter().copied().product()));
|
||||
|
||||
@@ -968,7 +968,10 @@ mod tests {
|
||||
let zeros = cx.iota(Expression::from(0usize), 6);
|
||||
let inv = values.scatter(perm, zeros).cast(DType::F32).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.]);
|
||||
rt.set_data(indexes.id, vec![5, 0, 3, 2]);
|
||||
rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0]);
|
||||
@@ -985,7 +988,10 @@ mod tests {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![10., 20., 30.]);
|
||||
rt.set_data(indexes.id, vec![1, 3, 4]);
|
||||
rt.set_data(dest.id, vec![0., 0., 0., 0., 0.]);
|
||||
@@ -1001,7 +1007,10 @@ mod tests {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![99.]);
|
||||
rt.set_data(indexes.id, vec![2]);
|
||||
rt.set_data(dest.id, vec![1., 2., 3., 4., 5.]);
|
||||
@@ -1017,7 +1026,10 @@ mod tests {
|
||||
let dest = cx.tensor(4);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![40., 30., 20., 10.]);
|
||||
rt.set_data(indexes.id, vec![3, 2, 1, 0]);
|
||||
rt.set_data(dest.id, vec![1., 2., 3., 4.]);
|
||||
@@ -1045,7 +1057,10 @@ mod tests {
|
||||
let repeated = (a.repeat((2, 2)) * 1.0).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -126,7 +126,10 @@ mod tests {
|
||||
let b = func(&mut cx).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -211,7 +214,10 @@ mod tests {
|
||||
let stacked = cx.stack(&[a, b, c], 0).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let a_data = random_vec(6);
|
||||
let b_data = random_vec(6);
|
||||
|
||||
@@ -493,7 +493,10 @@ pub(super) mod tests {
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let v = random_vec(shape.iter().copied().product());
|
||||
rt.set_data(a.id, v.clone());
|
||||
|
||||
233
src/graph.rs
233
src/graph.rs
@@ -82,6 +82,13 @@ struct SearchSpaceContext {
|
||||
intervals: DynDimIntervals,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SearchProfileBucketContext {
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
representative_dyn_map: FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
/// A compiled bucket: (bucket_indices, representative_dyn_map, stitched_llir).
|
||||
pub type BucketLLIR = (FxHashMap<char, usize>, FxHashMap<char, usize>, LLIRGraph);
|
||||
|
||||
@@ -137,7 +144,9 @@ impl DimBucket {
|
||||
/// Use the builder pattern to configure search parameters:
|
||||
/// ```
|
||||
/// use luminal::prelude::CompileOptions;
|
||||
/// let opts = CompileOptions::new(5)
|
||||
/// let opts = CompileOptions::default()
|
||||
/// .search_graph_limit(5)
|
||||
/// .search_time_limit(std::time::Duration::from_secs(30))
|
||||
/// .generation_size(50)
|
||||
/// .mutations(40)
|
||||
/// .trials(15);
|
||||
@@ -146,6 +155,8 @@ impl DimBucket {
|
||||
pub struct CompileOptions {
|
||||
/// Maximum number of graphs to evaluate during search.
|
||||
pub limit: usize,
|
||||
/// Maximum wall-clock time to spend searching.
|
||||
pub search_time_limit: std::time::Duration,
|
||||
/// Optional maximum runtime-specific intermediate memory, in bytes.
|
||||
///
|
||||
/// When this is `None`, search does not apply a memory cap. Runtimes that
|
||||
@@ -163,8 +174,6 @@ pub struct CompileOptions {
|
||||
/// Per-candidate profiling timeout. If a profile call reaches this budget,
|
||||
/// that candidate is discarded and search continues.
|
||||
pub profile_timeout: Option<std::time::Duration>,
|
||||
/// Optional per-group search timeout.
|
||||
pub group_timeout: Option<std::time::Duration>,
|
||||
/// Optional profiling dimension overrides.
|
||||
pub profile_dims: FxHashMap<char, usize>,
|
||||
/// Bucket definitions per dynamic dimension. Dimensions without buckets use
|
||||
@@ -173,20 +182,16 @@ pub struct CompileOptions {
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
/// Create compile options with the given search limit. Other fields use defaults.
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
max_memory_bytes: None,
|
||||
generation_size: 10,
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
group_timeout: None,
|
||||
profile_dims: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
}
|
||||
/// Set the maximum number of graphs to evaluate during search.
|
||||
pub fn search_graph_limit(mut self, limit: usize) -> Self {
|
||||
self.limit = limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum wall-clock time to spend searching.
|
||||
pub fn search_time_limit(mut self, search_time_limit: std::time::Duration) -> Self {
|
||||
self.search_time_limit = search_time_limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a maximum intermediate memory budget in bytes.
|
||||
@@ -235,12 +240,6 @@ impl CompileOptions {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an optional per-group search timeout.
|
||||
pub fn group_timeout(mut self, group_timeout: std::time::Duration) -> Self {
|
||||
self.group_timeout = Some(group_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Override a dynamic dimension value used during search profiling.
|
||||
pub fn profile_dim(mut self, dim: char, value: usize) -> Self {
|
||||
self.profile_dims.insert(dim, value);
|
||||
@@ -261,7 +260,18 @@ impl CompileOptions {
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self::new(1)
|
||||
Self {
|
||||
limit: 100,
|
||||
search_time_limit: std::time::Duration::MAX,
|
||||
max_memory_bytes: None,
|
||||
generation_size: 10,
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
profile_dims: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1302,10 +1312,19 @@ impl Graph {
|
||||
"dim buckets must be configured in CompileOptions before build_search_space; search cannot change buckets after build",
|
||||
);
|
||||
|
||||
let search_started_at = std::time::Instant::now();
|
||||
if self.search_space_dim_buckets.is_empty() {
|
||||
// No buckets: existing single-search path
|
||||
let stitched =
|
||||
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None, 0);
|
||||
let stitched = self.search_single(
|
||||
&mut runtime,
|
||||
&options,
|
||||
rng,
|
||||
&self.dyn_map.clone(),
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
search_started_at,
|
||||
);
|
||||
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir(&stitched);
|
||||
@@ -1338,7 +1357,13 @@ impl Graph {
|
||||
rng,
|
||||
&context.representative_dyn_map,
|
||||
Some((combo_idx, n_combos)),
|
||||
Some(SearchProfileBucketContext {
|
||||
dim_buckets: self.search_space_dim_buckets.clone(),
|
||||
bucket_indices: context.bucket_indices.clone(),
|
||||
representative_dyn_map: context.representative_dyn_map.clone(),
|
||||
}),
|
||||
combo_idx,
|
||||
search_started_at,
|
||||
);
|
||||
bucket_llirs.push((
|
||||
context.bucket_indices,
|
||||
@@ -1425,6 +1450,7 @@ impl Graph {
|
||||
/// Run the genetic search and return the unrolled LLIR for the winning
|
||||
/// genome. `bucket_progress`: if `Some((current_bucket_idx, total_buckets))`
|
||||
/// adds a second "Bucket" progress bar.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn search_single<R: Runtime + 'static, G: rand::Rng>(
|
||||
&mut self,
|
||||
runtime: &mut R,
|
||||
@@ -1432,7 +1458,9 @@ impl Graph {
|
||||
rng: &mut G,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
bucket_progress: Option<(usize, usize)>,
|
||||
bucket_profile_context: Option<SearchProfileBucketContext>,
|
||||
egraph_index: usize,
|
||||
search_started_at: std::time::Instant,
|
||||
) -> LLIRGraph {
|
||||
let mut profile_dyn_map = dyn_map.clone();
|
||||
for (&dim, &value) in &options.profile_dims {
|
||||
@@ -1482,11 +1510,11 @@ impl Graph {
|
||||
}
|
||||
};
|
||||
|
||||
let group_start = std::time::Instant::now();
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
runtime.clear_intermediate_buffers();
|
||||
let search_time_limit_reached = || search_started_at.elapsed() >= options.search_time_limit;
|
||||
let profile_timed_out = |elapsed: std::time::Duration| {
|
||||
options
|
||||
.profile_timeout
|
||||
@@ -1531,12 +1559,27 @@ impl Graph {
|
||||
collapse_loops_to_first_iter(&mut graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let (rep_metric, rep_display) =
|
||||
if let Some(bucket_context) = &bucket_profile_context {
|
||||
runtime.profile_with_bucket_context(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
ProfileBucketContext {
|
||||
dim_buckets: &bucket_context.dim_buckets,
|
||||
bucket_indices: &bucket_context.bucket_indices,
|
||||
representative_dyn_map: &bucket_context.representative_dyn_map,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
runtime.profile(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
)
|
||||
};
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
(
|
||||
@@ -1561,11 +1604,8 @@ impl Graph {
|
||||
break;
|
||||
}
|
||||
Ok(_) | Err(_) => {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
panic!("Failed to find a viable initial genome before timeout");
|
||||
if search_time_limit_reached() {
|
||||
panic!("Failed to find a viable initial genome before search time limit");
|
||||
}
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
@@ -1586,10 +1626,7 @@ impl Graph {
|
||||
let mut resample_generation = false;
|
||||
|
||||
while n_graphs < search_limit {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
if search_time_limit_reached() {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1623,10 +1660,7 @@ impl Graph {
|
||||
let mut generation_found_non_timeout = false;
|
||||
|
||||
for genome in all_offspring {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
if search_time_limit_reached() {
|
||||
break;
|
||||
}
|
||||
n_graphs += 1;
|
||||
@@ -1659,12 +1693,27 @@ impl Graph {
|
||||
collapse_loops_to_first_iter(&mut llir_graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let (rep_metric, rep_display) =
|
||||
if let Some(bucket_context) = &bucket_profile_context {
|
||||
runtime.profile_with_bucket_context(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
ProfileBucketContext {
|
||||
dim_buckets: &bucket_context.dim_buckets,
|
||||
bucket_indices: &bucket_context.bucket_indices,
|
||||
representative_dyn_map: &bucket_context.representative_dyn_map,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
runtime.profile(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
)
|
||||
};
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan =
|
||||
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
@@ -2808,6 +2857,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::egglog_utils::hash_egglog_normalized;
|
||||
use crate::tests::{assert_close, random_vec};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestMemoryRuntime;
|
||||
@@ -2845,6 +2895,69 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
static PROFILE_CALLS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingRuntime;
|
||||
|
||||
impl Runtime for CountingRuntime {
|
||||
type Ops = ();
|
||||
type CompileArg = ();
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = usize;
|
||||
|
||||
fn initialize(_: Self::CompileArg) -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
fn load_llir(&mut self, _: &LLIRGraph) {}
|
||||
|
||||
fn execute(&mut self, _: &FxHashMap<char, usize>) -> Self::ExecReturn {}
|
||||
|
||||
fn profile(
|
||||
&mut self,
|
||||
_: &LLIRGraph,
|
||||
_: &FxHashMap<char, usize>,
|
||||
_: usize,
|
||||
_: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
let count = PROFILE_CALLS.fetch_add(1, Ordering::SeqCst);
|
||||
(count, format!("{count} ms"))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_and_search_time_limit_builder() {
|
||||
let opts = CompileOptions::default();
|
||||
assert_eq!(opts.limit, 100);
|
||||
assert_eq!(opts.search_time_limit, std::time::Duration::MAX);
|
||||
|
||||
let time_limit = std::time::Duration::from_millis(25);
|
||||
let opts = CompileOptions::default()
|
||||
.search_graph_limit(7)
|
||||
.search_time_limit(time_limit);
|
||||
assert_eq!(opts.limit, 7);
|
||||
assert_eq!(opts.search_time_limit, time_limit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_time_limit_stops_after_initial_viable_candidate() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 8));
|
||||
let b = cx.tensor((8, 4));
|
||||
let c = cx.tensor((4, 4));
|
||||
let _ = (a.matmul(b) + c).relu().softmax(1).output();
|
||||
|
||||
cx.build_search_space::<CountingRuntime>(CompileOptions::default());
|
||||
|
||||
PROFILE_CALLS.store(0, Ordering::SeqCst);
|
||||
let _ = cx.search(
|
||||
CountingRuntime,
|
||||
CompileOptions::default().search_time_limit(std::time::Duration::ZERO),
|
||||
);
|
||||
assert_eq!(PROFILE_CALLS.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_search_space_without_explicit_max_memory_has_no_cap() {
|
||||
let mut cx = Graph::new();
|
||||
@@ -2862,7 +2975,10 @@ mod tests {
|
||||
let _ = cx.tensor(1).output();
|
||||
cx.build_search_space::<TestMemoryRuntime>(CompileOptions::default().max_memory_bytes(0));
|
||||
|
||||
let _ = cx.search(TestMemoryRuntime, CompileOptions::new(1));
|
||||
let _ = cx.search(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2870,7 +2986,10 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
let _ = cx.tensor(1).output();
|
||||
|
||||
let _ = cx.compile(TestMemoryRuntime, CompileOptions::new(1));
|
||||
let _ = cx.compile(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
assert_eq!(cx.egraphs.len(), 1);
|
||||
assert_eq!(cx.search_space_max_memory_bytes, None);
|
||||
@@ -2908,7 +3027,7 @@ mod tests {
|
||||
let mut rng = rand::rng();
|
||||
let _ = cx.search_with_rng(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::new(1).dim_buckets(
|
||||
CompileOptions::default().search_graph_limit(1).dim_buckets(
|
||||
's',
|
||||
&[DimBucket::new(1, 1), DimBucket::new(2, 4).representative(4)],
|
||||
),
|
||||
@@ -3019,7 +3138,7 @@ mod tests {
|
||||
let vals = random_vec(8);
|
||||
let mut rt = NativeRuntime::default();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(x.id, vals.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -3056,7 +3175,7 @@ mod tests {
|
||||
let vals = random_vec(8);
|
||||
let mut rt = NativeRuntime::default();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(x.id, vals.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
20
src/op.rs
20
src/op.rs
@@ -7,6 +7,12 @@ use crate::prelude::*;
|
||||
use as_any::{AsAny, Downcast};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub struct ProfileBucketContext<'a> {
|
||||
pub dim_buckets: &'a FxHashMap<char, Vec<DimBucket>>,
|
||||
pub bucket_indices: &'a FxHashMap<char, usize>,
|
||||
pub representative_dyn_map: &'a FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
pub trait Runtime {
|
||||
type Ops: IntoEgglogOp;
|
||||
type CompileArg;
|
||||
@@ -36,6 +42,20 @@ pub trait Runtime {
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String);
|
||||
/// Profile one candidate in the context of a specific dynamic-dimension
|
||||
/// bucket. Runtimes with bucket-sensitive lowering can override this so
|
||||
/// search ranks candidates under the same execution model used after
|
||||
/// final bucket compilation.
|
||||
fn profile_with_bucket_context(
|
||||
&mut self,
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
_bucket_context: ProfileBucketContext<'_>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.profile(llir_graph, dyn_map, trials, timeout)
|
||||
}
|
||||
/// Aggregate multiple profile metrics into one comparable metric.
|
||||
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
|
||||
@@ -24,7 +24,7 @@ proptest! {
|
||||
let d = (b * c / e).sin().output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
|
||||
rt.set_data(b.id, vals.clone());
|
||||
rt.set_data(c.id, vals.clone());
|
||||
@@ -58,7 +58,7 @@ proptest! {
|
||||
let a = b.matmul(c).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
let lhs = lhs.into_iter().take(m * k).collect::<Vec<f32>>();
|
||||
let rhs = rhs.into_iter().take(k * n).collect::<Vec<f32>>();
|
||||
rt.set_data(b.id, lhs.clone());
|
||||
@@ -82,7 +82,7 @@ proptest! {
|
||||
let a = cx.tensor((2, 2));
|
||||
let b = (a.permute((1, 0)) * 1.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(a.id, values.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -99,7 +99,7 @@ proptest! {
|
||||
let mask = a.ge(kth_largest.expand_dim(1, cols)).cast(crate::dtype::DType::F32);
|
||||
let filtered = (a * mask).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
let values = values.into_iter().take(rows * cols).collect::<Vec<f32>>();
|
||||
rt.set_data(a.id, values.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -465,7 +465,10 @@ fn test_inputs_consumed_after_execute() {
|
||||
let a = cx.tensor(3);
|
||||
let _b = (a * 2.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
// Second execute should panic — input 'a' was consumed
|
||||
@@ -481,7 +484,10 @@ fn test_passthrough_preserves_weights() {
|
||||
w.persist();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Iteration 1
|
||||
rt.set_data(w.id, vec![1.0, 2.0, 3.0]);
|
||||
@@ -503,7 +509,10 @@ fn test_only_outputs_remain() {
|
||||
let a = cx.tensor(3);
|
||||
let _b = (a * 2.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let output_count = rt
|
||||
@@ -556,7 +565,10 @@ fn integration_auto_loop_rolling_matches_reference_native_runtime() {
|
||||
|
||||
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
|
||||
graph.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = graph.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = graph.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(input_id, input);
|
||||
for (node, data) in weight_ids.iter().zip(weights.iter()) {
|
||||
rt.set_data(*node, data.clone());
|
||||
|
||||
Reference in New Issue
Block a user