Compare commits

...

4 Commits

Author SHA1 Message Date
Tucker Morgan
d6b0eb0ec1 Add recommender model compile coverage 2026-05-13 21:40:15 +00:00
June
1dcd0370ce feat: add CUDA 13.2 support via cudarc 0.19.4 (#312)
* Update cudarc to 0.19.4 to support CUDA 13.2

Fixes #291

Changes:
- Upgrade cudarc from 0.18.2 to 0.19.4
- Remove get_global call for __constant__ memory tracking

Rationale:
cudarc 0.19.0 changed get_global to return CudaViewMut instead of
CudaSlice to prevent double-free of __constant__ memory managed by
the CUDA module. The old code worked around this by storing the
CudaSlice and calling std::mem::forget on cleanup. With the new API,
the view's lifetime is tied to the module borrow, making the
workaround unnecessary. Since the constants HashMap was only used
for this workaround and never accessed otherwise, we now return an
empty HashMap.

CUDA 13.2 support was added in cudarc 0.19.4.

* fix: migrate embed kernel to shared dyn_dims buffer

The cudarc 0.18→0.19 bump removed get_global, but simply dropping the
call left __constant__ memory declared-but-never-written, producing
wrong results for models with dynamic-shape embeddings. Migrate to
the same dyn_dims parameter + #define pattern every other kernel uses.
2026-05-13 13:43:36 -04:00
Ali
6757a4e37b pack scatter kernel into 256-thread blocks (#309) 2026-05-13 13:43:15 -04:00
Joe Fioti
631451f8b8 Remove Testing section from README (#313)
Removed the Testing section from the README.
2026-05-12 17:36:33 -04:00
24 changed files with 2203 additions and 154 deletions

View File

@@ -45,49 +45,6 @@ cd ./examples/llama
cargo run --release
```
## Testing
The CUDA test suites have a fast default path and explicit opt-in slow coverage.
Slow tests include exhaustive cuBLASLt rewrite sweeps, CUDA search/genome fuzzing,
large Python model compiles, and pretrained/full-width model tests.
Rust CUDA unit tests, matching the default CI path:
```bash
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
cargo test -p luminal_cuda_lite -- --test-threads=1
```
Rust CUDA slow tests only:
```bash
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
cargo test -p luminal_cuda_lite -- --ignored --test-threads=1
```
Rust CUDA full suite, including ignored slow tests:
```bash
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
cargo test -p luminal_cuda_lite -- --include-ignored --test-threads=1
```
Python CUDA tests, matching the default CI path:
```bash
cd crates/luminal_python
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda MATURIN_PEP517_ARGS="--features cuda --profile release" CUDARC_CUDA_VERSION=12080 \
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
```
Python CUDA full suite, including slow tests:
```bash
cd crates/luminal_python
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda MATURIN_PEP517_ARGS="--features cuda --profile release" CUDARC_CUDA_VERSION=12080 \
uv run --group dev python -m pytest tests/ -v -s
```
## Features
### Speed

View File

@@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
luminal = { path = "../.." }
luminal_tracing = { path = "../luminal_tracing" }
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
anyhow = "1.0"
as-any = "0.3.2"
itertools = "0.12.1"

View File

@@ -5,6 +5,7 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
use luminal::prelude::*;
use crate::cudarc::driver::CudaContext;
use crate::host::describe_host_op;
use crate::runtime::CudaRuntime;
/// [`DynBackend`] wrapper for [`CudaRuntime`].
@@ -39,6 +40,26 @@ impl DynBackend for CudaLiteDynBackend {
self.runtime.execute(dyn_map);
}
fn kernel_names(&self) -> Vec<String> {
self.runtime
.kernel_names()
.iter()
.map(|name| (*name).to_string())
.collect()
}
fn host_op_names(&self) -> Vec<String> {
self.runtime
.host_ops()
.iter()
.map(|op| describe_host_op(*op))
.collect()
}
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
fn supports_device_ptrs(&self) -> bool {
true
}

View File

@@ -247,6 +247,10 @@ impl HostOp for CuBlasSgemmV2 {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasSgemmV2")
}
fn output_size(&self) -> Expression {
self.m * self.n
}

View File

@@ -419,7 +419,6 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
}
}
#[cfg(test)]
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
match epilogue {
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
@@ -978,6 +977,18 @@ impl CuBlasLt {
&& normalize(self.stride_c) == normalize(self.stride_d)
&& self.c_order == self.d_order
}
pub(crate) fn debug_summary(&self) -> String {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
format!(
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
resolve(&self.m),
resolve(&self.n),
resolve(&self.k),
resolve(&self.batch_count),
epilogue_name(self.epilogue),
)
}
}
impl HostOp for CuBlasLt {
@@ -1114,6 +1125,10 @@ impl HostOp for CuBlasLt {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasLt")
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)

View File

@@ -1,6 +1,7 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use crate::kernel::CudaGraphOp;
use luminal::{op::EgglogOp, prelude::*};
pub mod compute_attn_mask;
mod cublas;
@@ -79,6 +80,24 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
.map(cublaslt::CuBlasLt::c_d_layouts_match)
}
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
return op.debug_summary();
}
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
let mut summary = op.debug_summary();
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& let Some(timing) = op.debug_timing_summary()
{
summary.push_str(" [");
summary.push_str(&timing);
summary.push(']');
}
return summary;
}
op.stats_name().unwrap_or("unknown").to_string()
}
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside

View File

@@ -12,7 +12,10 @@ use luminal::{
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
extract_dtype, extract_expr, extract_expr_list,
},
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
hlir::{
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
Sin, Sqrt, SumReduce,
},
op::*,
prelude::*,
};
@@ -65,6 +68,8 @@ pub type Ops = (
KernelConstant,
KernelCast,
KernelEmbed,
KernelConcat2D,
KernelEmbeddingBagSum,
);
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
@@ -3266,14 +3271,20 @@ impl KernelOp for KernelEmbed {
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.embed_dim.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let kernel = format!(
"
{}
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table) {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
@@ -3284,10 +3295,7 @@ extern \"C\" {{
int token_id = token_ids[token_offset];
out[out_offset + embed_idx] = embed_table[(long long)token_id * embed_dim + embed_idx];
}}
}}",
vars.iter()
.map(|i| format!("__constant__ int const_{i}[1];"))
.join("\n"),
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
@@ -3298,10 +3306,8 @@ extern \"C\" {{
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let constants = vars
.into_iter()
.map(|d| (d, module.get_global(&format!("const_{d}"), stream).unwrap()))
.collect();
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
let total_threads = batch_size * self.embed_dim;
(
func,
@@ -3353,3 +3359,361 @@ extern \"C\" {{
"Embed"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelEmbeddingBagSum {
n_bags: Expression,
n_indices: Expression,
hidden_dim: Expression,
num_embeddings: Expression,
dtype: DType,
}
impl EgglogOp for KernelEmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelEmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[4]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelEmbeddingBagSum {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
self.dtype
);
let vars = self
.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_bags = self.n_bags.to_kernel();
let n_indices = self.n_indices.to_kernel();
let hidden_dim = self.hidden_dim.to_kernel();
let num_embeddings = self.num_embeddings.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long bag = blockIdx.y;
long long hidden_dim = {hidden_dim};
long long n_bags = {n_bags};
long long n_indices = {n_indices};
long long num_embeddings = {num_embeddings};
if (bag >= n_bags || dim >= hidden_dim) return;
int start_raw = offsets[bag];
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
int start = max(0, min(start_raw, (int)n_indices));
int end = max(start, min(end_raw, (int)n_indices));
float sum = 0.0f;
for (int pos = start; pos < end; ++pos) {{
int row = indices[pos];
row = max(0, min(row, (int)num_embeddings - 1));
sum += weight[(long long)row * hidden_dim + dim];
}}
out[bag * hidden_dim + dim] = sum;
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embedding_bag_sum").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
(self.hidden_dim.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.n_bags * self.hidden_dim
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
// Approximate: weights + indices + offsets
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.n_indices * self.hidden_dim
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"EmbeddingBagSum"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelConcat2D {
rows: Expression,
lhs_cols: Expression,
rhs_cols: Expression,
dtype: DType,
}
impl EgglogOp for KernelConcat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelConcat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<Concat2D, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelConcat2D {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelConcat2D only supports F32 today, got {:?}",
self.dtype
);
let vars = self
.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let rows = self.rows.to_kernel();
let lhs_cols = self.lhs_cols.to_kernel();
let rhs_cols = self.rhs_cols.to_kernel();
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = {total};
if (idx >= total) return;
long long rows = {rows};
long long lhs_cols = {lhs_cols};
long long rhs_cols = {rhs_cols};
long long out_cols = lhs_cols + rhs_cols;
if (rows == 0 || out_cols == 0) return;
long long row = idx / out_cols;
long long col = idx - row * out_cols;
if (col < lhs_cols) {{
out[idx] = lhs[row * lhs_cols + col];
}} else {{
long long rhs_col = col - lhs_cols;
out[idx] = rhs[row * rhs_cols + rhs_col];
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("concat_2d").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let output_size = self.output_size();
(
func,
module,
kernel,
(output_size.ceil_div(256), 1.into(), 1.into()),
(output_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.rows * (self.lhs_cols + self.rhs_cols)
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
0.into()
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Concat2D"
}
}

View File

@@ -509,8 +509,8 @@ extern \"C\" {{
func,
module,
scatter_kernel,
(n_src, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(n_src.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)

View File

@@ -4,6 +4,8 @@
//! that can be executed like any other HostOp.
use std::cell::RefCell;
use std::cmp::Reverse;
use std::collections::BTreeMap;
use std::sync::Arc;
use cudarc::driver::{
@@ -141,6 +143,8 @@ struct CudaGraphOpState {
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
/// Timing events for profiling
timing_events: Vec<cudarc::driver::sys::CUevent>,
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
last_kernel_timings_us: Vec<(&'static str, f64)>,
}
impl CudaGraphOpState {
@@ -155,6 +159,7 @@ impl CudaGraphOpState {
last_dyn_values: FxHashMap::default(),
last_buffer_ptrs: FxHashMap::default(),
timing_events: Vec::new(),
last_kernel_timings_us: Vec::new(),
}
}
}
@@ -192,6 +197,41 @@ impl CudaGraphOp {
state: RefCell::new(state),
}
}
pub fn debug_summary(&self) -> String {
let state = self.state.borrow();
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
for kernel in &state.kernels {
*counts.entry(kernel.kernel_name).or_default() += 1;
}
let mut counts: Vec<_> = counts.into_iter().collect();
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
let top = counts
.into_iter()
.take(4)
.map(|(name, count)| format!("{name}x{count}"))
.join(", ");
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
}
pub fn debug_timing_summary(&self) -> Option<String> {
let state = self.state.borrow();
if state.last_kernel_timings_us.is_empty() {
return None;
}
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
for (name, us) in &state.last_kernel_timings_us {
*totals.entry(*name).or_default() += *us;
}
let mut totals: Vec<_> = totals.into_iter().collect();
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
let top = totals
.into_iter()
.take(4)
.map(|(name, us)| format!("{name}={us:.0}us"))
.join(", ");
Some(top)
}
}
impl std::fmt::Debug for CudaGraphOp {
@@ -566,6 +606,23 @@ impl CudaGraphOp {
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& state.timing_events.len() >= state.kernels.len() + 1
{
stream.synchronize()?;
let ctx = stream.context().clone();
state.last_kernel_timings_us.clear();
for idx in 0..state.kernels.len() {
let start_event = state.timing_events[idx];
let end_event = state.timing_events[idx + 1];
let kernel_name = state.kernels[idx].kernel_name;
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
state.last_kernel_timings_us.push((kernel_name, us));
}
}
Ok(())
}
@@ -584,8 +641,9 @@ impl CudaGraphOp {
state.kernel_params.clear();
state.kernel_params.reserve(num_kernels);
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
let tracing_enabled = enabled!(Level::TRACE);
if tracing_enabled {
if tracing_enabled || profile_cuda_graph {
let needed_events = num_kernels + 1;
while state.timing_events.len() < needed_events {
state.timing_events.push(create_cuda_event(&ctx)?);
@@ -701,7 +759,7 @@ impl CudaGraphOp {
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
let timing_event = if tracing_enabled || profile_cuda_graph {
Some(state.timing_events[idx])
} else {
None
@@ -739,7 +797,9 @@ impl CudaGraphOp {
prev_graph_node = Some(graph_node);
}
if tracing_enabled && let Some(prev) = prev_graph_node {
if (tracing_enabled || profile_cuda_graph)
&& let Some(prev) = prev_graph_node
{
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}

View File

@@ -1,6 +1,9 @@
use crate::{
host::{DeviceBuffer, HostOp},
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
host::{DeviceBuffer, HostOp, describe_host_op},
kernel::{
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
record_cuda_graph_timings, record_event_on_stream,
},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
@@ -60,6 +63,12 @@ pub struct KernelStats {
pub tflops: f64,
}
#[derive(Debug, Clone)]
pub struct HostOpStats {
pub name: String,
pub execution_time_us: f64,
}
impl Debug for ExecutableHostOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HostOp: ({:?})", self.internal)
@@ -141,6 +150,7 @@ pub struct CudaRuntime {
changed_hlir: FxHashSet<NodeIndex>,
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
pub last_kernel_stats: Vec<KernelStats>,
pub last_host_op_stats: Vec<HostOpStats>,
pub last_total_time_us: f64,
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
/// When true, execute() skips input buffer consumption (used during search/profile)
@@ -1203,6 +1213,7 @@ impl Runtime for CudaRuntime {
changed_hlir: FxHashSet::default(),
cuda_graph_timings: vec![],
last_kernel_stats: vec![],
last_host_op_stats: vec![],
last_total_time_us: 0.0,
kernel_cache: FxHashMap::default(),
profiling: false,
@@ -1454,6 +1465,8 @@ impl Runtime for CudaRuntime {
let total_start = std::time::Instant::now();
let bucket = &self.compiled_buckets[self.active_bucket];
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
let mut host_timing_events = Vec::new();
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
let exec_op = &bucket.exec_graph[exec_node];
@@ -1507,6 +1520,16 @@ impl Runtime for CudaRuntime {
n_inputs = exec_op.inputs.len()
)
.entered();
let host_op_timing = if profile_host_ops {
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
let ctx = exec_op.stream.context().clone();
let start_event = create_cuda_event(&ctx).unwrap();
let end_event = create_cuda_event(&ctx).unwrap();
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
Some((name, ctx, start_event, end_event))
} else {
None
};
exec_op
.internal
.execute(
@@ -1522,10 +1545,28 @@ impl Runtime for CudaRuntime {
exec_op.internal.stats_name().unwrap_or("unknown")
);
});
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
host_timing_events.push((name, ctx, start_event, end_event));
}
}
// Single sync at end - CUDA stream ordering guarantees sequential execution
self.cuda_stream.synchronize().unwrap();
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
self.last_host_op_stats.clear();
if profile_host_ops {
for (name, ctx, start_event, end_event) in host_timing_events {
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
self.last_host_op_stats.push(HostOpStats {
name,
execution_time_us,
});
destroy_cuda_event(&ctx, start_event);
destroy_cuda_event(&ctx, end_event);
}
}
// Populate last_kernel_stats from HostOps that report stats
self.last_kernel_stats.clear();
@@ -1861,6 +1902,16 @@ impl CudaRuntime {
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
if !self.last_host_op_stats.is_empty() {
println!("\n=== Host Operation Statistics ===\n");
println!("{:<20} {:>12}", "HostOp", "Time (us)");
println!("{}", "-".repeat(34));
for s in &self.last_host_op_stats {
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
}
println!("{}", "-".repeat(34));
}
// Print kernel stats
if !self.last_kernel_stats.is_empty() {
println!("\n=== Kernel Execution Statistics ===\n");

View File

@@ -248,6 +248,23 @@ impl CompiledGraph {
self.runtime.device_type()
}
/// Names of kernels compiled into the active runtime bucket, if available.
#[getter]
fn kernel_names(&self) -> Vec<String> {
self.runtime.kernel_names()
}
/// Names of host ops in the active runtime bucket, if available.
#[getter]
fn host_op_names(&self) -> Vec<String> {
self.runtime.host_op_names()
}
/// Print backend execution statistics for the last run, if supported.
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
#[getter]
fn supports_device_ptrs(&self) -> bool {

View File

@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Result, bail};
use luminal::prelude::*;
use crate::pt2_schema::*;
@@ -8,21 +8,62 @@ use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let arg1 = &node.inputs[1].arg;
if let Some(name) = arg1.as_tensor_name() {
let b = self.get_tensor(name)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
} else {
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
let alpha = match op {
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
BinaryOp::Mul | BinaryOp::Div => 1.0,
};
let lhs = node.inputs[0]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
let rhs = node.inputs[1]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
match (lhs, rhs) {
(Some(a), Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(Some(a), None) => {
let mut val = self.get_float_arg(node, 1)? as f32;
if alpha != 1.0 {
val *= alpha;
}
Ok(self.apply_scalar_op(a, val, op))
}
(None, Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let lhs_val = self.get_float_arg(node, 0)? as f32;
let a = self
.graph
.constant_float(lhs_val)
.cast(b.dtype)
.expand_rhs(b.shape);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(None, None) => bail!("{} expects at least one tensor operand", node.target),
}
}

View File

@@ -111,6 +111,7 @@ impl<'a> Translator<'a> {
result
}
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
"torch.ops.aten.clone.default" => {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
@@ -133,8 +134,28 @@ impl<'a> Translator<'a> {
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
if alpha == 0.0 && beta == 0.0 {
self.graph
.constant_float(0.0)
.cast(mm.dtype)
.expand_rhs(mm.shape)
} else if beta == 0.0 {
if alpha == 1.0 { mm } else { mm * alpha }
} else if alpha == 0.0 {
let input = if beta == 1.0 { input } else { input * beta };
let zero = self
.graph
.constant_float(0.0)
.cast(input.dtype)
.expand_rhs(mm.shape);
let (input, _) = broadcast_binary(input, zero);
input
} else {
let input = if beta == 1.0 { input } else { input * beta };
let mm = if alpha == 1.0 { mm } else { mm * alpha };
let (input, mm) = broadcast_binary(input, mm);
input + mm
}
}
// Convolution
@@ -147,8 +168,14 @@ impl<'a> Translator<'a> {
// Slice/index ops
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
"torch.ops.aten._embedding_bag.default"
| "torch.ops.aten._embedding_bag_forward_only.default" => {
self.translate_embedding_bag(node)?
}
"<built-in function getitem>" => self.translate_getitem(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
@@ -163,6 +190,9 @@ impl<'a> Translator<'a> {
// LayerNorm
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
self.translate_native_batch_norm_no_training(node)?
}
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,

View File

@@ -12,6 +12,64 @@ const SCATTER_INDEX_ARG: usize = 2;
const SCATTER_VALUE_ARG: usize = 3;
impl<'a> Translator<'a> {
fn try_concat_2d_fast(
&mut self,
lhs: GraphTensor,
rhs: GraphTensor,
axis: usize,
) -> Option<GraphTensor> {
if axis != 1
|| lhs.dtype != DType::F32
|| rhs.dtype != DType::F32
|| lhs.shape.len() != 2
|| rhs.shape.len() != 2
|| !lhs.shape.is_contiguous()
|| !rhs.shape.is_contiguous()
|| lhs.shape.dims[0] != rhs.shape.dims[0]
{
return None;
}
let rows = lhs.shape.dims[0];
let lhs_cols = lhs.shape.dims[1];
let rhs_cols = rhs.shape.dims[1];
let id = self.graph.add_op(
luminal::hlir::Concat2D {
rows,
lhs_cols,
rhs_cols,
},
&[lhs.id, rhs.id],
);
Some(GraphTensor::from_id(
id,
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
lhs.graph_ref,
lhs.dtype,
))
}
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
let index = self
.get_int_arg(node, 2)
.context("select.int: missing index")?;
let dim_size = a.shape.dims[dim]
.to_usize()
.context("select.int: symbolic dims are not supported for negative indices")?;
let normalized_index = if index < 0 {
(dim_size as i64 + index) as usize
} else {
index as usize
};
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
.squeeze(dim))
}
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -80,6 +138,43 @@ impl<'a> Translator<'a> {
Ok(a)
}
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
let mut a = self.get_input_tensor(node, 0)?;
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
sizes
.into_iter()
.map(|size| {
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
Ok(Expression::from(size as usize))
})
.collect::<Result<_>>()?
} else {
self.get_exprs_arg(node, 1)?
};
anyhow::ensure!(
repeats.len() >= a.shape.len(),
"repeat: repeats rank {} is smaller than input rank {}",
repeats.len(),
a.shape.len()
);
while a.shape.len() < repeats.len() {
a = a.unsqueeze(0);
}
Ok(a.repeat(repeats))
}
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
let index = self.get_int_arg(node, 1)?;
anyhow::ensure!(
index == 0,
"getitem: only tuple[0] access is supported today, got index={index}"
);
self.get_input_tensor(node, 0)
}
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1).unwrap_or(0);
@@ -161,7 +256,11 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, tensors[0].shape.len());
let mut result = tensors[0];
for t in &tensors[1..] {
result = result.concat_along(*t, dim);
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
result = fast;
} else {
result = result.concat_along(*t, dim);
}
}
Ok(result)
}
@@ -218,6 +317,79 @@ impl<'a> Translator<'a> {
bail!("index.Tensor: no index tensors in optional_tensors list");
}
index_names = found_tensors;
// Multiple explicit index tensors after leading `None`s mean
// "keep the prefix dims, then advanced-index the contiguous
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
if first_non_none_dim > 0
&& index_names.len() > 1
&& first_non_none_dim + index_names.len() == source.shape.len()
{
let src_dims = source.shape.dims;
let indexed_dims = &src_dims[first_non_none_dim..];
let n_indexed = index_names.len();
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * indexed_dims[i + 1];
}
let mut flat_idx: Option<GraphTensor> = None;
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
let axis_size = indexed_dims[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
idx_int
} else {
idx_int * stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (acc_b, w_b) = broadcast_binary(acc, weighted);
acc_b + w_b
}
None => weighted,
});
}
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
let idx_shape = flat_idx.shape.dims.to_vec();
let mut idx_numel = Expression::from(1usize);
for dim in &idx_shape {
idx_numel *= *dim;
}
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
let mut indexed_size = Expression::from(1usize);
for dim in indexed_dims {
indexed_size *= *dim;
}
let mut flat_source_shape = prefix_dims.clone();
flat_source_shape.push(indexed_size);
let flat_source = reshape_tensor(source, flat_source_shape);
let mut expanded_idx = flat_idx;
for _ in 0..prefix_dims.len() {
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
}
let mut target = prefix_dims.clone();
target.push(idx_numel);
expanded_idx.shape.expand(target);
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
let mut result_shape = prefix_dims;
result_shape.extend_from_slice(&idx_shape);
return Ok(reshape_tensor(gathered, result_shape));
}
// Simple case: single non-None index on a specific dim → gather_elements
if first_non_none_dim > 0 && index_names.len() == 1 {
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);

View File

@@ -28,6 +28,45 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
impl<'a> Translator<'a> {
fn translate_embedding_bag_generic(
&mut self,
weight: GraphTensor,
indices: GraphTensor,
offsets: GraphTensor,
) -> Result<GraphTensor> {
let hidden_dim = weight.shape.dims[1];
let n_indices = indices.shape.dims[0];
let n_bags = offsets.shape.dims[0];
// Gather per-index embeddings: [E] -> [E, D].
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
let gathered = weight.gather(ids_expanded + arange);
// Bag assignment per position:
// bag_id[pos] = count(offsets <= pos) - 1
// This supports empty bags too, because repeated offsets simply skip a
// bag id when no positions land in that interval.
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
let starts = offsets.expand_dim(1, n_indices);
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
- self
.graph
.constant_float(1.0)
.cast(DType::Int)
.expand_rhs(vec![n_indices]);
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
let bag_ids = bag_ids.expand_dim(0, n_bags);
let mask = bag_ids
.eq(bag_axis)
.expand_dim(2, hidden_dim)
.cast(gathered.dtype);
let gathered = gathered.expand_dim(0, n_bags);
Ok((gathered * mask).sum(1))
}
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
.inputs
@@ -180,6 +219,45 @@ impl<'a> Translator<'a> {
Ok(value.expand_rhs(reference.shape))
}
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
let mode = self.get_int_arg(node, 4).unwrap_or(0);
anyhow::ensure!(
mode == 0,
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
);
anyhow::ensure!(
indices.shape.len() == 1 && offsets.shape.len() == 1,
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
indices.shape.len(),
offsets.shape.len()
);
if weight.dtype == DType::F32 {
let id = self.graph.add_op(
luminal::hlir::EmbeddingBagSum {
n_bags: offsets.shape.dims[0],
n_indices: indices.shape.dims[0],
hidden_dim: weight.shape.dims[1],
num_embeddings: weight.shape.dims[0],
},
&[weight.id, indices.id, offsets.id],
);
return Ok(GraphTensor::from_id(
id,
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
weight.graph_ref,
DType::F32,
));
}
self.translate_embedding_bag_generic(weight, indices, offsets)
}
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
let output_name = node
.outputs

View File

@@ -21,6 +21,30 @@ const DIV_MODE_INPUT_ARG: usize = 0;
const DIV_MODE_OTHER_ARG: usize = 1;
impl<'a> Translator<'a> {
fn expand_channel_parameter(
&self,
input: GraphTensor,
parameter: GraphTensor,
) -> Result<GraphTensor> {
anyhow::ensure!(
input.shape.len() >= 2,
"batch_norm: expected rank >= 2 input, got rank {}",
input.shape.len()
);
anyhow::ensure!(
parameter.shape.len() == 1,
"batch_norm: expected 1D channel parameter, got rank {}",
parameter.shape.len()
);
let mut expanded = parameter.unsqueeze(0);
for axis in 2..input.shape.len() {
expanded = expanded.unsqueeze(axis);
}
expanded.shape.expand(input.dims().to_vec());
Ok(expanded)
}
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
@@ -101,6 +125,30 @@ impl<'a> Translator<'a> {
Ok(result)
}
pub(crate) fn translate_native_batch_norm_no_training(
&mut self,
node: &Node,
) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
let mut result = (input - running_mean) / (running_var + eps).sqrt();
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
result = result * weight;
}
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
result = result + bias;
}
Ok(result)
}
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let zero = self

View File

@@ -8,12 +8,14 @@ from .compiled_model import CompiledModel
# Import Rust extension components (built by maturin)
from .luminal import CompiledGraph, process_pt2
from .main import luminal_backend, register_backend
from .pt2 import compile
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
"CompiledModel",
"compile",
"luminal_backend",
"register_backend",
"CompiledGraph",

View File

@@ -1,9 +1,8 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
from typing import List
from typing import Any, List
import torch
from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
@@ -28,6 +27,14 @@ class CompiledModel:
self._input_names = input_names or graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
output_dtype_codes = graph_result.output_dtypes
self._output_dtypes = [
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
for i in range(len(self._output_names))
]
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
self._weight_refs = weight_refs or []
self._user_indices = user_indices
@@ -35,6 +42,9 @@ class CompiledModel:
self._supports_device_ptrs = getattr(
graph_result, "supports_device_ptrs", False
)
# Cache converted/contiguous views for repeated calls with the same input
# tensors so we don't rebuild sparse index buffers every forward.
self._prepared_input_cache = [None] * len(self._input_names)
# Expected input dtypes from graph (used to convert user inputs)
input_dtype_codes = graph_result.input_dtypes
self._input_dtypes = [
@@ -43,6 +53,80 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
self._single_float_output_fast_path = (
self._supports_device_ptrs
and not self._has_dynamic_dims
and len(self._output_names) == 1
and self._output_dtypes[0].is_floating_point
)
@staticmethod
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
return (
id(tensor),
getattr(tensor, "_version", None),
tensor.data_ptr(),
expected_dtype,
)
def _prepare_input_tensor(
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
) -> torch.Tensor:
detached = tensor.detach()
needs_preparation = (
detached.dtype != expected_dtype or not detached.is_contiguous()
)
if not needs_preparation:
return detached
cache_key = self._input_cache_key(detached, expected_dtype)
cached = self._prepared_input_cache[index]
if cached is not None and cached[0] == cache_key:
return cached[1]
prepared = detached.contiguous().to(expected_dtype)
self._prepared_input_cache[index] = (cache_key, prepared)
return prepared
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Bind the current user inputs into the Rust graph."""
input_refs = []
for index, (name, tensor, expected_dtype) in enumerate(
zip(self._input_names, user_inputs, self._input_dtypes)
):
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
if self._supports_device_ptrs and tensor.is_cuda:
n_bytes = prepared.numel() * prepared.element_size()
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
input_refs.append(prepared)
else:
if prepared.device.type != "cpu":
prepared = prepared.cpu()
n_bytes = prepared.numel() * prepared.element_size()
dtype_code = _torch_dtype_code(prepared.dtype)
self._graph.set_input_from_ptr(
name, prepared.data_ptr(), n_bytes, dtype_code
)
return input_refs
def _run_static_single_float_output(
self, user_inputs: List[torch.Tensor], input_device: torch.device
):
_input_refs = self._bind_user_inputs(user_inputs)
output_name = self._output_names[0]
output_dtype = self._output_dtypes[0]
out = torch.empty(
self._static_output_shapes[0], dtype=output_dtype, device=input_device
)
self._graph.set_output_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
self._graph.run()
if not self._graph.output_is_zero_copy(output_name):
self._graph.copy_output_to_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
return (out,)
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -86,25 +170,18 @@ class CompiledModel:
if self._has_dynamic_dims:
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
elif (
self._single_float_output_fast_path
and input_device.type != "cpu"
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
):
return self._run_static_single_float_output(user_inputs, input_device)
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = []
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
_input_refs = self._bind_user_inputs(user_inputs)
# Resolve output shapes before run() (needed for pre-allocation).
if self._has_dynamic_dims:
@@ -112,8 +189,6 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
output_dtype_codes = self._graph.output_dtypes
# CUDA zero-copy path: pre-allocate output tensors and register their device
# pointers so the final kernel writes directly into PyTorch's buffer.
_use_zero_copy = self._supports_device_ptrs
@@ -121,8 +196,8 @@ class CompiledModel:
if _use_zero_copy:
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = torch.empty(shape, dtype=out_dtype, device=input_device)
@@ -140,8 +215,8 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = output_tensors[i]
@@ -178,8 +253,8 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
if out_dtype == torch.int32:
@@ -199,3 +274,41 @@ class CompiledModel:
outputs.append(out)
return tuple(outputs)
def _leaf_paths(tree: Any, prefix=()):
if torch.is_tensor(tree):
return [prefix]
if isinstance(tree, (list, tuple)):
paths = []
for idx, value in enumerate(tree):
paths.extend(_leaf_paths(value, prefix + (idx,)))
return paths
if isinstance(tree, dict):
paths = []
for key, value in tree.items():
paths.extend(_leaf_paths(value, prefix + (key,)))
return paths
return [prefix]
def _follow_path(tree: Any, path):
value = tree
for key in path:
value = value[key]
return value
class StructuredCompiledModel:
"""Preserve a module's original nested input structure for direct PT2 compile()."""
def __init__(self, compiled_model, example_args):
self._compiled = compiled_model
self._leaf_paths = _leaf_paths(example_args)
def __getattr__(self, name):
return getattr(self._compiled, name)
def __call__(self, *args):
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
return self._compiled(*flat_inputs)

View File

@@ -9,10 +9,12 @@ import inspect
import os
import shutil
import tempfile
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
from .compiled_model import CompiledModel
from .compiled_model import CompiledModel, StructuredCompiledModel
from .luminal import process_pt2
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
@@ -184,6 +186,66 @@ def _save_and_compile(
shutil.rmtree(tmpdir, ignore_errors=True)
def _has_cuda_inputs(flat_example_inputs):
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
"""Search env overrides for direct compile() calls.
CUDA DLRM-style models benefit materially from a deeper per-candidate
profile and from keeping more parents alive between generations. Keep the
defaults narrow so CPU and env-configured callers are unchanged.
"""
has_cuda = _has_cuda_inputs(flat_example_inputs)
overrides = {}
if search_trials is not None:
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
if search_keep_best is not None:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
return overrides
@contextmanager
def _temporary_env(overrides):
sentinel = object()
previous = {}
try:
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = old_value
def _strip_exported_weights_for_zero_copy(ep, original_weights):
"""Shrink the saved .pt2 artifact when original weights will be reused."""
if not original_weights:
return
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
if isinstance(orig, torch.nn.Parameter):
replacement = torch.nn.Parameter(
replacement, requires_grad=orig.requires_grad
)
ep._state_dict[key] = replacement
del orig
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
@@ -401,7 +463,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
def compile(
model,
example_input,
search_iterations=25,
search_iterations=None,
search_trials=None,
search_keep_best=None,
factory=None,
export_kwargs=None,
dynamic_dim=None,
@@ -413,7 +477,13 @@ def compile(
model: A PyTorch nn.Module.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations.
search_iterations: Number of optimization search iterations. When None,
defaults to 200 on CUDA inputs and 10 otherwise.
search_trials: Optional per-candidate profiling trials inside Luminal's
search. When unset, direct CUDA compile defaults to 5.
search_keep_best: Optional number of parent candidates to retain
between search generations. When unset, direct CUDA compile
defaults to 3.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
@@ -432,17 +502,26 @@ def compile(
Returns:
A CompiledModel callable.
"""
if factory is None:
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
if factory is None:
factory = _detect_factory_capsule(flat_example_inputs)
if search_iterations is None:
search_iterations = (
200
if _has_cuda_inputs(flat_example_inputs)
else 10
)
search_env = _direct_search_env(
flat_example_inputs,
search_trials=search_trials,
search_keep_best=search_keep_best,
)
kwargs = export_kwargs or {}
extra = _export_kwargs()
@@ -488,7 +567,16 @@ def compile(
)
ep = ep.run_decompositions(_decomp_table())
return _save_and_compile(ep, factory, search_iterations)
original_weights = model.state_dict()
_strip_exported_weights_for_zero_copy(ep, original_weights)
with _temporary_env(search_env):
compiled = _save_and_compile(
ep,
factory,
search_iterations,
original_weights=original_weights,
)
return StructuredCompiledModel(compiled, example_args)
def _legacy_auto_dim(example_args):
@@ -576,12 +664,7 @@ def _eager_pt2_compile(
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
_strip_exported_weights_for_zero_copy(ep, original_weights)
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
@@ -595,11 +678,20 @@ def _eager_pt2_compile(
if torch.cuda.is_available():
torch.cuda.empty_cache()
default_search_iterations = (
50
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
else 10
)
search_iterations = int(
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
)
try:
return _save_and_compile(
pt2_path,
factory,
10,
search_iterations,
original_weights=original_weights,
user_indices=user_indices,
)

View File

@@ -0,0 +1,11 @@
# DLRM CUDA Benchmark
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
| Path | Median latency | Throughput |
| --- | ---: | ---: |
| eager | 0.267 ms | 7,674,321 candidates/s |
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |

View File

@@ -0,0 +1,213 @@
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
These tests are intended for the local integration workflow where the
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
then run the same model through ``backend=luminal_backend``.
"""
from __future__ import annotations
import copy
import sys
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
pytest.importorskip("tqdm")
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
if not DEEPCTR_ROOT.exists():
pytest.skip(
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
allow_module_level=True,
)
deepctr_root = str(DEEPCTR_ROOT)
if deepctr_root not in sys.path:
sys.path.insert(0, deepctr_root)
from deepctr_torch.inputs import (
DenseFeat,
SparseFeat,
VarLenSparseFeat,
build_input_features,
)
from deepctr_torch.models import DCN
from deepctr_torch.models.din import DIN
from luminal import luminal_backend
def _stack_features(
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
) -> torch.Tensor:
parts = []
for name in build_input_features(feature_columns):
value = np.asarray(feature_dict[name])
if value.ndim == 1:
value = np.expand_dims(value, axis=1)
parts.append(value)
stacked = np.concatenate(parts, axis=-1)
return torch.tensor(stacked, dtype=torch.float32, device=device)
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _make_dcn(
device: torch.device, cross_parameterization: str
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("s0", 5, embedding_dim=4),
SparseFeat("s1", 7, embedding_dim=4),
DenseFeat("d0", 1),
DenseFeat("d1", 1),
]
feature_dict = {
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
}
model = DCN(
linear_feature_columns=feature_columns,
dnn_feature_columns=feature_columns,
cross_num=2,
cross_parameterization=cross_parameterization,
dnn_hidden_units=(16,),
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("user", 4, embedding_dim=4),
SparseFeat("gender", 2, embedding_dim=4),
SparseFeat("item_id", 4, embedding_dim=8),
SparseFeat("cate_id", 3, embedding_dim=4),
DenseFeat("pay_score", 1),
VarLenSparseFeat(
SparseFeat(
"hist_item_id",
vocabulary_size=4,
embedding_dim=8,
embedding_name="item_id",
),
maxlen=4,
length_name="seq_length",
),
VarLenSparseFeat(
SparseFeat(
"hist_cate_id",
vocabulary_size=3,
embedding_dim=4,
embedding_name="cate_id",
),
maxlen=4,
length_name="seq_length",
),
]
feature_dict = {
"user": np.array([0, 1, 2, 3], dtype=np.int64),
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
"hist_item_id": np.array(
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"hist_cate_id": np.array(
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
}
model = DIN(
feature_columns,
["item_id", "cate_id"],
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
def test_deepctr_dcn_matches_inductor_when_supported(
device: torch.device, cross_parameterization: str
) -> None:
model, inputs = _make_dcn(device, cross_parameterization)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
model, inputs = _make_din(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")

View File

@@ -0,0 +1,456 @@
"""DLRM coverage for the luminal torch.compile backend.
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
that eager mode, ``torch.compile(..., backend="inductor")``, and
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
configurations, including a CUDA benchmark that compares Luminal against
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
"""
from __future__ import annotations
import copy
import importlib.machinery
import sys
import types
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
if not DLRM_ROOT.exists():
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
dlrm_root = str(DLRM_ROOT)
if dlrm_root not in sys.path:
sys.path.insert(0, dlrm_root)
def _install_dlrm_import_stubs() -> None:
ext_dist = types.ModuleType("extend_distributed")
ext_dist.my_size = 1
ext_dist.dist = None
ext_dist.get_split_lengths = lambda n: (n, [n])
ext_dist.get_my_slice = lambda n: slice(0, n)
class _AllToAll:
def __init__(self, values):
self._values = values
def wait(self):
return self._values
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
"extend_distributed", loader=None
)
sys.modules["extend_distributed"] = ext_dist
mlperf_logger = types.ModuleType("mlperf_logger")
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
"mlperf_logger", loader=None
)
sys.modules["mlperf_logger"] = mlperf_logger
tensorboard = types.ModuleType("torch.utils.tensorboard")
class SummaryWriter:
def __init__(self, *args, **kwargs):
pass
def add_scalar(self, *args, **kwargs):
pass
def close(self):
pass
tensorboard.SummaryWriter = SummaryWriter
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
"torch.utils.tensorboard", loader=None
)
sys.modules["torch.utils.tensorboard"] = tensorboard
onnx = types.ModuleType("onnx")
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
sys.modules["onnx"] = onnx
_install_dlrm_import_stubs()
import dlrm_s_pytorch as dlrm_mod
from luminal import luminal_backend
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend="inductor")
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(
copy.deepcopy(model),
backend="inductor",
mode="reduce-overhead",
)
def _compile_luminal(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
def _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters: int,
timed_iters: int,
mark_step_begin: bool = False,
) -> dict[str, float]:
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
with torch.no_grad():
for _ in range(warmup_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
_unwrap(compiled_model(*inputs))
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
with torch.no_grad():
for idx in range(timed_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
starts[idx].record()
_unwrap(compiled_model(*inputs))
ends[idx].record()
torch.cuda.synchronize()
elapsed_ms = np.array(
[start.elapsed_time(end) for start, end in zip(starts, ends)],
dtype=np.float64,
)
return {
"mean_ms": float(elapsed_ms.mean()),
"median_ms": float(np.median(elapsed_ms)),
"min_ms": float(elapsed_ms.min()),
}
def _timed_cuda_rounds(
compiled_model,
*inputs,
pre_round_warmup_iters: int,
timed_iters: int,
rounds: int,
mark_step_begin: bool = False,
) -> dict[str, float | list[float]]:
assert rounds > 0, "rounds must be positive"
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=pre_round_warmup_iters,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians = [stats["median_ms"]]
for _ in range(rounds - 1):
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=0,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians.append(stats["median_ms"])
round_medians_np = np.array(round_medians, dtype=np.float64)
return {
"round_medians_ms": [float(value) for value in round_medians],
"median_ms": float(np.median(round_medians_np)),
"mean_ms": float(round_medians_np.mean()),
"min_ms": float(round_medians_np.min()),
}
def _make_dlrm(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
m_spa = 4
ln_emb = np.array([8, 6, 4])
ln_bot = np.array([3, 4])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 8, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=1,
).eval()
inputs = (
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
[
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
],
[
torch.tensor([1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 3], dtype=torch.int64, device=device),
torch.tensor([2, 1], dtype=torch.int64, device=device),
],
)
return model.to(device), inputs
def _make_dlrm_batch_2048(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
batch_size = 2048
indices_per_bag = 2
m_spa = 16
ln_emb = np.array([4096, 2048, 1024])
ln_bot = np.array([3, 64, m_spa])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 64, 32, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=2,
).eval()
dense_x = torch.linspace(
-1.0,
1.0,
steps=batch_size * 3,
dtype=torch.float32,
device=device,
).reshape(batch_size, 3)
total_sparse_indices = batch_size * indices_per_bag
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
offsets = torch.arange(
0,
total_sparse_indices,
indices_per_bag,
dtype=torch.int64,
device=device,
)
inputs = (
dense_x,
[offsets.clone(), offsets.clone(), offsets.clone()],
[
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
],
)
return model.to(device), inputs
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
model, inputs = _make_dlrm(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
@pytest.mark.slow
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
device: torch.device,
) -> None:
if device.type != "cuda":
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
model, inputs = _make_dlrm_batch_2048(device)
eager = _run_eager(model, *inputs)
eager_model = copy.deepcopy(model).to(device).eval()
inductor_compiled = _compile_inductor(model)
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad():
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
luminal_compiled = _compile_luminal(model)
with torch.no_grad():
luminal_output = _unwrap(luminal_compiled(*inputs))
with torch.no_grad():
inductor_default_output = _unwrap(inductor_compiled(*inputs))
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
_assert_allclose(
luminal_output,
inductor_default_output,
"luminal vs inductor default",
atol=1e-4,
)
_assert_allclose(
luminal_output,
inductor_output,
"luminal vs inductor reduce-overhead",
atol=1e-4,
)
benchmark_rounds = 5
benchmark_iters = 20
post_compile_warmup_iters = 10
eager_stats = _timed_cuda_rounds(
eager_model,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_default_stats = _timed_cuda_rounds(
inductor_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_stats = _timed_cuda_rounds(
inductor_reduce_overhead,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
mark_step_begin=True,
)
luminal_stats = _timed_cuda_rounds(
luminal_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
batch_size = inputs[0].shape[0]
benchmark_results = [
("eager", eager_stats),
("inductor default", inductor_default_stats),
("inductor reduce-overhead", inductor_stats),
("luminal backend", luminal_stats),
]
ranked_results = sorted(
benchmark_results,
key=lambda item: float(item[1]["median_ms"]),
)
speed_lines = []
for idx, (label, stats) in enumerate(ranked_results, start=1):
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
rounds_repr = ", ".join(
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
)
speed_lines.append(
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
f" ({throughput:,.0f} candidates/s)"
f" [round medians: {rounds_repr}]"
)
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
inductor_stats["median_ms"]
)
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
print(
"\n"
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
f" Ranking by median latency:\n"
+ "\n".join(speed_lines)
+ "\n"
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
f" {luminal_vs_inductor:.3f}x\n"
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
)

View File

@@ -7,13 +7,14 @@
//! - [`NativeDynBackend`]: the reference implementation for CPU
use std::collections::HashMap;
use std::time::Duration;
use half::{bf16, f16};
use petgraph::stable_graph::NodeIndex;
use rustc_hash::FxHashMap;
use crate::dtype::DType;
use crate::graph::Graph;
use crate::graph::{Graph, SearchOptions};
use crate::hlir::{NativeData, NativeRuntime, Output};
use crate::op::Runtime;
@@ -46,6 +47,18 @@ pub trait DynBackend {
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
// --- Optional diagnostics --------------------------------------------
fn kernel_names(&self) -> Vec<String> {
vec![]
}
fn host_op_names(&self) -> Vec<String> {
vec![]
}
fn print_execution_stats(&self) {}
// --- Optional device pointer support (GPU backends) --------------------
fn supports_device_ptrs(&self) -> bool {
@@ -162,7 +175,9 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, args.search_iters);
let mut rng = rand::rng();
let search_options = search_options_from_env(args.search_iters);
let mut rt = graph.search_options(rt, search_options, &mut rng);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);
@@ -179,6 +194,39 @@ pub fn compile_backend<Rt: Runtime + 'static>(
Ok(wrap(rt))
}
fn env_usize(name: &str) -> Option<usize> {
std::env::var(name).ok()?.parse().ok()
}
fn env_duration_ms(name: &str) -> Option<Duration> {
Some(Duration::from_millis(env_usize(name)? as u64))
}
fn search_options_from_env(limit: usize) -> SearchOptions {
let mut options = SearchOptions::new(limit);
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
options = options.generation_size(generation_size);
}
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
options = options.mutations(mutations);
}
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
options = options.trials(trials);
}
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
options = options.keep_best(keep_best);
}
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
options = options.profile_timeout(profile_timeout);
}
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
options = options.group_timeout(group_timeout);
}
options
}
// ---------------------------------------------------------------------------
// Shared utilities
// ---------------------------------------------------------------------------

View File

@@ -217,32 +217,38 @@ pub fn reduce_sort(name: &str) -> SortDef {
}
pub type HLIROps = (
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Scatter,
SumReduce,
MaxReduce,
Softmax,
(
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
),
(
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Concat2D,
EmbeddingBagSum,
Scatter,
SumReduce,
MaxReduce,
Softmax,
),
);
#[derive(Default, Debug, Clone)]
@@ -1721,7 +1727,9 @@ impl NativeOp for Add {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
}
}
}
}
@@ -1808,7 +1816,9 @@ impl NativeOp for Mul {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
}
}
}
}
@@ -2126,6 +2136,233 @@ impl NativeOp for Gather {
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Concat2D {
pub rows: Expression,
pub lhs_cols: Expression,
pub rhs_cols: Expression,
}
impl Display for Concat2D {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Concat2D")
}
}
impl HLIROp for Concat2D {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (Concat2D {} {} {}) {})",
self.rows.to_egglog(),
self.lhs_cols.to_egglog(),
self.rhs_cols.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
)
}
}
impl EgglogOp for Concat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"Concat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for Concat2D {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let rows = self.rows.exec(dyn_map).unwrap();
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
fn concat_rows<T: Clone>(
lhs: &[T],
rhs: &[T],
rows: usize,
lhs_cols: usize,
rhs_cols: usize,
) -> Vec<T> {
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
for row in 0..rows {
let lhs_base = row * lhs_cols;
let rhs_base = row * rhs_cols;
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
}
out
}
match (inputs[0], inputs[1]) {
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
_ => panic!("Concat2D inputs must have the same dtype"),
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct EmbeddingBagSum {
pub n_bags: Expression,
pub n_indices: Expression,
pub hidden_dim: Expression,
pub num_embeddings: Expression,
}
impl Display for EmbeddingBagSum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "EmbeddingBagSum")
}
}
impl HLIROp for EmbeddingBagSum {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (EmbeddingBagSum {} {} {} {}) {})",
self.n_bags.to_egglog(),
self.n_indices.to_egglog(),
self.hidden_dim.to_egglog(),
self.num_embeddings.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
)
}
}
impl EgglogOp for EmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"EmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for EmbeddingBagSum {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let n_bags = self.n_bags.exec(dyn_map).unwrap();
let n_indices = self.n_indices.exec(dyn_map).unwrap();
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
let clamp_index = |value: i32, limit: usize| -> usize {
if limit == 0 {
return 0;
}
value.clamp(0, limit.saturating_sub(1) as i32) as usize
};
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
let NativeData::Int(indices) = inputs[1] else {
panic!("EmbeddingBagSum indices must be Int")
};
let NativeData::Int(offsets) = inputs[2] else {
panic!("EmbeddingBagSum offsets must be Int")
};
let bag_bounds = |bag: usize| -> (usize, usize) {
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
let end = offsets
.get(bag + 1)
.copied()
.map(clamp_offset)
.unwrap_or(n_indices);
(start, end.max(start))
};
match inputs[0] {
NativeData::F32(weight) => {
let mut out = vec![0.0f32; n_bags * hidden_dim];
for bag in 0..n_bags {
let (start, end) = bag_bounds(bag);
let out_base = bag * hidden_dim;
for pos in start..end {
let row = clamp_index(indices[pos], num_embeddings);
let row_base = row * hidden_dim;
for dim in 0..hidden_dim {
out[out_base + dim] += weight[row_base + dim];
}
}
}
NativeData::F32(out)
}
_ => panic!("EmbeddingBagSum only supports F32 weights"),
}
}
}
// Scatter Op (inverse of Gather)
#[derive(Debug, Clone, Default, PartialEq)]