Compare commits

...

1 Commits

Author SHA1 Message Date
Tucker Morgan
d6b0eb0ec1 Add recommender model compile coverage 2026-05-13 21:40:15 +00:00
21 changed files with 2189 additions and 98 deletions

View File

@@ -5,6 +5,7 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
use luminal::prelude::*;
use crate::cudarc::driver::CudaContext;
use crate::host::describe_host_op;
use crate::runtime::CudaRuntime;
/// [`DynBackend`] wrapper for [`CudaRuntime`].
@@ -39,6 +40,26 @@ impl DynBackend for CudaLiteDynBackend {
self.runtime.execute(dyn_map);
}
fn kernel_names(&self) -> Vec<String> {
self.runtime
.kernel_names()
.iter()
.map(|name| (*name).to_string())
.collect()
}
fn host_op_names(&self) -> Vec<String> {
self.runtime
.host_ops()
.iter()
.map(|op| describe_host_op(*op))
.collect()
}
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
fn supports_device_ptrs(&self) -> bool {
true
}

View File

@@ -247,6 +247,10 @@ impl HostOp for CuBlasSgemmV2 {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasSgemmV2")
}
fn output_size(&self) -> Expression {
self.m * self.n
}

View File

@@ -419,7 +419,6 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
}
}
#[cfg(test)]
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
match epilogue {
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
@@ -978,6 +977,18 @@ impl CuBlasLt {
&& normalize(self.stride_c) == normalize(self.stride_d)
&& self.c_order == self.d_order
}
pub(crate) fn debug_summary(&self) -> String {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
format!(
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
resolve(&self.m),
resolve(&self.n),
resolve(&self.k),
resolve(&self.batch_count),
epilogue_name(self.epilogue),
)
}
}
impl HostOp for CuBlasLt {
@@ -1114,6 +1125,10 @@ impl HostOp for CuBlasLt {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasLt")
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)

View File

@@ -1,6 +1,7 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use crate::kernel::CudaGraphOp;
use luminal::{op::EgglogOp, prelude::*};
pub mod compute_attn_mask;
mod cublas;
@@ -79,6 +80,24 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
.map(cublaslt::CuBlasLt::c_d_layouts_match)
}
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
return op.debug_summary();
}
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
let mut summary = op.debug_summary();
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& let Some(timing) = op.debug_timing_summary()
{
summary.push_str(" [");
summary.push_str(&timing);
summary.push(']');
}
return summary;
}
op.stats_name().unwrap_or("unknown").to_string()
}
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside

View File

@@ -12,7 +12,10 @@ use luminal::{
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
extract_dtype, extract_expr, extract_expr_list,
},
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
hlir::{
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
Sin, Sqrt, SumReduce,
},
op::*,
prelude::*,
};
@@ -65,6 +68,8 @@ pub type Ops = (
KernelConstant,
KernelCast,
KernelEmbed,
KernelConcat2D,
KernelEmbeddingBagSum,
);
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
@@ -3354,3 +3359,361 @@ extern \"C\" {{
"Embed"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelEmbeddingBagSum {
n_bags: Expression,
n_indices: Expression,
hidden_dim: Expression,
num_embeddings: Expression,
dtype: DType,
}
impl EgglogOp for KernelEmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelEmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[4]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelEmbeddingBagSum {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
self.dtype
);
let vars = self
.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_bags = self.n_bags.to_kernel();
let n_indices = self.n_indices.to_kernel();
let hidden_dim = self.hidden_dim.to_kernel();
let num_embeddings = self.num_embeddings.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long bag = blockIdx.y;
long long hidden_dim = {hidden_dim};
long long n_bags = {n_bags};
long long n_indices = {n_indices};
long long num_embeddings = {num_embeddings};
if (bag >= n_bags || dim >= hidden_dim) return;
int start_raw = offsets[bag];
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
int start = max(0, min(start_raw, (int)n_indices));
int end = max(start, min(end_raw, (int)n_indices));
float sum = 0.0f;
for (int pos = start; pos < end; ++pos) {{
int row = indices[pos];
row = max(0, min(row, (int)num_embeddings - 1));
sum += weight[(long long)row * hidden_dim + dim];
}}
out[bag * hidden_dim + dim] = sum;
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embedding_bag_sum").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
(self.hidden_dim.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.n_bags * self.hidden_dim
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
// Approximate: weights + indices + offsets
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.n_indices * self.hidden_dim
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"EmbeddingBagSum"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelConcat2D {
rows: Expression,
lhs_cols: Expression,
rhs_cols: Expression,
dtype: DType,
}
impl EgglogOp for KernelConcat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelConcat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<Concat2D, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelConcat2D {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelConcat2D only supports F32 today, got {:?}",
self.dtype
);
let vars = self
.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let rows = self.rows.to_kernel();
let lhs_cols = self.lhs_cols.to_kernel();
let rhs_cols = self.rhs_cols.to_kernel();
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = {total};
if (idx >= total) return;
long long rows = {rows};
long long lhs_cols = {lhs_cols};
long long rhs_cols = {rhs_cols};
long long out_cols = lhs_cols + rhs_cols;
if (rows == 0 || out_cols == 0) return;
long long row = idx / out_cols;
long long col = idx - row * out_cols;
if (col < lhs_cols) {{
out[idx] = lhs[row * lhs_cols + col];
}} else {{
long long rhs_col = col - lhs_cols;
out[idx] = rhs[row * rhs_cols + rhs_col];
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("concat_2d").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let output_size = self.output_size();
(
func,
module,
kernel,
(output_size.ceil_div(256), 1.into(), 1.into()),
(output_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.rows * (self.lhs_cols + self.rhs_cols)
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
0.into()
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Concat2D"
}
}

View File

@@ -4,6 +4,8 @@
//! that can be executed like any other HostOp.
use std::cell::RefCell;
use std::cmp::Reverse;
use std::collections::BTreeMap;
use std::sync::Arc;
use cudarc::driver::{
@@ -141,6 +143,8 @@ struct CudaGraphOpState {
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
/// Timing events for profiling
timing_events: Vec<cudarc::driver::sys::CUevent>,
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
last_kernel_timings_us: Vec<(&'static str, f64)>,
}
impl CudaGraphOpState {
@@ -155,6 +159,7 @@ impl CudaGraphOpState {
last_dyn_values: FxHashMap::default(),
last_buffer_ptrs: FxHashMap::default(),
timing_events: Vec::new(),
last_kernel_timings_us: Vec::new(),
}
}
}
@@ -192,6 +197,41 @@ impl CudaGraphOp {
state: RefCell::new(state),
}
}
pub fn debug_summary(&self) -> String {
let state = self.state.borrow();
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
for kernel in &state.kernels {
*counts.entry(kernel.kernel_name).or_default() += 1;
}
let mut counts: Vec<_> = counts.into_iter().collect();
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
let top = counts
.into_iter()
.take(4)
.map(|(name, count)| format!("{name}x{count}"))
.join(", ");
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
}
pub fn debug_timing_summary(&self) -> Option<String> {
let state = self.state.borrow();
if state.last_kernel_timings_us.is_empty() {
return None;
}
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
for (name, us) in &state.last_kernel_timings_us {
*totals.entry(*name).or_default() += *us;
}
let mut totals: Vec<_> = totals.into_iter().collect();
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
let top = totals
.into_iter()
.take(4)
.map(|(name, us)| format!("{name}={us:.0}us"))
.join(", ");
Some(top)
}
}
impl std::fmt::Debug for CudaGraphOp {
@@ -566,6 +606,23 @@ impl CudaGraphOp {
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& state.timing_events.len() >= state.kernels.len() + 1
{
stream.synchronize()?;
let ctx = stream.context().clone();
state.last_kernel_timings_us.clear();
for idx in 0..state.kernels.len() {
let start_event = state.timing_events[idx];
let end_event = state.timing_events[idx + 1];
let kernel_name = state.kernels[idx].kernel_name;
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
state.last_kernel_timings_us.push((kernel_name, us));
}
}
Ok(())
}
@@ -584,8 +641,9 @@ impl CudaGraphOp {
state.kernel_params.clear();
state.kernel_params.reserve(num_kernels);
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
let tracing_enabled = enabled!(Level::TRACE);
if tracing_enabled {
if tracing_enabled || profile_cuda_graph {
let needed_events = num_kernels + 1;
while state.timing_events.len() < needed_events {
state.timing_events.push(create_cuda_event(&ctx)?);
@@ -701,7 +759,7 @@ impl CudaGraphOp {
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
let timing_event = if tracing_enabled || profile_cuda_graph {
Some(state.timing_events[idx])
} else {
None
@@ -739,7 +797,9 @@ impl CudaGraphOp {
prev_graph_node = Some(graph_node);
}
if tracing_enabled && let Some(prev) = prev_graph_node {
if (tracing_enabled || profile_cuda_graph)
&& let Some(prev) = prev_graph_node
{
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}

View File

@@ -1,6 +1,9 @@
use crate::{
host::{DeviceBuffer, HostOp},
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
host::{DeviceBuffer, HostOp, describe_host_op},
kernel::{
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
record_cuda_graph_timings, record_event_on_stream,
},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
@@ -60,6 +63,12 @@ pub struct KernelStats {
pub tflops: f64,
}
#[derive(Debug, Clone)]
pub struct HostOpStats {
pub name: String,
pub execution_time_us: f64,
}
impl Debug for ExecutableHostOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HostOp: ({:?})", self.internal)
@@ -141,6 +150,7 @@ pub struct CudaRuntime {
changed_hlir: FxHashSet<NodeIndex>,
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
pub last_kernel_stats: Vec<KernelStats>,
pub last_host_op_stats: Vec<HostOpStats>,
pub last_total_time_us: f64,
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
/// When true, execute() skips input buffer consumption (used during search/profile)
@@ -1203,6 +1213,7 @@ impl Runtime for CudaRuntime {
changed_hlir: FxHashSet::default(),
cuda_graph_timings: vec![],
last_kernel_stats: vec![],
last_host_op_stats: vec![],
last_total_time_us: 0.0,
kernel_cache: FxHashMap::default(),
profiling: false,
@@ -1454,6 +1465,8 @@ impl Runtime for CudaRuntime {
let total_start = std::time::Instant::now();
let bucket = &self.compiled_buckets[self.active_bucket];
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
let mut host_timing_events = Vec::new();
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
let exec_op = &bucket.exec_graph[exec_node];
@@ -1507,6 +1520,16 @@ impl Runtime for CudaRuntime {
n_inputs = exec_op.inputs.len()
)
.entered();
let host_op_timing = if profile_host_ops {
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
let ctx = exec_op.stream.context().clone();
let start_event = create_cuda_event(&ctx).unwrap();
let end_event = create_cuda_event(&ctx).unwrap();
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
Some((name, ctx, start_event, end_event))
} else {
None
};
exec_op
.internal
.execute(
@@ -1522,10 +1545,28 @@ impl Runtime for CudaRuntime {
exec_op.internal.stats_name().unwrap_or("unknown")
);
});
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
host_timing_events.push((name, ctx, start_event, end_event));
}
}
// Single sync at end - CUDA stream ordering guarantees sequential execution
self.cuda_stream.synchronize().unwrap();
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
self.last_host_op_stats.clear();
if profile_host_ops {
for (name, ctx, start_event, end_event) in host_timing_events {
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
self.last_host_op_stats.push(HostOpStats {
name,
execution_time_us,
});
destroy_cuda_event(&ctx, start_event);
destroy_cuda_event(&ctx, end_event);
}
}
// Populate last_kernel_stats from HostOps that report stats
self.last_kernel_stats.clear();
@@ -1861,6 +1902,16 @@ impl CudaRuntime {
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
if !self.last_host_op_stats.is_empty() {
println!("\n=== Host Operation Statistics ===\n");
println!("{:<20} {:>12}", "HostOp", "Time (us)");
println!("{}", "-".repeat(34));
for s in &self.last_host_op_stats {
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
}
println!("{}", "-".repeat(34));
}
// Print kernel stats
if !self.last_kernel_stats.is_empty() {
println!("\n=== Kernel Execution Statistics ===\n");

View File

@@ -248,6 +248,23 @@ impl CompiledGraph {
self.runtime.device_type()
}
/// Names of kernels compiled into the active runtime bucket, if available.
#[getter]
fn kernel_names(&self) -> Vec<String> {
self.runtime.kernel_names()
}
/// Names of host ops in the active runtime bucket, if available.
#[getter]
fn host_op_names(&self) -> Vec<String> {
self.runtime.host_op_names()
}
/// Print backend execution statistics for the last run, if supported.
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
#[getter]
fn supports_device_ptrs(&self) -> bool {

View File

@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Result, bail};
use luminal::prelude::*;
use crate::pt2_schema::*;
@@ -8,21 +8,62 @@ use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let arg1 = &node.inputs[1].arg;
if let Some(name) = arg1.as_tensor_name() {
let b = self.get_tensor(name)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
} else {
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
let alpha = match op {
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
BinaryOp::Mul | BinaryOp::Div => 1.0,
};
let lhs = node.inputs[0]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
let rhs = node.inputs[1]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
match (lhs, rhs) {
(Some(a), Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(Some(a), None) => {
let mut val = self.get_float_arg(node, 1)? as f32;
if alpha != 1.0 {
val *= alpha;
}
Ok(self.apply_scalar_op(a, val, op))
}
(None, Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let lhs_val = self.get_float_arg(node, 0)? as f32;
let a = self
.graph
.constant_float(lhs_val)
.cast(b.dtype)
.expand_rhs(b.shape);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(None, None) => bail!("{} expects at least one tensor operand", node.target),
}
}

View File

@@ -111,6 +111,7 @@ impl<'a> Translator<'a> {
result
}
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
"torch.ops.aten.clone.default" => {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
@@ -133,8 +134,28 @@ impl<'a> Translator<'a> {
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
if alpha == 0.0 && beta == 0.0 {
self.graph
.constant_float(0.0)
.cast(mm.dtype)
.expand_rhs(mm.shape)
} else if beta == 0.0 {
if alpha == 1.0 { mm } else { mm * alpha }
} else if alpha == 0.0 {
let input = if beta == 1.0 { input } else { input * beta };
let zero = self
.graph
.constant_float(0.0)
.cast(input.dtype)
.expand_rhs(mm.shape);
let (input, _) = broadcast_binary(input, zero);
input
} else {
let input = if beta == 1.0 { input } else { input * beta };
let mm = if alpha == 1.0 { mm } else { mm * alpha };
let (input, mm) = broadcast_binary(input, mm);
input + mm
}
}
// Convolution
@@ -147,8 +168,14 @@ impl<'a> Translator<'a> {
// Slice/index ops
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
"torch.ops.aten._embedding_bag.default"
| "torch.ops.aten._embedding_bag_forward_only.default" => {
self.translate_embedding_bag(node)?
}
"<built-in function getitem>" => self.translate_getitem(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
@@ -163,6 +190,9 @@ impl<'a> Translator<'a> {
// LayerNorm
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
self.translate_native_batch_norm_no_training(node)?
}
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,

View File

@@ -12,6 +12,64 @@ const SCATTER_INDEX_ARG: usize = 2;
const SCATTER_VALUE_ARG: usize = 3;
impl<'a> Translator<'a> {
fn try_concat_2d_fast(
&mut self,
lhs: GraphTensor,
rhs: GraphTensor,
axis: usize,
) -> Option<GraphTensor> {
if axis != 1
|| lhs.dtype != DType::F32
|| rhs.dtype != DType::F32
|| lhs.shape.len() != 2
|| rhs.shape.len() != 2
|| !lhs.shape.is_contiguous()
|| !rhs.shape.is_contiguous()
|| lhs.shape.dims[0] != rhs.shape.dims[0]
{
return None;
}
let rows = lhs.shape.dims[0];
let lhs_cols = lhs.shape.dims[1];
let rhs_cols = rhs.shape.dims[1];
let id = self.graph.add_op(
luminal::hlir::Concat2D {
rows,
lhs_cols,
rhs_cols,
},
&[lhs.id, rhs.id],
);
Some(GraphTensor::from_id(
id,
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
lhs.graph_ref,
lhs.dtype,
))
}
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
let index = self
.get_int_arg(node, 2)
.context("select.int: missing index")?;
let dim_size = a.shape.dims[dim]
.to_usize()
.context("select.int: symbolic dims are not supported for negative indices")?;
let normalized_index = if index < 0 {
(dim_size as i64 + index) as usize
} else {
index as usize
};
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
.squeeze(dim))
}
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -80,6 +138,43 @@ impl<'a> Translator<'a> {
Ok(a)
}
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
let mut a = self.get_input_tensor(node, 0)?;
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
sizes
.into_iter()
.map(|size| {
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
Ok(Expression::from(size as usize))
})
.collect::<Result<_>>()?
} else {
self.get_exprs_arg(node, 1)?
};
anyhow::ensure!(
repeats.len() >= a.shape.len(),
"repeat: repeats rank {} is smaller than input rank {}",
repeats.len(),
a.shape.len()
);
while a.shape.len() < repeats.len() {
a = a.unsqueeze(0);
}
Ok(a.repeat(repeats))
}
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
let index = self.get_int_arg(node, 1)?;
anyhow::ensure!(
index == 0,
"getitem: only tuple[0] access is supported today, got index={index}"
);
self.get_input_tensor(node, 0)
}
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1).unwrap_or(0);
@@ -161,7 +256,11 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, tensors[0].shape.len());
let mut result = tensors[0];
for t in &tensors[1..] {
result = result.concat_along(*t, dim);
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
result = fast;
} else {
result = result.concat_along(*t, dim);
}
}
Ok(result)
}
@@ -218,6 +317,79 @@ impl<'a> Translator<'a> {
bail!("index.Tensor: no index tensors in optional_tensors list");
}
index_names = found_tensors;
// Multiple explicit index tensors after leading `None`s mean
// "keep the prefix dims, then advanced-index the contiguous
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
if first_non_none_dim > 0
&& index_names.len() > 1
&& first_non_none_dim + index_names.len() == source.shape.len()
{
let src_dims = source.shape.dims;
let indexed_dims = &src_dims[first_non_none_dim..];
let n_indexed = index_names.len();
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * indexed_dims[i + 1];
}
let mut flat_idx: Option<GraphTensor> = None;
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
let axis_size = indexed_dims[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
idx_int
} else {
idx_int * stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (acc_b, w_b) = broadcast_binary(acc, weighted);
acc_b + w_b
}
None => weighted,
});
}
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
let idx_shape = flat_idx.shape.dims.to_vec();
let mut idx_numel = Expression::from(1usize);
for dim in &idx_shape {
idx_numel *= *dim;
}
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
let mut indexed_size = Expression::from(1usize);
for dim in indexed_dims {
indexed_size *= *dim;
}
let mut flat_source_shape = prefix_dims.clone();
flat_source_shape.push(indexed_size);
let flat_source = reshape_tensor(source, flat_source_shape);
let mut expanded_idx = flat_idx;
for _ in 0..prefix_dims.len() {
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
}
let mut target = prefix_dims.clone();
target.push(idx_numel);
expanded_idx.shape.expand(target);
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
let mut result_shape = prefix_dims;
result_shape.extend_from_slice(&idx_shape);
return Ok(reshape_tensor(gathered, result_shape));
}
// Simple case: single non-None index on a specific dim → gather_elements
if first_non_none_dim > 0 && index_names.len() == 1 {
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);

View File

@@ -28,6 +28,45 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
impl<'a> Translator<'a> {
fn translate_embedding_bag_generic(
&mut self,
weight: GraphTensor,
indices: GraphTensor,
offsets: GraphTensor,
) -> Result<GraphTensor> {
let hidden_dim = weight.shape.dims[1];
let n_indices = indices.shape.dims[0];
let n_bags = offsets.shape.dims[0];
// Gather per-index embeddings: [E] -> [E, D].
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
let gathered = weight.gather(ids_expanded + arange);
// Bag assignment per position:
// bag_id[pos] = count(offsets <= pos) - 1
// This supports empty bags too, because repeated offsets simply skip a
// bag id when no positions land in that interval.
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
let starts = offsets.expand_dim(1, n_indices);
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
- self
.graph
.constant_float(1.0)
.cast(DType::Int)
.expand_rhs(vec![n_indices]);
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
let bag_ids = bag_ids.expand_dim(0, n_bags);
let mask = bag_ids
.eq(bag_axis)
.expand_dim(2, hidden_dim)
.cast(gathered.dtype);
let gathered = gathered.expand_dim(0, n_bags);
Ok((gathered * mask).sum(1))
}
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
.inputs
@@ -180,6 +219,45 @@ impl<'a> Translator<'a> {
Ok(value.expand_rhs(reference.shape))
}
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
let mode = self.get_int_arg(node, 4).unwrap_or(0);
anyhow::ensure!(
mode == 0,
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
);
anyhow::ensure!(
indices.shape.len() == 1 && offsets.shape.len() == 1,
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
indices.shape.len(),
offsets.shape.len()
);
if weight.dtype == DType::F32 {
let id = self.graph.add_op(
luminal::hlir::EmbeddingBagSum {
n_bags: offsets.shape.dims[0],
n_indices: indices.shape.dims[0],
hidden_dim: weight.shape.dims[1],
num_embeddings: weight.shape.dims[0],
},
&[weight.id, indices.id, offsets.id],
);
return Ok(GraphTensor::from_id(
id,
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
weight.graph_ref,
DType::F32,
));
}
self.translate_embedding_bag_generic(weight, indices, offsets)
}
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
let output_name = node
.outputs

View File

@@ -21,6 +21,30 @@ const DIV_MODE_INPUT_ARG: usize = 0;
const DIV_MODE_OTHER_ARG: usize = 1;
impl<'a> Translator<'a> {
fn expand_channel_parameter(
&self,
input: GraphTensor,
parameter: GraphTensor,
) -> Result<GraphTensor> {
anyhow::ensure!(
input.shape.len() >= 2,
"batch_norm: expected rank >= 2 input, got rank {}",
input.shape.len()
);
anyhow::ensure!(
parameter.shape.len() == 1,
"batch_norm: expected 1D channel parameter, got rank {}",
parameter.shape.len()
);
let mut expanded = parameter.unsqueeze(0);
for axis in 2..input.shape.len() {
expanded = expanded.unsqueeze(axis);
}
expanded.shape.expand(input.dims().to_vec());
Ok(expanded)
}
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
@@ -101,6 +125,30 @@ impl<'a> Translator<'a> {
Ok(result)
}
pub(crate) fn translate_native_batch_norm_no_training(
&mut self,
node: &Node,
) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
let mut result = (input - running_mean) / (running_var + eps).sqrt();
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
result = result * weight;
}
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
result = result + bias;
}
Ok(result)
}
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let zero = self

View File

@@ -8,12 +8,14 @@ from .compiled_model import CompiledModel
# Import Rust extension components (built by maturin)
from .luminal import CompiledGraph, process_pt2
from .main import luminal_backend, register_backend
from .pt2 import compile
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
"CompiledModel",
"compile",
"luminal_backend",
"register_backend",
"CompiledGraph",

View File

@@ -1,9 +1,8 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
from typing import List
from typing import Any, List
import torch
from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
@@ -28,6 +27,14 @@ class CompiledModel:
self._input_names = input_names or graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
output_dtype_codes = graph_result.output_dtypes
self._output_dtypes = [
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
for i in range(len(self._output_names))
]
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
self._weight_refs = weight_refs or []
self._user_indices = user_indices
@@ -35,6 +42,9 @@ class CompiledModel:
self._supports_device_ptrs = getattr(
graph_result, "supports_device_ptrs", False
)
# Cache converted/contiguous views for repeated calls with the same input
# tensors so we don't rebuild sparse index buffers every forward.
self._prepared_input_cache = [None] * len(self._input_names)
# Expected input dtypes from graph (used to convert user inputs)
input_dtype_codes = graph_result.input_dtypes
self._input_dtypes = [
@@ -43,6 +53,80 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
self._single_float_output_fast_path = (
self._supports_device_ptrs
and not self._has_dynamic_dims
and len(self._output_names) == 1
and self._output_dtypes[0].is_floating_point
)
@staticmethod
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
return (
id(tensor),
getattr(tensor, "_version", None),
tensor.data_ptr(),
expected_dtype,
)
def _prepare_input_tensor(
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
) -> torch.Tensor:
detached = tensor.detach()
needs_preparation = (
detached.dtype != expected_dtype or not detached.is_contiguous()
)
if not needs_preparation:
return detached
cache_key = self._input_cache_key(detached, expected_dtype)
cached = self._prepared_input_cache[index]
if cached is not None and cached[0] == cache_key:
return cached[1]
prepared = detached.contiguous().to(expected_dtype)
self._prepared_input_cache[index] = (cache_key, prepared)
return prepared
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Bind the current user inputs into the Rust graph."""
input_refs = []
for index, (name, tensor, expected_dtype) in enumerate(
zip(self._input_names, user_inputs, self._input_dtypes)
):
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
if self._supports_device_ptrs and tensor.is_cuda:
n_bytes = prepared.numel() * prepared.element_size()
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
input_refs.append(prepared)
else:
if prepared.device.type != "cpu":
prepared = prepared.cpu()
n_bytes = prepared.numel() * prepared.element_size()
dtype_code = _torch_dtype_code(prepared.dtype)
self._graph.set_input_from_ptr(
name, prepared.data_ptr(), n_bytes, dtype_code
)
return input_refs
def _run_static_single_float_output(
self, user_inputs: List[torch.Tensor], input_device: torch.device
):
_input_refs = self._bind_user_inputs(user_inputs)
output_name = self._output_names[0]
output_dtype = self._output_dtypes[0]
out = torch.empty(
self._static_output_shapes[0], dtype=output_dtype, device=input_device
)
self._graph.set_output_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
self._graph.run()
if not self._graph.output_is_zero_copy(output_name):
self._graph.copy_output_to_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
return (out,)
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -86,25 +170,18 @@ class CompiledModel:
if self._has_dynamic_dims:
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
elif (
self._single_float_output_fast_path
and input_device.type != "cpu"
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
):
return self._run_static_single_float_output(user_inputs, input_device)
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = []
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
_input_refs = self._bind_user_inputs(user_inputs)
# Resolve output shapes before run() (needed for pre-allocation).
if self._has_dynamic_dims:
@@ -112,8 +189,6 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
output_dtype_codes = self._graph.output_dtypes
# CUDA zero-copy path: pre-allocate output tensors and register their device
# pointers so the final kernel writes directly into PyTorch's buffer.
_use_zero_copy = self._supports_device_ptrs
@@ -121,8 +196,8 @@ class CompiledModel:
if _use_zero_copy:
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = torch.empty(shape, dtype=out_dtype, device=input_device)
@@ -140,8 +215,8 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = output_tensors[i]
@@ -178,8 +253,8 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
if out_dtype == torch.int32:
@@ -199,3 +274,41 @@ class CompiledModel:
outputs.append(out)
return tuple(outputs)
def _leaf_paths(tree: Any, prefix=()):
if torch.is_tensor(tree):
return [prefix]
if isinstance(tree, (list, tuple)):
paths = []
for idx, value in enumerate(tree):
paths.extend(_leaf_paths(value, prefix + (idx,)))
return paths
if isinstance(tree, dict):
paths = []
for key, value in tree.items():
paths.extend(_leaf_paths(value, prefix + (key,)))
return paths
return [prefix]
def _follow_path(tree: Any, path):
value = tree
for key in path:
value = value[key]
return value
class StructuredCompiledModel:
"""Preserve a module's original nested input structure for direct PT2 compile()."""
def __init__(self, compiled_model, example_args):
self._compiled = compiled_model
self._leaf_paths = _leaf_paths(example_args)
def __getattr__(self, name):
return getattr(self._compiled, name)
def __call__(self, *args):
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
return self._compiled(*flat_inputs)

View File

@@ -9,10 +9,12 @@ import inspect
import os
import shutil
import tempfile
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
from .compiled_model import CompiledModel
from .compiled_model import CompiledModel, StructuredCompiledModel
from .luminal import process_pt2
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
@@ -184,6 +186,66 @@ def _save_and_compile(
shutil.rmtree(tmpdir, ignore_errors=True)
def _has_cuda_inputs(flat_example_inputs):
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
"""Search env overrides for direct compile() calls.
CUDA DLRM-style models benefit materially from a deeper per-candidate
profile and from keeping more parents alive between generations. Keep the
defaults narrow so CPU and env-configured callers are unchanged.
"""
has_cuda = _has_cuda_inputs(flat_example_inputs)
overrides = {}
if search_trials is not None:
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
if search_keep_best is not None:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
return overrides
@contextmanager
def _temporary_env(overrides):
sentinel = object()
previous = {}
try:
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = old_value
def _strip_exported_weights_for_zero_copy(ep, original_weights):
"""Shrink the saved .pt2 artifact when original weights will be reused."""
if not original_weights:
return
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
if isinstance(orig, torch.nn.Parameter):
replacement = torch.nn.Parameter(
replacement, requires_grad=orig.requires_grad
)
ep._state_dict[key] = replacement
del orig
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
@@ -401,7 +463,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
def compile(
model,
example_input,
search_iterations=25,
search_iterations=None,
search_trials=None,
search_keep_best=None,
factory=None,
export_kwargs=None,
dynamic_dim=None,
@@ -413,7 +477,13 @@ def compile(
model: A PyTorch nn.Module.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations.
search_iterations: Number of optimization search iterations. When None,
defaults to 200 on CUDA inputs and 10 otherwise.
search_trials: Optional per-candidate profiling trials inside Luminal's
search. When unset, direct CUDA compile defaults to 5.
search_keep_best: Optional number of parent candidates to retain
between search generations. When unset, direct CUDA compile
defaults to 3.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
@@ -432,17 +502,26 @@ def compile(
Returns:
A CompiledModel callable.
"""
if factory is None:
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
if factory is None:
factory = _detect_factory_capsule(flat_example_inputs)
if search_iterations is None:
search_iterations = (
200
if _has_cuda_inputs(flat_example_inputs)
else 10
)
search_env = _direct_search_env(
flat_example_inputs,
search_trials=search_trials,
search_keep_best=search_keep_best,
)
kwargs = export_kwargs or {}
extra = _export_kwargs()
@@ -488,7 +567,16 @@ def compile(
)
ep = ep.run_decompositions(_decomp_table())
return _save_and_compile(ep, factory, search_iterations)
original_weights = model.state_dict()
_strip_exported_weights_for_zero_copy(ep, original_weights)
with _temporary_env(search_env):
compiled = _save_and_compile(
ep,
factory,
search_iterations,
original_weights=original_weights,
)
return StructuredCompiledModel(compiled, example_args)
def _legacy_auto_dim(example_args):
@@ -576,12 +664,7 @@ def _eager_pt2_compile(
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
_strip_exported_weights_for_zero_copy(ep, original_weights)
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
@@ -595,11 +678,20 @@ def _eager_pt2_compile(
if torch.cuda.is_available():
torch.cuda.empty_cache()
default_search_iterations = (
50
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
else 10
)
search_iterations = int(
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
)
try:
return _save_and_compile(
pt2_path,
factory,
10,
search_iterations,
original_weights=original_weights,
user_indices=user_indices,
)

View File

@@ -0,0 +1,11 @@
# DLRM CUDA Benchmark
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
| Path | Median latency | Throughput |
| --- | ---: | ---: |
| eager | 0.267 ms | 7,674,321 candidates/s |
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |

View File

@@ -0,0 +1,213 @@
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
These tests are intended for the local integration workflow where the
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
then run the same model through ``backend=luminal_backend``.
"""
from __future__ import annotations
import copy
import sys
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
pytest.importorskip("tqdm")
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
if not DEEPCTR_ROOT.exists():
pytest.skip(
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
allow_module_level=True,
)
deepctr_root = str(DEEPCTR_ROOT)
if deepctr_root not in sys.path:
sys.path.insert(0, deepctr_root)
from deepctr_torch.inputs import (
DenseFeat,
SparseFeat,
VarLenSparseFeat,
build_input_features,
)
from deepctr_torch.models import DCN
from deepctr_torch.models.din import DIN
from luminal import luminal_backend
def _stack_features(
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
) -> torch.Tensor:
parts = []
for name in build_input_features(feature_columns):
value = np.asarray(feature_dict[name])
if value.ndim == 1:
value = np.expand_dims(value, axis=1)
parts.append(value)
stacked = np.concatenate(parts, axis=-1)
return torch.tensor(stacked, dtype=torch.float32, device=device)
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _make_dcn(
device: torch.device, cross_parameterization: str
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("s0", 5, embedding_dim=4),
SparseFeat("s1", 7, embedding_dim=4),
DenseFeat("d0", 1),
DenseFeat("d1", 1),
]
feature_dict = {
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
}
model = DCN(
linear_feature_columns=feature_columns,
dnn_feature_columns=feature_columns,
cross_num=2,
cross_parameterization=cross_parameterization,
dnn_hidden_units=(16,),
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("user", 4, embedding_dim=4),
SparseFeat("gender", 2, embedding_dim=4),
SparseFeat("item_id", 4, embedding_dim=8),
SparseFeat("cate_id", 3, embedding_dim=4),
DenseFeat("pay_score", 1),
VarLenSparseFeat(
SparseFeat(
"hist_item_id",
vocabulary_size=4,
embedding_dim=8,
embedding_name="item_id",
),
maxlen=4,
length_name="seq_length",
),
VarLenSparseFeat(
SparseFeat(
"hist_cate_id",
vocabulary_size=3,
embedding_dim=4,
embedding_name="cate_id",
),
maxlen=4,
length_name="seq_length",
),
]
feature_dict = {
"user": np.array([0, 1, 2, 3], dtype=np.int64),
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
"hist_item_id": np.array(
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"hist_cate_id": np.array(
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
}
model = DIN(
feature_columns,
["item_id", "cate_id"],
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
def test_deepctr_dcn_matches_inductor_when_supported(
device: torch.device, cross_parameterization: str
) -> None:
model, inputs = _make_dcn(device, cross_parameterization)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
model, inputs = _make_din(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")

View File

@@ -0,0 +1,456 @@
"""DLRM coverage for the luminal torch.compile backend.
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
that eager mode, ``torch.compile(..., backend="inductor")``, and
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
configurations, including a CUDA benchmark that compares Luminal against
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
"""
from __future__ import annotations
import copy
import importlib.machinery
import sys
import types
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
if not DLRM_ROOT.exists():
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
dlrm_root = str(DLRM_ROOT)
if dlrm_root not in sys.path:
sys.path.insert(0, dlrm_root)
def _install_dlrm_import_stubs() -> None:
ext_dist = types.ModuleType("extend_distributed")
ext_dist.my_size = 1
ext_dist.dist = None
ext_dist.get_split_lengths = lambda n: (n, [n])
ext_dist.get_my_slice = lambda n: slice(0, n)
class _AllToAll:
def __init__(self, values):
self._values = values
def wait(self):
return self._values
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
"extend_distributed", loader=None
)
sys.modules["extend_distributed"] = ext_dist
mlperf_logger = types.ModuleType("mlperf_logger")
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
"mlperf_logger", loader=None
)
sys.modules["mlperf_logger"] = mlperf_logger
tensorboard = types.ModuleType("torch.utils.tensorboard")
class SummaryWriter:
def __init__(self, *args, **kwargs):
pass
def add_scalar(self, *args, **kwargs):
pass
def close(self):
pass
tensorboard.SummaryWriter = SummaryWriter
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
"torch.utils.tensorboard", loader=None
)
sys.modules["torch.utils.tensorboard"] = tensorboard
onnx = types.ModuleType("onnx")
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
sys.modules["onnx"] = onnx
_install_dlrm_import_stubs()
import dlrm_s_pytorch as dlrm_mod
from luminal import luminal_backend
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend="inductor")
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(
copy.deepcopy(model),
backend="inductor",
mode="reduce-overhead",
)
def _compile_luminal(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
def _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters: int,
timed_iters: int,
mark_step_begin: bool = False,
) -> dict[str, float]:
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
with torch.no_grad():
for _ in range(warmup_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
_unwrap(compiled_model(*inputs))
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
with torch.no_grad():
for idx in range(timed_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
starts[idx].record()
_unwrap(compiled_model(*inputs))
ends[idx].record()
torch.cuda.synchronize()
elapsed_ms = np.array(
[start.elapsed_time(end) for start, end in zip(starts, ends)],
dtype=np.float64,
)
return {
"mean_ms": float(elapsed_ms.mean()),
"median_ms": float(np.median(elapsed_ms)),
"min_ms": float(elapsed_ms.min()),
}
def _timed_cuda_rounds(
compiled_model,
*inputs,
pre_round_warmup_iters: int,
timed_iters: int,
rounds: int,
mark_step_begin: bool = False,
) -> dict[str, float | list[float]]:
assert rounds > 0, "rounds must be positive"
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=pre_round_warmup_iters,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians = [stats["median_ms"]]
for _ in range(rounds - 1):
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=0,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians.append(stats["median_ms"])
round_medians_np = np.array(round_medians, dtype=np.float64)
return {
"round_medians_ms": [float(value) for value in round_medians],
"median_ms": float(np.median(round_medians_np)),
"mean_ms": float(round_medians_np.mean()),
"min_ms": float(round_medians_np.min()),
}
def _make_dlrm(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
m_spa = 4
ln_emb = np.array([8, 6, 4])
ln_bot = np.array([3, 4])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 8, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=1,
).eval()
inputs = (
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
[
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
],
[
torch.tensor([1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 3], dtype=torch.int64, device=device),
torch.tensor([2, 1], dtype=torch.int64, device=device),
],
)
return model.to(device), inputs
def _make_dlrm_batch_2048(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
batch_size = 2048
indices_per_bag = 2
m_spa = 16
ln_emb = np.array([4096, 2048, 1024])
ln_bot = np.array([3, 64, m_spa])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 64, 32, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=2,
).eval()
dense_x = torch.linspace(
-1.0,
1.0,
steps=batch_size * 3,
dtype=torch.float32,
device=device,
).reshape(batch_size, 3)
total_sparse_indices = batch_size * indices_per_bag
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
offsets = torch.arange(
0,
total_sparse_indices,
indices_per_bag,
dtype=torch.int64,
device=device,
)
inputs = (
dense_x,
[offsets.clone(), offsets.clone(), offsets.clone()],
[
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
],
)
return model.to(device), inputs
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
model, inputs = _make_dlrm(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
@pytest.mark.slow
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
device: torch.device,
) -> None:
if device.type != "cuda":
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
model, inputs = _make_dlrm_batch_2048(device)
eager = _run_eager(model, *inputs)
eager_model = copy.deepcopy(model).to(device).eval()
inductor_compiled = _compile_inductor(model)
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad():
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
luminal_compiled = _compile_luminal(model)
with torch.no_grad():
luminal_output = _unwrap(luminal_compiled(*inputs))
with torch.no_grad():
inductor_default_output = _unwrap(inductor_compiled(*inputs))
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
_assert_allclose(
luminal_output,
inductor_default_output,
"luminal vs inductor default",
atol=1e-4,
)
_assert_allclose(
luminal_output,
inductor_output,
"luminal vs inductor reduce-overhead",
atol=1e-4,
)
benchmark_rounds = 5
benchmark_iters = 20
post_compile_warmup_iters = 10
eager_stats = _timed_cuda_rounds(
eager_model,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_default_stats = _timed_cuda_rounds(
inductor_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_stats = _timed_cuda_rounds(
inductor_reduce_overhead,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
mark_step_begin=True,
)
luminal_stats = _timed_cuda_rounds(
luminal_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
batch_size = inputs[0].shape[0]
benchmark_results = [
("eager", eager_stats),
("inductor default", inductor_default_stats),
("inductor reduce-overhead", inductor_stats),
("luminal backend", luminal_stats),
]
ranked_results = sorted(
benchmark_results,
key=lambda item: float(item[1]["median_ms"]),
)
speed_lines = []
for idx, (label, stats) in enumerate(ranked_results, start=1):
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
rounds_repr = ", ".join(
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
)
speed_lines.append(
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
f" ({throughput:,.0f} candidates/s)"
f" [round medians: {rounds_repr}]"
)
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
inductor_stats["median_ms"]
)
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
print(
"\n"
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
f" Ranking by median latency:\n"
+ "\n".join(speed_lines)
+ "\n"
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
f" {luminal_vs_inductor:.3f}x\n"
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
)

View File

@@ -7,13 +7,14 @@
//! - [`NativeDynBackend`]: the reference implementation for CPU
use std::collections::HashMap;
use std::time::Duration;
use half::{bf16, f16};
use petgraph::stable_graph::NodeIndex;
use rustc_hash::FxHashMap;
use crate::dtype::DType;
use crate::graph::Graph;
use crate::graph::{Graph, SearchOptions};
use crate::hlir::{NativeData, NativeRuntime, Output};
use crate::op::Runtime;
@@ -46,6 +47,18 @@ pub trait DynBackend {
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
// --- Optional diagnostics --------------------------------------------
fn kernel_names(&self) -> Vec<String> {
vec![]
}
fn host_op_names(&self) -> Vec<String> {
vec![]
}
fn print_execution_stats(&self) {}
// --- Optional device pointer support (GPU backends) --------------------
fn supports_device_ptrs(&self) -> bool {
@@ -162,7 +175,9 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, args.search_iters);
let mut rng = rand::rng();
let search_options = search_options_from_env(args.search_iters);
let mut rt = graph.search_options(rt, search_options, &mut rng);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);
@@ -179,6 +194,39 @@ pub fn compile_backend<Rt: Runtime + 'static>(
Ok(wrap(rt))
}
fn env_usize(name: &str) -> Option<usize> {
std::env::var(name).ok()?.parse().ok()
}
fn env_duration_ms(name: &str) -> Option<Duration> {
Some(Duration::from_millis(env_usize(name)? as u64))
}
fn search_options_from_env(limit: usize) -> SearchOptions {
let mut options = SearchOptions::new(limit);
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
options = options.generation_size(generation_size);
}
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
options = options.mutations(mutations);
}
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
options = options.trials(trials);
}
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
options = options.keep_best(keep_best);
}
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
options = options.profile_timeout(profile_timeout);
}
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
options = options.group_timeout(group_timeout);
}
options
}
// ---------------------------------------------------------------------------
// Shared utilities
// ---------------------------------------------------------------------------

View File

@@ -217,32 +217,38 @@ pub fn reduce_sort(name: &str) -> SortDef {
}
pub type HLIROps = (
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Scatter,
SumReduce,
MaxReduce,
Softmax,
(
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
),
(
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Concat2D,
EmbeddingBagSum,
Scatter,
SumReduce,
MaxReduce,
Softmax,
),
);
#[derive(Default, Debug, Clone)]
@@ -1721,7 +1727,9 @@ impl NativeOp for Add {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
}
}
}
}
@@ -1808,7 +1816,9 @@ impl NativeOp for Mul {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
}
}
}
}
@@ -2126,6 +2136,233 @@ impl NativeOp for Gather {
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Concat2D {
pub rows: Expression,
pub lhs_cols: Expression,
pub rhs_cols: Expression,
}
impl Display for Concat2D {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Concat2D")
}
}
impl HLIROp for Concat2D {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (Concat2D {} {} {}) {})",
self.rows.to_egglog(),
self.lhs_cols.to_egglog(),
self.rhs_cols.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
)
}
}
impl EgglogOp for Concat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"Concat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for Concat2D {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let rows = self.rows.exec(dyn_map).unwrap();
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
fn concat_rows<T: Clone>(
lhs: &[T],
rhs: &[T],
rows: usize,
lhs_cols: usize,
rhs_cols: usize,
) -> Vec<T> {
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
for row in 0..rows {
let lhs_base = row * lhs_cols;
let rhs_base = row * rhs_cols;
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
}
out
}
match (inputs[0], inputs[1]) {
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
_ => panic!("Concat2D inputs must have the same dtype"),
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct EmbeddingBagSum {
pub n_bags: Expression,
pub n_indices: Expression,
pub hidden_dim: Expression,
pub num_embeddings: Expression,
}
impl Display for EmbeddingBagSum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "EmbeddingBagSum")
}
}
impl HLIROp for EmbeddingBagSum {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (EmbeddingBagSum {} {} {} {}) {})",
self.n_bags.to_egglog(),
self.n_indices.to_egglog(),
self.hidden_dim.to_egglog(),
self.num_embeddings.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
)
}
}
impl EgglogOp for EmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"EmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for EmbeddingBagSum {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let n_bags = self.n_bags.exec(dyn_map).unwrap();
let n_indices = self.n_indices.exec(dyn_map).unwrap();
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
let clamp_index = |value: i32, limit: usize| -> usize {
if limit == 0 {
return 0;
}
value.clamp(0, limit.saturating_sub(1) as i32) as usize
};
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
let NativeData::Int(indices) = inputs[1] else {
panic!("EmbeddingBagSum indices must be Int")
};
let NativeData::Int(offsets) = inputs[2] else {
panic!("EmbeddingBagSum offsets must be Int")
};
let bag_bounds = |bag: usize| -> (usize, usize) {
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
let end = offsets
.get(bag + 1)
.copied()
.map(clamp_offset)
.unwrap_or(n_indices);
(start, end.max(start))
};
match inputs[0] {
NativeData::F32(weight) => {
let mut out = vec![0.0f32; n_bags * hidden_dim];
for bag in 0..n_bags {
let (start, end) = bag_bounds(bag);
let out_base = bag * hidden_dim;
for pos in start..end {
let row = clamp_index(indices[pos], num_embeddings);
let row_base = row * hidden_dim;
for dim in 0..hidden_dim {
out[out_base + dim] += weight[row_base + dim];
}
}
}
NativeData::F32(out)
}
_ => panic!("EmbeddingBagSum only supports F32 weights"),
}
}
}
// Scatter Op (inverse of Gather)
#[derive(Debug, Clone, Default, PartialEq)]