Compare commits

..

1 Commits

Author SHA1 Message Date
Tucker Morgan
d6b0eb0ec1 Add recommender model compile coverage 2026-05-13 21:40:15 +00:00
46 changed files with 2408 additions and 2873 deletions

View File

@@ -231,9 +231,7 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
}
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
@@ -280,9 +278,7 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
}
fn gather_experts(

View File

@@ -5,6 +5,7 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
use luminal::prelude::*;
use crate::cudarc::driver::CudaContext;
use crate::host::describe_host_op;
use crate::runtime::CudaRuntime;
/// [`DynBackend`] wrapper for [`CudaRuntime`].
@@ -39,6 +40,26 @@ impl DynBackend for CudaLiteDynBackend {
self.runtime.execute(dyn_map);
}
fn kernel_names(&self) -> Vec<String> {
self.runtime
.kernel_names()
.iter()
.map(|name| (*name).to_string())
.collect()
}
fn host_op_names(&self) -> Vec<String> {
self.runtime
.host_ops()
.iter()
.map(|op| describe_host_op(*op))
.collect()
}
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
fn supports_device_ptrs(&self) -> bool {
true
}

View File

@@ -247,6 +247,10 @@ impl HostOp for CuBlasSgemmV2 {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasSgemmV2")
}
fn output_size(&self) -> Expression {
self.m * self.n
}

View File

@@ -419,7 +419,6 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
}
}
#[cfg(test)]
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
match epilogue {
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
@@ -978,6 +977,18 @@ impl CuBlasLt {
&& normalize(self.stride_c) == normalize(self.stride_d)
&& self.c_order == self.d_order
}
pub(crate) fn debug_summary(&self) -> String {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
format!(
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
resolve(&self.m),
resolve(&self.n),
resolve(&self.k),
resolve(&self.batch_count),
epilogue_name(self.epilogue),
)
}
}
impl HostOp for CuBlasLt {
@@ -1114,6 +1125,10 @@ impl HostOp for CuBlasLt {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasLt")
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)

View File

@@ -1,6 +1,7 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use crate::kernel::CudaGraphOp;
use luminal::{op::EgglogOp, prelude::*};
pub mod compute_attn_mask;
mod cublas;
@@ -79,6 +80,24 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
.map(cublaslt::CuBlasLt::c_d_layouts_match)
}
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
return op.debug_summary();
}
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
let mut summary = op.debug_summary();
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& let Some(timing) = op.debug_timing_summary()
{
summary.push_str(" [");
summary.push_str(&timing);
summary.push(']');
}
return summary;
}
op.stats_name().unwrap_or("unknown").to_string()
}
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside

View File

@@ -12,7 +12,10 @@ use luminal::{
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
extract_dtype, extract_expr, extract_expr_list,
},
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
hlir::{
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
Sin, Sqrt, SumReduce,
},
op::*,
prelude::*,
};
@@ -65,6 +68,8 @@ pub type Ops = (
KernelConstant,
KernelCast,
KernelEmbed,
KernelConcat2D,
KernelEmbeddingBagSum,
);
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
@@ -1540,22 +1545,19 @@ impl KernelOp for KernelIota {
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
vars.extend(self.range.dyn_vars());
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let range = self.range.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void iota_k(int *C{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {range}) return;
C[const_z] = {};
}}
}}",
@@ -1574,8 +1576,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.range.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
(self.range, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -2938,14 +2940,6 @@ impl KernelOp for KernelCast {
) {
let out_dtype = cuda_dtype(self.out_dtype);
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let size = self.size.to_kernel();
let kernel = if self.in_dtype.bits() < 8 {
// Sub-byte packed types: multiple values packed per byte.
@@ -2955,11 +2949,9 @@ impl KernelOp for KernelCast {
let mask = (1u32 << bits) - 1;
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {size}) return;
long long bit_offset = idx * {bits};
long long byte_idx = bit_offset >> 3;
int bit_pos = (int)(bit_offset & 7);
@@ -2975,11 +2967,9 @@ extern \"C\" {{
let in_dtype = cuda_dtype(self.in_dtype);
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {size}) return;
out[const_z] = ({out_dtype})in[const_z];
}}
}}"
@@ -2998,8 +2988,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.size.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
(self.size, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -3290,15 +3280,12 @@ impl KernelOp for KernelEmbed {
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let total_threads = batch_size * self.embed_dim;
let n_elements = total_threads.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {n_elements}) return;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
long long embed_idx = idx % embed_dim;
@@ -3321,12 +3308,13 @@ extern \"C\" {{
};
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
let total_threads = batch_size * self.embed_dim;
(
func,
module,
kernel,
(total_threads.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
(total_threads, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
0.into(),
constants,
)
@@ -3371,3 +3359,361 @@ extern \"C\" {{
"Embed"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelEmbeddingBagSum {
n_bags: Expression,
n_indices: Expression,
hidden_dim: Expression,
num_embeddings: Expression,
dtype: DType,
}
impl EgglogOp for KernelEmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelEmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[4]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelEmbeddingBagSum {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
self.dtype
);
let vars = self
.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_bags = self.n_bags.to_kernel();
let n_indices = self.n_indices.to_kernel();
let hidden_dim = self.hidden_dim.to_kernel();
let num_embeddings = self.num_embeddings.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long bag = blockIdx.y;
long long hidden_dim = {hidden_dim};
long long n_bags = {n_bags};
long long n_indices = {n_indices};
long long num_embeddings = {num_embeddings};
if (bag >= n_bags || dim >= hidden_dim) return;
int start_raw = offsets[bag];
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
int start = max(0, min(start_raw, (int)n_indices));
int end = max(start, min(end_raw, (int)n_indices));
float sum = 0.0f;
for (int pos = start; pos < end; ++pos) {{
int row = indices[pos];
row = max(0, min(row, (int)num_embeddings - 1));
sum += weight[(long long)row * hidden_dim + dim];
}}
out[bag * hidden_dim + dim] = sum;
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embedding_bag_sum").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
(self.hidden_dim.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.n_bags * self.hidden_dim
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
// Approximate: weights + indices + offsets
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.n_indices * self.hidden_dim
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"EmbeddingBagSum"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelConcat2D {
rows: Expression,
lhs_cols: Expression,
rhs_cols: Expression,
dtype: DType,
}
impl EgglogOp for KernelConcat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelConcat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<Concat2D, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelConcat2D {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelConcat2D only supports F32 today, got {:?}",
self.dtype
);
let vars = self
.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let rows = self.rows.to_kernel();
let lhs_cols = self.lhs_cols.to_kernel();
let rhs_cols = self.rhs_cols.to_kernel();
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = {total};
if (idx >= total) return;
long long rows = {rows};
long long lhs_cols = {lhs_cols};
long long rhs_cols = {rhs_cols};
long long out_cols = lhs_cols + rhs_cols;
if (rows == 0 || out_cols == 0) return;
long long row = idx / out_cols;
long long col = idx - row * out_cols;
if (col < lhs_cols) {{
out[idx] = lhs[row * lhs_cols + col];
}} else {{
long long rhs_col = col - lhs_cols;
out[idx] = rhs[row * rhs_cols + rhs_col];
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("concat_2d").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let output_size = self.output_size();
(
func,
module,
kernel,
(output_size.ceil_div(256), 1.into(), 1.into()),
(output_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.rows * (self.lhs_cols + self.rhs_cols)
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
0.into()
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Concat2D"
}
}

View File

@@ -4,6 +4,8 @@
//! that can be executed like any other HostOp.
use std::cell::RefCell;
use std::cmp::Reverse;
use std::collections::BTreeMap;
use std::sync::Arc;
use cudarc::driver::{
@@ -141,6 +143,8 @@ struct CudaGraphOpState {
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
/// Timing events for profiling
timing_events: Vec<cudarc::driver::sys::CUevent>,
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
last_kernel_timings_us: Vec<(&'static str, f64)>,
}
impl CudaGraphOpState {
@@ -155,6 +159,7 @@ impl CudaGraphOpState {
last_dyn_values: FxHashMap::default(),
last_buffer_ptrs: FxHashMap::default(),
timing_events: Vec::new(),
last_kernel_timings_us: Vec::new(),
}
}
}
@@ -192,6 +197,41 @@ impl CudaGraphOp {
state: RefCell::new(state),
}
}
pub fn debug_summary(&self) -> String {
let state = self.state.borrow();
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
for kernel in &state.kernels {
*counts.entry(kernel.kernel_name).or_default() += 1;
}
let mut counts: Vec<_> = counts.into_iter().collect();
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
let top = counts
.into_iter()
.take(4)
.map(|(name, count)| format!("{name}x{count}"))
.join(", ");
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
}
pub fn debug_timing_summary(&self) -> Option<String> {
let state = self.state.borrow();
if state.last_kernel_timings_us.is_empty() {
return None;
}
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
for (name, us) in &state.last_kernel_timings_us {
*totals.entry(*name).or_default() += *us;
}
let mut totals: Vec<_> = totals.into_iter().collect();
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
let top = totals
.into_iter()
.take(4)
.map(|(name, us)| format!("{name}={us:.0}us"))
.join(", ");
Some(top)
}
}
impl std::fmt::Debug for CudaGraphOp {
@@ -566,6 +606,23 @@ impl CudaGraphOp {
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& state.timing_events.len() >= state.kernels.len() + 1
{
stream.synchronize()?;
let ctx = stream.context().clone();
state.last_kernel_timings_us.clear();
for idx in 0..state.kernels.len() {
let start_event = state.timing_events[idx];
let end_event = state.timing_events[idx + 1];
let kernel_name = state.kernels[idx].kernel_name;
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
state.last_kernel_timings_us.push((kernel_name, us));
}
}
Ok(())
}
@@ -584,8 +641,9 @@ impl CudaGraphOp {
state.kernel_params.clear();
state.kernel_params.reserve(num_kernels);
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
let tracing_enabled = enabled!(Level::TRACE);
if tracing_enabled {
if tracing_enabled || profile_cuda_graph {
let needed_events = num_kernels + 1;
while state.timing_events.len() < needed_events {
state.timing_events.push(create_cuda_event(&ctx)?);
@@ -701,7 +759,7 @@ impl CudaGraphOp {
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
let timing_event = if tracing_enabled || profile_cuda_graph {
Some(state.timing_events[idx])
} else {
None
@@ -739,7 +797,9 @@ impl CudaGraphOp {
prev_graph_node = Some(graph_node);
}
if tracing_enabled && let Some(prev) = prev_graph_node {
if (tracing_enabled || profile_cuda_graph)
&& let Some(prev) = prev_graph_node
{
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}

View File

@@ -1,6 +1,9 @@
use crate::{
host::{DeviceBuffer, HostOp},
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
host::{DeviceBuffer, HostOp, describe_host_op},
kernel::{
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
record_cuda_graph_timings, record_event_on_stream,
},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
@@ -60,6 +63,12 @@ pub struct KernelStats {
pub tflops: f64,
}
#[derive(Debug, Clone)]
pub struct HostOpStats {
pub name: String,
pub execution_time_us: f64,
}
impl Debug for ExecutableHostOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HostOp: ({:?})", self.internal)
@@ -141,6 +150,7 @@ pub struct CudaRuntime {
changed_hlir: FxHashSet<NodeIndex>,
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
pub last_kernel_stats: Vec<KernelStats>,
pub last_host_op_stats: Vec<HostOpStats>,
pub last_total_time_us: f64,
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
/// When true, execute() skips input buffer consumption (used during search/profile)
@@ -1203,6 +1213,7 @@ impl Runtime for CudaRuntime {
changed_hlir: FxHashSet::default(),
cuda_graph_timings: vec![],
last_kernel_stats: vec![],
last_host_op_stats: vec![],
last_total_time_us: 0.0,
kernel_cache: FxHashMap::default(),
profiling: false,
@@ -1454,6 +1465,8 @@ impl Runtime for CudaRuntime {
let total_start = std::time::Instant::now();
let bucket = &self.compiled_buckets[self.active_bucket];
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
let mut host_timing_events = Vec::new();
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
let exec_op = &bucket.exec_graph[exec_node];
@@ -1507,6 +1520,16 @@ impl Runtime for CudaRuntime {
n_inputs = exec_op.inputs.len()
)
.entered();
let host_op_timing = if profile_host_ops {
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
let ctx = exec_op.stream.context().clone();
let start_event = create_cuda_event(&ctx).unwrap();
let end_event = create_cuda_event(&ctx).unwrap();
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
Some((name, ctx, start_event, end_event))
} else {
None
};
exec_op
.internal
.execute(
@@ -1522,10 +1545,28 @@ impl Runtime for CudaRuntime {
exec_op.internal.stats_name().unwrap_or("unknown")
);
});
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
host_timing_events.push((name, ctx, start_event, end_event));
}
}
// Single sync at end - CUDA stream ordering guarantees sequential execution
self.cuda_stream.synchronize().unwrap();
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
self.last_host_op_stats.clear();
if profile_host_ops {
for (name, ctx, start_event, end_event) in host_timing_events {
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
self.last_host_op_stats.push(HostOpStats {
name,
execution_time_us,
});
destroy_cuda_event(&ctx, start_event);
destroy_cuda_event(&ctx, end_event);
}
}
// Populate last_kernel_stats from HostOps that report stats
self.last_kernel_stats.clear();
@@ -1861,6 +1902,16 @@ impl CudaRuntime {
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
if !self.last_host_op_stats.is_empty() {
println!("\n=== Host Operation Statistics ===\n");
println!("{:<20} {:>12}", "HostOp", "Time (us)");
println!("{}", "-".repeat(34));
for s in &self.last_host_op_stats {
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
}
println!("{}", "-".repeat(34));
}
// Print kernel stats
if !self.last_kernel_stats.is_empty() {
println!("\n=== Kernel Execution Statistics ===\n");

View File

@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
.sum(n - 1)
.output();
QwenMoeGraph {
graph: cx,
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
.sum(n - 1)
.output();
GemmaMoeGraph {
graph: cx,

View File

@@ -61,8 +61,7 @@ impl MoE {
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
weights_exp.shape.expand(expert_out.dims());
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
(expert_out * weights_exp).sum(n - 1)
}
}
@@ -479,8 +478,7 @@ mod tests {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let _output = (down_out * weights_exp).sum(n - 1).output();
// Dump the HLIR to egglog

View File

@@ -855,6 +855,8 @@ Two important details:
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
---
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.

View File

@@ -98,12 +98,7 @@ pub struct GraphTranslation {
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
/// distinctions luminal collapses internally — notably int64 vs int32,
/// both of which map to `DType::Int` in luminal but must be reported
/// back to PyTorch with their original precision.
pub output_dtypes: Vec<u32>,
pub output_dtypes: Vec<DType>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -129,9 +124,7 @@ pub struct CompiledGraph {
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
/// that luminal collapses to `DType::Int` internally).
pub output_dtypes: Vec<u32>,
pub output_dtypes: Vec<DType>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -255,6 +248,23 @@ impl CompiledGraph {
self.runtime.device_type()
}
/// Names of kernels compiled into the active runtime bucket, if available.
#[getter]
fn kernel_names(&self) -> Vec<String> {
self.runtime.kernel_names()
}
/// Names of host ops in the active runtime bucket, if available.
#[getter]
fn host_op_names(&self) -> Vec<String> {
self.runtime.host_op_names()
}
/// Print backend execution statistics for the last run, if supported.
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
#[getter]
fn supports_device_ptrs(&self) -> bool {
@@ -483,7 +493,10 @@ impl CompiledGraph {
/// Get the PT2 dtype codes for all outputs (in order).
#[getter]
fn output_dtypes(&self) -> Vec<u32> {
self.output_dtypes.clone()
self.output_dtypes
.iter()
.map(|d| luminal_dtype_to_pt2_code(*d))
.collect()
}
/// Get output tensor data by name as f32 (copies to host).

View File

@@ -262,13 +262,10 @@ pub fn translate_pt2(
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges. PT2 emits
// `min_val: null` when the constraint is unbounded; fall back to 1 in
// that case (the smallest valid dim — used only as an initial value).
// Set initial dynamic dim values from symbol ranges
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
graph.set_dim(*c, initial);
graph.set_dim(*c, rc.min_val as usize);
}
}
@@ -284,14 +281,14 @@ pub fn translate_pt2(
})
.collect();
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
// Python wrapper can return tensors with the right torch.dtype, even when
// luminal collapses the type internally (e.g. int64 → DType::Int).
let output_dtypes: Vec<u32> = translated
let output_dtypes: Vec<DType> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
parsed
.tensor_meta(name)
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
.unwrap_or(DType::F32)
})
.collect();

View File

@@ -15,16 +15,7 @@ pub struct ExportedProgram {
#[derive(Debug, Clone, Deserialize)]
pub struct RangeConstraint {
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
/// constraint is unbounded (no min set), so this must accept None.
#[serde(default)]
pub min_val: Option<i64>,
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
/// unused on the luminal side, but accepted to avoid deserialization
/// errors when PT2 emits it.
#[serde(default)]
#[allow(dead_code)]
pub max_val: Option<i64>,
pub min_val: i64,
}
#[derive(Debug, Deserialize)]

View File

@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Result, bail};
use luminal::prelude::*;
use crate::pt2_schema::*;
@@ -8,21 +8,62 @@ use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let arg1 = &node.inputs[1].arg;
if let Some(name) = arg1.as_tensor_name() {
let b = self.get_tensor(name)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
} else {
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
let alpha = match op {
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
BinaryOp::Mul | BinaryOp::Div => 1.0,
};
let lhs = node.inputs[0]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
let rhs = node.inputs[1]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
match (lhs, rhs) {
(Some(a), Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(Some(a), None) => {
let mut val = self.get_float_arg(node, 1)? as f32;
if alpha != 1.0 {
val *= alpha;
}
Ok(self.apply_scalar_op(a, val, op))
}
(None, Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let lhs_val = self.get_float_arg(node, 0)? as f32;
let a = self
.graph
.constant_float(lhs_val)
.cast(b.dtype)
.expand_rhs(b.shape);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(None, None) => bail!("{} expects at least one tensor operand", node.target),
}
}

View File

@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, out_dims[0]);
let mut b_expanded = b.expand_dim(0, 1);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
@@ -389,11 +389,8 @@ fn depthwise_conv(
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
let patches = patches.expand_dim(2, group_out);
// Explicitly expand weight across the batch axis so the elementwise Mul
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
let w_expanded = w_flat
.expand_dim(0, patches.dims()[0])
.expand_dim(3, patches.dims()[3]);
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
// Element-wise multiply and sum over kernel dim
let product = patches * w_expanded;

View File

@@ -6,7 +6,6 @@ use crate::pt2_util::*;
use super::Translator;
use super::attention::SdpaVariant;
use super::reduction::ArgExtremum;
impl<'a> Translator<'a> {
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
@@ -112,6 +111,7 @@ impl<'a> Translator<'a> {
result
}
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
"torch.ops.aten.clone.default" => {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
@@ -134,8 +134,28 @@ impl<'a> Translator<'a> {
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
if alpha == 0.0 && beta == 0.0 {
self.graph
.constant_float(0.0)
.cast(mm.dtype)
.expand_rhs(mm.shape)
} else if beta == 0.0 {
if alpha == 1.0 { mm } else { mm * alpha }
} else if alpha == 0.0 {
let input = if beta == 1.0 { input } else { input * beta };
let zero = self
.graph
.constant_float(0.0)
.cast(input.dtype)
.expand_rhs(mm.shape);
let (input, _) = broadcast_binary(input, zero);
input
} else {
let input = if beta == 1.0 { input } else { input * beta };
let mm = if alpha == 1.0 { mm } else { mm * alpha };
let (input, mm) = broadcast_binary(input, mm);
input + mm
}
}
// Convolution
@@ -151,6 +171,11 @@ impl<'a> Translator<'a> {
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
"torch.ops.aten._embedding_bag.default"
| "torch.ops.aten._embedding_bag_forward_only.default" => {
self.translate_embedding_bag(node)?
}
"<built-in function getitem>" => self.translate_getitem(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
@@ -165,6 +190,9 @@ impl<'a> Translator<'a> {
// LayerNorm
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
self.translate_native_batch_norm_no_training(node)?
}
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,
@@ -221,16 +249,6 @@ impl<'a> Translator<'a> {
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
// Tensor comparisons
"torch.ops.aten.eq.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a.eq(scalar)
}
"torch.ops.aten.ne.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
@@ -248,13 +266,6 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a.eq(b)
}
"torch.ops.aten.ne.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.ne(b)
}
"torch.ops.aten.le.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
@@ -293,27 +304,18 @@ impl<'a> Translator<'a> {
// Clamp
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
// Cumsum
"torch.ops.aten.cumsum.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let a = if a.dtype == DType::Bool {
a.cast(DType::Int)
} else {
a
};
// Rank-0 (scalar) input: cumsum of a single element is the element
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
// and the underlying `cumop` indexes `shape.dims[axis]` which would
// panic with empty dims.
if a.shape.is_empty() {
a
} else {
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
a.cumsum(dim)
}
a.cumsum(dim)
}
// Floor / Ceil / Erf (approximations)
@@ -409,17 +411,6 @@ impl<'a> Translator<'a> {
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
// PyTorch's argmax/argmin returns int64; the dtype is preserved
// through the LUM-486 boundary widening.
"torch.ops.aten.argmax.default" => {
self.translate_argextremum(node, ArgExtremum::Max)?
}
"torch.ops.aten.argmin.default" => {
self.translate_argextremum(node, ArgExtremum::Min)?
}
// Gather (axis-aware)
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
@@ -483,28 +474,6 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a % b
}
// Remainder (Python-style modulo). For float tensors aten.remainder
// returns the same value as `%` would in luminal (Mod follows the
// language's % semantics on f32). The Tensor variant accepts a
// tensor RHS that may be rank-0; broadcast both operands so a
// scalar RHS is expanded to match the LHS shape before mod.
"torch.ops.aten.remainder.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a % b
}
"torch.ops.aten.remainder.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a % scalar
}
// Prod reduction
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,

View File

@@ -12,6 +12,64 @@ const SCATTER_INDEX_ARG: usize = 2;
const SCATTER_VALUE_ARG: usize = 3;
impl<'a> Translator<'a> {
fn try_concat_2d_fast(
&mut self,
lhs: GraphTensor,
rhs: GraphTensor,
axis: usize,
) -> Option<GraphTensor> {
if axis != 1
|| lhs.dtype != DType::F32
|| rhs.dtype != DType::F32
|| lhs.shape.len() != 2
|| rhs.shape.len() != 2
|| !lhs.shape.is_contiguous()
|| !rhs.shape.is_contiguous()
|| lhs.shape.dims[0] != rhs.shape.dims[0]
{
return None;
}
let rows = lhs.shape.dims[0];
let lhs_cols = lhs.shape.dims[1];
let rhs_cols = rhs.shape.dims[1];
let id = self.graph.add_op(
luminal::hlir::Concat2D {
rows,
lhs_cols,
rhs_cols,
},
&[lhs.id, rhs.id],
);
Some(GraphTensor::from_id(
id,
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
lhs.graph_ref,
lhs.dtype,
))
}
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
let index = self
.get_int_arg(node, 2)
.context("select.int: missing index")?;
let dim_size = a.shape.dims[dim]
.to_usize()
.context("select.int: symbolic dims are not supported for negative indices")?;
let normalized_index = if index < 0 {
(dim_size as i64 + index) as usize
} else {
index as usize
};
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
.squeeze(dim))
}
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -80,6 +138,43 @@ impl<'a> Translator<'a> {
Ok(a)
}
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
let mut a = self.get_input_tensor(node, 0)?;
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
sizes
.into_iter()
.map(|size| {
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
Ok(Expression::from(size as usize))
})
.collect::<Result<_>>()?
} else {
self.get_exprs_arg(node, 1)?
};
anyhow::ensure!(
repeats.len() >= a.shape.len(),
"repeat: repeats rank {} is smaller than input rank {}",
repeats.len(),
a.shape.len()
);
while a.shape.len() < repeats.len() {
a = a.unsqueeze(0);
}
Ok(a.repeat(repeats))
}
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
let index = self.get_int_arg(node, 1)?;
anyhow::ensure!(
index == 0,
"getitem: only tuple[0] access is supported today, got index={index}"
);
self.get_input_tensor(node, 0)
}
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1).unwrap_or(0);
@@ -120,47 +215,6 @@ impl<'a> Translator<'a> {
Ok(a.slice_along(start..end, dim))
}
/// `aten.select.int(self, dim, index)` — select element `index` along
/// `dim`, dropping that dim. Output rank = input rank 1, so a 1-D input
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
/// are normalized against the input shape.
///
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
/// slice + squeeze decomposition (rather than `gather`) because the
/// composition is a pure shape manipulation with a single iota, which the
/// luminal compiler can fold into surrounding ops.
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let index_raw = self.get_int_arg(node, 2)?;
// Normalize a possibly-negative index. PyTorch accepts indices in
// [-size, size); negative wraps from the end.
let index = if index_raw < 0 {
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"select.int: dim {} must be concrete to normalize a negative index",
dim
)
})?;
let normalized = axis_size as i64 + index_raw;
if normalized < 0 {
bail!(
"select.int: index {} out of range for dim {} of size {}",
index_raw,
dim,
axis_size
);
}
normalized as usize
} else {
index_raw as usize
};
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
}
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
names
@@ -202,7 +256,11 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, tensors[0].shape.len());
let mut result = tensors[0];
for t in &tensors[1..] {
result = result.concat_along(*t, dim);
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
result = fast;
} else {
result = result.concat_along(*t, dim);
}
}
Ok(result)
}
@@ -259,6 +317,79 @@ impl<'a> Translator<'a> {
bail!("index.Tensor: no index tensors in optional_tensors list");
}
index_names = found_tensors;
// Multiple explicit index tensors after leading `None`s mean
// "keep the prefix dims, then advanced-index the contiguous
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
if first_non_none_dim > 0
&& index_names.len() > 1
&& first_non_none_dim + index_names.len() == source.shape.len()
{
let src_dims = source.shape.dims;
let indexed_dims = &src_dims[first_non_none_dim..];
let n_indexed = index_names.len();
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * indexed_dims[i + 1];
}
let mut flat_idx: Option<GraphTensor> = None;
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
let axis_size = indexed_dims[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
idx_int
} else {
idx_int * stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (acc_b, w_b) = broadcast_binary(acc, weighted);
acc_b + w_b
}
None => weighted,
});
}
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
let idx_shape = flat_idx.shape.dims.to_vec();
let mut idx_numel = Expression::from(1usize);
for dim in &idx_shape {
idx_numel *= *dim;
}
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
let mut indexed_size = Expression::from(1usize);
for dim in indexed_dims {
indexed_size *= *dim;
}
let mut flat_source_shape = prefix_dims.clone();
flat_source_shape.push(indexed_size);
let flat_source = reshape_tensor(source, flat_source_shape);
let mut expanded_idx = flat_idx;
for _ in 0..prefix_dims.len() {
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
}
let mut target = prefix_dims.clone();
target.push(idx_numel);
expanded_idx.shape.expand(target);
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
let mut result_shape = prefix_dims;
result_shape.extend_from_slice(&idx_shape);
return Ok(reshape_tensor(gathered, result_shape));
}
// Simple case: single non-None index on a specific dim → gather_elements
if first_non_none_dim > 0 && index_names.len() == 1 {
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
@@ -374,17 +505,6 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
// gather_elements requires the index rank to match the source rank,
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
let indices = if promoted_rank0 {
indices.unsqueeze(0)
} else {
indices
};
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
// Stay in Int the whole way — multiplying an Int tensor by an
// Expression broadcasts the axis size and avoids three Cast nodes
@@ -396,12 +516,7 @@ impl<'a> Translator<'a> {
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
let result = a.gather_elements(normalized, dim);
Ok(if promoted_rank0 {
result.squeeze(0)
} else {
result
})
Ok(a.gather_elements(normalized, dim))
}
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {

View File

@@ -6,20 +6,6 @@ use crate::pt2_util::*;
use super::Translator;
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
/// smallest (ascending sort) element when scanning the input.
#[derive(Clone, Copy)]
pub(crate) enum ArgExtremum {
Max,
Min,
}
impl ArgExtremum {
fn descending(self) -> bool {
matches!(self, ArgExtremum::Max)
}
}
/// Compute total element count, returning an error if any dimension is symbolic.
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
a.dims().iter().try_fold(1usize, |acc, d| {
@@ -51,26 +37,32 @@ impl<'a> Translator<'a> {
(axes, keepdim)
}
_ => {
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
let ndim = a.shape.len();
if ndim == 0 {
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
// and mean of a scalar is just the scalar.
return Ok(a);
}
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
// override below assumes contiguous, no-broadcast storage —
// otherwise the `[1, N]` view treats stride-0 broadcast dims
// as if they held N distinct values and reads past the backing
// buffer. Materialize first when that's not the case (matches
// the guard `translate_reshape` already applies).
let total = concrete_numel(&a)?;
let axes: Vec<usize> = (0..ndim).collect();
let has_broadcast = a
.shape
.dims
.iter()
.zip(a.shape.strides.iter())
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
let a = if has_broadcast || !a.shape.is_contiguous() {
a + 0.0
} else {
a
};
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let result = match op {
ReductionOp::Sum => a.sum(axes),
// Note: the luminal `mean` helper divides by the product of the
// axis dims, but we already require concrete dims here so we
// divide by the cached `total` to avoid recomputing.
ReductionOp::Mean => a.sum(axes) / total as f32,
ReductionOp::Max => a.max(axes),
ReductionOp::Min => a.min(axes),
ReductionOp::Prod => a.prod(axes),
ReductionOp::Sum => flat.sum(vec![1]),
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
ReductionOp::Max => flat.max(vec![1]),
ReductionOp::Min => flat.min(vec![1]),
ReductionOp::Prod => flat.prod(vec![1]),
};
return Ok(result);
}
@@ -94,100 +86,4 @@ impl<'a> Translator<'a> {
Ok(result)
}
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
/// existing `stable_argsort` op and selecting the first index along the
/// sort axis.
///
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
/// for argmin). FX export emits the inputs positionally:
/// - input 0: tensor
/// - input 1: dim (Int) or None (Other) — when `dim=None`
/// - input 2: keepdim (Bool, optional)
///
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
/// The result of argsort along the sort axis is sliced at index 0,
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
/// `dim`.
///
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
/// view; we materialize it with `* 1` so the resulting node has
/// contiguous strides matching its visible shape (mirroring the
/// `topk` lowering in `translate_topk`). Without this, the output
/// buffer would be sized for the un-sliced argsort tensor while the
/// shape tracker reports a smaller rank.
///
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
/// metadata records int64 and the Python wrapper widens at the
/// boundary, so the PyTorch contract is preserved end-to-end
/// (LUM-486).
pub(crate) fn translate_argextremum(
&mut self,
node: &Node,
which: ArgExtremum,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
// argument (typically `Argument::Other(Null)`), so a missing or
// non-int slot means "reduce over the flattened tensor".
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
self.get_int_arg(node, 1).ok()
} else {
None
};
let keepdim = if node.inputs.len() > 2 {
self.get_bool_arg(node, 2).unwrap_or(false)
} else {
false
};
if a.shape.is_empty() {
match dim_opt {
None | Some(0) | Some(-1) => {
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
// `keepdim=True` does not add a dimension when the input is 0-d.
return Ok(self.graph.constant(0i64).cast(DType::Int));
}
Some(dim) => {
return Err(anyhow::anyhow!(
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
));
}
}
}
let descending = which.descending();
let (sort_axis, base) = match dim_opt {
None => {
// Full-reduce: flatten to 1-D, argsort along axis 0.
let total = concrete_numel(&a)?;
let flat = reshape_tensor(a, vec![Expression::from(total)]);
(0usize, flat)
}
Some(dim_raw) => {
let dim = normalize_dim(dim_raw, a.shape.len());
(dim, a)
}
};
// Pick index 0 along the sort axis. The slice-then-squeeze chain
// produces a non-contiguous view whose physical buffer is still
// sized for the un-sliced argsort tensor; the optional `keepdim`
// unsqueeze adds a stride-0 axis which is also non-contiguous.
// Materialize at the end with `* 1` so the resulting node has
// contiguous strides matching its visible shape (matches the
// pattern used by `translate_topk` for sliced index outputs).
let sorted = base.stable_argsort(sort_axis, descending);
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
let result = if keepdim {
picked.unsqueeze(sort_axis)
} else {
picked
};
Ok(result * 1)
}
}

View File

@@ -28,6 +28,45 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
impl<'a> Translator<'a> {
fn translate_embedding_bag_generic(
&mut self,
weight: GraphTensor,
indices: GraphTensor,
offsets: GraphTensor,
) -> Result<GraphTensor> {
let hidden_dim = weight.shape.dims[1];
let n_indices = indices.shape.dims[0];
let n_bags = offsets.shape.dims[0];
// Gather per-index embeddings: [E] -> [E, D].
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
let gathered = weight.gather(ids_expanded + arange);
// Bag assignment per position:
// bag_id[pos] = count(offsets <= pos) - 1
// This supports empty bags too, because repeated offsets simply skip a
// bag id when no positions land in that interval.
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
let starts = offsets.expand_dim(1, n_indices);
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
- self
.graph
.constant_float(1.0)
.cast(DType::Int)
.expand_rhs(vec![n_indices]);
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
let bag_ids = bag_ids.expand_dim(0, n_bags);
let mask = bag_ids
.eq(bag_axis)
.expand_dim(2, hidden_dim)
.cast(gathered.dtype);
let gathered = gathered.expand_dim(0, n_bags);
Ok((gathered * mask).sum(1))
}
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
.inputs
@@ -180,6 +219,45 @@ impl<'a> Translator<'a> {
Ok(value.expand_rhs(reference.shape))
}
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
let mode = self.get_int_arg(node, 4).unwrap_or(0);
anyhow::ensure!(
mode == 0,
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
);
anyhow::ensure!(
indices.shape.len() == 1 && offsets.shape.len() == 1,
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
indices.shape.len(),
offsets.shape.len()
);
if weight.dtype == DType::F32 {
let id = self.graph.add_op(
luminal::hlir::EmbeddingBagSum {
n_bags: offsets.shape.dims[0],
n_indices: indices.shape.dims[0],
hidden_dim: weight.shape.dims[1],
num_embeddings: weight.shape.dims[0],
},
&[weight.id, indices.id, offsets.id],
);
return Ok(GraphTensor::from_id(
id,
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
weight.graph_ref,
DType::F32,
));
}
self.translate_embedding_bag_generic(weight, indices, offsets)
}
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
let output_name = node
.outputs

View File

@@ -21,6 +21,30 @@ const DIV_MODE_INPUT_ARG: usize = 0;
const DIV_MODE_OTHER_ARG: usize = 1;
impl<'a> Translator<'a> {
fn expand_channel_parameter(
&self,
input: GraphTensor,
parameter: GraphTensor,
) -> Result<GraphTensor> {
anyhow::ensure!(
input.shape.len() >= 2,
"batch_norm: expected rank >= 2 input, got rank {}",
input.shape.len()
);
anyhow::ensure!(
parameter.shape.len() == 1,
"batch_norm: expected 1D channel parameter, got rank {}",
parameter.shape.len()
);
let mut expanded = parameter.unsqueeze(0);
for axis in 2..input.shape.len() {
expanded = expanded.unsqueeze(axis);
}
expanded.shape.expand(input.dims().to_vec());
Ok(expanded)
}
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
@@ -101,6 +125,30 @@ impl<'a> Translator<'a> {
Ok(result)
}
pub(crate) fn translate_native_batch_norm_no_training(
&mut self,
node: &Node,
) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
let mut result = (input - running_mean) / (running_var + eps).sqrt();
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
result = result * weight;
}
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
result = result + bias;
}
Ok(result)
}
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let zero = self
@@ -213,18 +261,12 @@ impl<'a> Translator<'a> {
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
// Check rounding_mode kwarg. PT2 serializes string args as
// {"as_string": "<value>"}, so we have to drill into the JSON.
// Check rounding_mode kwarg
let rounding_mode = node.inputs.iter().find_map(|input| {
if input.name == "rounding_mode"
&& let Argument::Other(val) = &input.arg
{
if let Some(s) = val.as_str() {
return Some(s.to_string());
}
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
return Some(s.to_string());
}
return val.as_str().map(|s| s.to_string());
}
None
});
@@ -275,52 +317,4 @@ impl<'a> Translator<'a> {
}
Ok(result)
}
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
///
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
/// overload takes tensor bounds that appear as separate input nodes in the
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
///
/// - rank-0 (scalar wrapped in a tensor) — most common
/// - same shape as self (per-element clamp, e.g. learned bounds)
/// - any shape that broadcasts to self via right-align + size-1 expand
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
///
/// We use `broadcast_binary` to right-align and expand both operands to a
/// common shape before the elementwise max/min, matching PyTorch semantics
/// across all three modes.
///
/// Either bound may be absent (FX represents this as a non-tensor argument
/// at the corresponding input slot), in which case we clamp to one side
/// only.
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_tensor = node
.inputs
.get(1)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let max_tensor = node
.inputs
.get(2)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let mut result = a;
if let Some(lo) = min_tensor {
let lo = lo.cast(result.dtype);
let (r, lo) = broadcast_binary(result, lo);
result = r.maximum(lo);
}
if let Some(hi) = max_tensor {
let hi = hi.cast(result.dtype);
let (r, hi) = broadcast_binary(result, hi);
result = r.minimum(hi);
}
Ok(result)
}
}

View File

@@ -8,12 +8,14 @@ from .compiled_model import CompiledModel
# Import Rust extension components (built by maturin)
from .luminal import CompiledGraph, process_pt2
from .main import luminal_backend, register_backend
from .pt2 import compile
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
"CompiledModel",
"compile",
"luminal_backend",
"register_backend",
"CompiledGraph",

View File

@@ -1,9 +1,8 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
from typing import List
from typing import Any, List
import torch
from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
@@ -28,6 +27,14 @@ class CompiledModel:
self._input_names = input_names or graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
output_dtype_codes = graph_result.output_dtypes
self._output_dtypes = [
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
for i in range(len(self._output_names))
]
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
self._weight_refs = weight_refs or []
self._user_indices = user_indices
@@ -35,6 +42,9 @@ class CompiledModel:
self._supports_device_ptrs = getattr(
graph_result, "supports_device_ptrs", False
)
# Cache converted/contiguous views for repeated calls with the same input
# tensors so we don't rebuild sparse index buffers every forward.
self._prepared_input_cache = [None] * len(self._input_names)
# Expected input dtypes from graph (used to convert user inputs)
input_dtype_codes = graph_result.input_dtypes
self._input_dtypes = [
@@ -43,6 +53,80 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
self._single_float_output_fast_path = (
self._supports_device_ptrs
and not self._has_dynamic_dims
and len(self._output_names) == 1
and self._output_dtypes[0].is_floating_point
)
@staticmethod
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
return (
id(tensor),
getattr(tensor, "_version", None),
tensor.data_ptr(),
expected_dtype,
)
def _prepare_input_tensor(
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
) -> torch.Tensor:
detached = tensor.detach()
needs_preparation = (
detached.dtype != expected_dtype or not detached.is_contiguous()
)
if not needs_preparation:
return detached
cache_key = self._input_cache_key(detached, expected_dtype)
cached = self._prepared_input_cache[index]
if cached is not None and cached[0] == cache_key:
return cached[1]
prepared = detached.contiguous().to(expected_dtype)
self._prepared_input_cache[index] = (cache_key, prepared)
return prepared
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Bind the current user inputs into the Rust graph."""
input_refs = []
for index, (name, tensor, expected_dtype) in enumerate(
zip(self._input_names, user_inputs, self._input_dtypes)
):
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
if self._supports_device_ptrs and tensor.is_cuda:
n_bytes = prepared.numel() * prepared.element_size()
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
input_refs.append(prepared)
else:
if prepared.device.type != "cpu":
prepared = prepared.cpu()
n_bytes = prepared.numel() * prepared.element_size()
dtype_code = _torch_dtype_code(prepared.dtype)
self._graph.set_input_from_ptr(
name, prepared.data_ptr(), n_bytes, dtype_code
)
return input_refs
def _run_static_single_float_output(
self, user_inputs: List[torch.Tensor], input_device: torch.device
):
_input_refs = self._bind_user_inputs(user_inputs)
output_name = self._output_names[0]
output_dtype = self._output_dtypes[0]
out = torch.empty(
self._static_output_shapes[0], dtype=output_dtype, device=input_device
)
self._graph.set_output_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
self._graph.run()
if not self._graph.output_is_zero_copy(output_name):
self._graph.copy_output_to_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
return (out,)
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -86,25 +170,18 @@ class CompiledModel:
if self._has_dynamic_dims:
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
elif (
self._single_float_output_fast_path
and input_device.type != "cpu"
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
):
return self._run_static_single_float_output(user_inputs, input_device)
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = []
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
_input_refs = self._bind_user_inputs(user_inputs)
# Resolve output shapes before run() (needed for pre-allocation).
if self._has_dynamic_dims:
@@ -112,8 +189,6 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
output_dtype_codes = self._graph.output_dtypes
# CUDA zero-copy path: pre-allocate output tensors and register their device
# pointers so the final kernel writes directly into PyTorch's buffer.
_use_zero_copy = self._supports_device_ptrs
@@ -121,8 +196,8 @@ class CompiledModel:
if _use_zero_copy:
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = torch.empty(shape, dtype=out_dtype, device=input_device)
@@ -135,18 +210,13 @@ class CompiledModel:
# Run the graph
self._graph.run()
# Integer dtypes for which we read the buffer as i32 and then cast.
# Includes int64 because luminal collapses all integer types to its
# 32-bit `Int` internally — we restore the original precision here.
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
# Collect outputs
if _use_zero_copy:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
out = output_tensors[i]
@@ -155,12 +225,11 @@ class CompiledModel:
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype in _int_dtypes:
elif out_dtype == torch.int32:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
elif out_dtype == torch.bool:
@@ -184,17 +253,13 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
self._output_dtypes[i]
if i < len(self._output_dtypes)
else torch.float32
)
if out_dtype in _int_dtypes:
if out_dtype == torch.int32:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
)
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
@@ -209,3 +274,41 @@ class CompiledModel:
outputs.append(out)
return tuple(outputs)
def _leaf_paths(tree: Any, prefix=()):
if torch.is_tensor(tree):
return [prefix]
if isinstance(tree, (list, tuple)):
paths = []
for idx, value in enumerate(tree):
paths.extend(_leaf_paths(value, prefix + (idx,)))
return paths
if isinstance(tree, dict):
paths = []
for key, value in tree.items():
paths.extend(_leaf_paths(value, prefix + (key,)))
return paths
return [prefix]
def _follow_path(tree: Any, path):
value = tree
for key in path:
value = value[key]
return value
class StructuredCompiledModel:
"""Preserve a module's original nested input structure for direct PT2 compile()."""
def __init__(self, compiled_model, example_args):
self._compiled = compiled_model
self._leaf_paths = _leaf_paths(example_args)
def __getattr__(self, name):
return getattr(self._compiled, name)
def __call__(self, *args):
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
return self._compiled(*flat_inputs)

View File

@@ -9,10 +9,12 @@ import inspect
import os
import shutil
import tempfile
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
from .compiled_model import CompiledModel
from .compiled_model import CompiledModel, StructuredCompiledModel
from .luminal import process_pt2
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
@@ -184,6 +186,66 @@ def _save_and_compile(
shutil.rmtree(tmpdir, ignore_errors=True)
def _has_cuda_inputs(flat_example_inputs):
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
"""Search env overrides for direct compile() calls.
CUDA DLRM-style models benefit materially from a deeper per-candidate
profile and from keeping more parents alive between generations. Keep the
defaults narrow so CPU and env-configured callers are unchanged.
"""
has_cuda = _has_cuda_inputs(flat_example_inputs)
overrides = {}
if search_trials is not None:
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
if search_keep_best is not None:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
return overrides
@contextmanager
def _temporary_env(overrides):
sentinel = object()
previous = {}
try:
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = old_value
def _strip_exported_weights_for_zero_copy(ep, original_weights):
"""Shrink the saved .pt2 artifact when original weights will be reused."""
if not original_weights:
return
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
if isinstance(orig, torch.nn.Parameter):
replacement = torch.nn.Parameter(
replacement, requires_grad=orig.requires_grad
)
ep._state_dict[key] = replacement
del orig
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
@@ -401,7 +463,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
def compile(
model,
example_input,
search_iterations=25,
search_iterations=None,
search_trials=None,
search_keep_best=None,
factory=None,
export_kwargs=None,
dynamic_dim=None,
@@ -413,7 +477,13 @@ def compile(
model: A PyTorch nn.Module.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations.
search_iterations: Number of optimization search iterations. When None,
defaults to 200 on CUDA inputs and 10 otherwise.
search_trials: Optional per-candidate profiling trials inside Luminal's
search. When unset, direct CUDA compile defaults to 5.
search_keep_best: Optional number of parent candidates to retain
between search generations. When unset, direct CUDA compile
defaults to 3.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
@@ -432,17 +502,26 @@ def compile(
Returns:
A CompiledModel callable.
"""
if factory is None:
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
if factory is None:
factory = _detect_factory_capsule(flat_example_inputs)
if search_iterations is None:
search_iterations = (
200
if _has_cuda_inputs(flat_example_inputs)
else 10
)
search_env = _direct_search_env(
flat_example_inputs,
search_trials=search_trials,
search_keep_best=search_keep_best,
)
kwargs = export_kwargs or {}
extra = _export_kwargs()
@@ -488,63 +567,16 @@ def compile(
)
ep = ep.run_decompositions(_decomp_table())
return _save_and_compile(ep, factory, search_iterations)
def _drop_input_guards(ep):
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
``i = torch.tensor(2)``), torch.export records two equivalent input
guards: ``L['i'].item() == 2`` (referencing the original local source)
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
Two failures stack on top of each other:
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
a literal ``L`` reference in the generated ``_guards_fn`` and raising
``NameError: name 'L' is not defined`` during retracing.
2. Even after dropping the unresolvable guard, the surviving
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
a graph break.
These guards exist solely to validate inputs at runtime in eager-mode
consumers of the ExportedProgram; the luminal compiler does its own
input shape/dtype checks against the compiled graph signature, so we
are not losing any safety by clearing them.
"""
if hasattr(ep, "_guards_code"):
ep._guards_code = []
def _drop_dead_data_dependent_ops(gm):
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
When dynamo specializes a 0-d int input by tracing through ``.item()``,
the resulting graph may contain a dead ``aten.item.default`` node whose
output is never consumed. luminal's translator does not lower
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
node in the graph causes a graph break at compile time. Eliminating it
keeps the (correctly specialized) downstream graph in a single subgraph.
"""
graph = gm.graph
changed = False
for node in list(graph.nodes):
if (
node.op == "call_function"
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
and len(node.users) == 0
):
graph.erase_node(node)
changed = True
if changed:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
original_weights = model.state_dict()
_strip_exported_weights_for_zero_copy(ep, original_weights)
with _temporary_env(search_env):
compiled = _save_and_compile(
ep,
factory,
search_iterations,
original_weights=original_weights,
)
return StructuredCompiledModel(compiled, example_args)
def _legacy_auto_dim(example_args):
@@ -626,23 +658,13 @@ def _eager_pt2_compile(
if dynamic_shapes is None:
raise
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
# LUM-499: drop dynamo-emitted input guards before run_decompositions
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
# data-dependent .item() calls and unresolved `L[...]` references.
_drop_input_guards(ep)
_drop_dead_data_dependent_ops(ep.graph_module)
ep = ep.run_decompositions(_decomp_table())
# When using shared memory (original_weights), strip large weight buffers
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
_strip_exported_weights_for_zero_copy(ep, original_weights)
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
@@ -656,11 +678,20 @@ def _eager_pt2_compile(
if torch.cuda.is_available():
torch.cuda.empty_cache()
default_search_iterations = (
50
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
else 10
)
search_iterations = int(
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
)
try:
return _save_and_compile(
pt2_path,
factory,
10,
search_iterations,
original_weights=original_weights,
user_indices=user_indices,
)

View File

@@ -0,0 +1,11 @@
# DLRM CUDA Benchmark
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
| Path | Median latency | Throughput |
| --- | ---: | ---: |
| eager | 0.267 ms | 7,674,321 candidates/s |
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |

View File

@@ -0,0 +1,213 @@
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
These tests are intended for the local integration workflow where the
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
then run the same model through ``backend=luminal_backend``.
"""
from __future__ import annotations
import copy
import sys
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
pytest.importorskip("tqdm")
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
if not DEEPCTR_ROOT.exists():
pytest.skip(
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
allow_module_level=True,
)
deepctr_root = str(DEEPCTR_ROOT)
if deepctr_root not in sys.path:
sys.path.insert(0, deepctr_root)
from deepctr_torch.inputs import (
DenseFeat,
SparseFeat,
VarLenSparseFeat,
build_input_features,
)
from deepctr_torch.models import DCN
from deepctr_torch.models.din import DIN
from luminal import luminal_backend
def _stack_features(
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
) -> torch.Tensor:
parts = []
for name in build_input_features(feature_columns):
value = np.asarray(feature_dict[name])
if value.ndim == 1:
value = np.expand_dims(value, axis=1)
parts.append(value)
stacked = np.concatenate(parts, axis=-1)
return torch.tensor(stacked, dtype=torch.float32, device=device)
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _make_dcn(
device: torch.device, cross_parameterization: str
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("s0", 5, embedding_dim=4),
SparseFeat("s1", 7, embedding_dim=4),
DenseFeat("d0", 1),
DenseFeat("d1", 1),
]
feature_dict = {
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
}
model = DCN(
linear_feature_columns=feature_columns,
dnn_feature_columns=feature_columns,
cross_num=2,
cross_parameterization=cross_parameterization,
dnn_hidden_units=(16,),
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("user", 4, embedding_dim=4),
SparseFeat("gender", 2, embedding_dim=4),
SparseFeat("item_id", 4, embedding_dim=8),
SparseFeat("cate_id", 3, embedding_dim=4),
DenseFeat("pay_score", 1),
VarLenSparseFeat(
SparseFeat(
"hist_item_id",
vocabulary_size=4,
embedding_dim=8,
embedding_name="item_id",
),
maxlen=4,
length_name="seq_length",
),
VarLenSparseFeat(
SparseFeat(
"hist_cate_id",
vocabulary_size=3,
embedding_dim=4,
embedding_name="cate_id",
),
maxlen=4,
length_name="seq_length",
),
]
feature_dict = {
"user": np.array([0, 1, 2, 3], dtype=np.int64),
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
"hist_item_id": np.array(
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"hist_cate_id": np.array(
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
}
model = DIN(
feature_columns,
["item_id", "cate_id"],
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
def test_deepctr_dcn_matches_inductor_when_supported(
device: torch.device, cross_parameterization: str
) -> None:
model, inputs = _make_dcn(device, cross_parameterization)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
model, inputs = _make_din(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")

View File

@@ -0,0 +1,456 @@
"""DLRM coverage for the luminal torch.compile backend.
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
that eager mode, ``torch.compile(..., backend="inductor")``, and
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
configurations, including a CUDA benchmark that compares Luminal against
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
"""
from __future__ import annotations
import copy
import importlib.machinery
import sys
import types
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
if not DLRM_ROOT.exists():
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
dlrm_root = str(DLRM_ROOT)
if dlrm_root not in sys.path:
sys.path.insert(0, dlrm_root)
def _install_dlrm_import_stubs() -> None:
ext_dist = types.ModuleType("extend_distributed")
ext_dist.my_size = 1
ext_dist.dist = None
ext_dist.get_split_lengths = lambda n: (n, [n])
ext_dist.get_my_slice = lambda n: slice(0, n)
class _AllToAll:
def __init__(self, values):
self._values = values
def wait(self):
return self._values
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
"extend_distributed", loader=None
)
sys.modules["extend_distributed"] = ext_dist
mlperf_logger = types.ModuleType("mlperf_logger")
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
"mlperf_logger", loader=None
)
sys.modules["mlperf_logger"] = mlperf_logger
tensorboard = types.ModuleType("torch.utils.tensorboard")
class SummaryWriter:
def __init__(self, *args, **kwargs):
pass
def add_scalar(self, *args, **kwargs):
pass
def close(self):
pass
tensorboard.SummaryWriter = SummaryWriter
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
"torch.utils.tensorboard", loader=None
)
sys.modules["torch.utils.tensorboard"] = tensorboard
onnx = types.ModuleType("onnx")
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
sys.modules["onnx"] = onnx
_install_dlrm_import_stubs()
import dlrm_s_pytorch as dlrm_mod
from luminal import luminal_backend
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend="inductor")
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(
copy.deepcopy(model),
backend="inductor",
mode="reduce-overhead",
)
def _compile_luminal(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
def _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters: int,
timed_iters: int,
mark_step_begin: bool = False,
) -> dict[str, float]:
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
with torch.no_grad():
for _ in range(warmup_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
_unwrap(compiled_model(*inputs))
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
with torch.no_grad():
for idx in range(timed_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
starts[idx].record()
_unwrap(compiled_model(*inputs))
ends[idx].record()
torch.cuda.synchronize()
elapsed_ms = np.array(
[start.elapsed_time(end) for start, end in zip(starts, ends)],
dtype=np.float64,
)
return {
"mean_ms": float(elapsed_ms.mean()),
"median_ms": float(np.median(elapsed_ms)),
"min_ms": float(elapsed_ms.min()),
}
def _timed_cuda_rounds(
compiled_model,
*inputs,
pre_round_warmup_iters: int,
timed_iters: int,
rounds: int,
mark_step_begin: bool = False,
) -> dict[str, float | list[float]]:
assert rounds > 0, "rounds must be positive"
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=pre_round_warmup_iters,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians = [stats["median_ms"]]
for _ in range(rounds - 1):
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=0,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians.append(stats["median_ms"])
round_medians_np = np.array(round_medians, dtype=np.float64)
return {
"round_medians_ms": [float(value) for value in round_medians],
"median_ms": float(np.median(round_medians_np)),
"mean_ms": float(round_medians_np.mean()),
"min_ms": float(round_medians_np.min()),
}
def _make_dlrm(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
m_spa = 4
ln_emb = np.array([8, 6, 4])
ln_bot = np.array([3, 4])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 8, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=1,
).eval()
inputs = (
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
[
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
],
[
torch.tensor([1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 3], dtype=torch.int64, device=device),
torch.tensor([2, 1], dtype=torch.int64, device=device),
],
)
return model.to(device), inputs
def _make_dlrm_batch_2048(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
batch_size = 2048
indices_per_bag = 2
m_spa = 16
ln_emb = np.array([4096, 2048, 1024])
ln_bot = np.array([3, 64, m_spa])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 64, 32, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=2,
).eval()
dense_x = torch.linspace(
-1.0,
1.0,
steps=batch_size * 3,
dtype=torch.float32,
device=device,
).reshape(batch_size, 3)
total_sparse_indices = batch_size * indices_per_bag
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
offsets = torch.arange(
0,
total_sparse_indices,
indices_per_bag,
dtype=torch.int64,
device=device,
)
inputs = (
dense_x,
[offsets.clone(), offsets.clone(), offsets.clone()],
[
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
],
)
return model.to(device), inputs
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
model, inputs = _make_dlrm(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
@pytest.mark.slow
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
device: torch.device,
) -> None:
if device.type != "cuda":
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
model, inputs = _make_dlrm_batch_2048(device)
eager = _run_eager(model, *inputs)
eager_model = copy.deepcopy(model).to(device).eval()
inductor_compiled = _compile_inductor(model)
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad():
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
luminal_compiled = _compile_luminal(model)
with torch.no_grad():
luminal_output = _unwrap(luminal_compiled(*inputs))
with torch.no_grad():
inductor_default_output = _unwrap(inductor_compiled(*inputs))
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
_assert_allclose(
luminal_output,
inductor_default_output,
"luminal vs inductor default",
atol=1e-4,
)
_assert_allclose(
luminal_output,
inductor_output,
"luminal vs inductor reduce-overhead",
atol=1e-4,
)
benchmark_rounds = 5
benchmark_iters = 20
post_compile_warmup_iters = 10
eager_stats = _timed_cuda_rounds(
eager_model,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_default_stats = _timed_cuda_rounds(
inductor_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_stats = _timed_cuda_rounds(
inductor_reduce_overhead,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
mark_step_begin=True,
)
luminal_stats = _timed_cuda_rounds(
luminal_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
batch_size = inputs[0].shape[0]
benchmark_results = [
("eager", eager_stats),
("inductor default", inductor_default_stats),
("inductor reduce-overhead", inductor_stats),
("luminal backend", luminal_stats),
]
ranked_results = sorted(
benchmark_results,
key=lambda item: float(item[1]["median_ms"]),
)
speed_lines = []
for idx, (label, stats) in enumerate(ranked_results, start=1):
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
rounds_repr = ", ".join(
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
)
speed_lines.append(
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
f" ({throughput:,.0f} candidates/s)"
f" [round medians: {rounds_repr}]"
)
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
inductor_stats["median_ms"]
)
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
print(
"\n"
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
f" Ranking by median latency:\n"
+ "\n".join(speed_lines)
+ "\n"
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
f" {luminal_vs_inductor:.3f}x\n"
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
)

View File

@@ -1,6 +1,5 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
@@ -221,7 +220,6 @@ from test_models import (
Conv1dNoPadModel,
Conv1dSamePadModel,
Conv1dBiasModel,
Conv1dFloorDivPositionalModel,
Conv2dNoPadModel,
Conv2dSamePadModel,
Conv2dBiasModel,
@@ -1098,17 +1096,6 @@ def test_reduce_sum_all_axes(device: torch.device):
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
eager = model(x)
out = model_compiled(x)
assert out.dtype == eager.dtype == torch.int64
assert torch.equal(out, eager)
def test_reduce_sum_3d_axis1(device: torch.device):
"""Test sum reduction along axis 1 for a 3D tensor."""
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
@@ -2035,16 +2022,9 @@ def test_split(device: torch.device):
# ========== Argsort / MoE Routing Tests ==========
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
"""Duplicate values should follow stable lower-index-first tie-breaking.
Parametrized over int32/int64 to verify luminal preserves whichever
integer dtype the eager model declares (LUM-486).
"""
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
device
)
def test_argsort_stable_duplicates(device: torch.device):
"""Duplicate values should follow stable lower-index-first tie-breaking."""
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.tensor(
[[2.0, 1.0, 1.0, 3.0]],
@@ -2053,21 +2033,13 @@ def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype)
)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
assert output.dtype == original.dtype, (
f"luminal returned {output.dtype}, eager produced {original.dtype}"
)
assert torch.equal(output, original)
assert output.dtype == torch.int32
assert torch.equal(output, original.to(torch.int32))
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
"""Focused proof for built MoE routing support.
Parametrized over int32/int64 for the integer-valued outputs to verify
luminal preserves the dtype declared by the eager model (LUM-486).
"""
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
def test_tiny_moe_routing(device: torch.device):
"""Focused proof for build MoE routing support."""
model: torch.nn.Module = TinyMoERoutingModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
scores = torch.tensor(
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
@@ -2078,10 +2050,17 @@ def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
expected = model(scores)
output = model_compiled(scores)
for actual, eager in zip(output, expected):
assert actual.dtype == eager.dtype, (
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
)
expected_dtypes = (
torch.int32,
torch.float32,
torch.int32,
torch.bool,
torch.int32,
torch.float32,
)
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
assert actual.dtype == expected_dtype
eager = eager.to(actual.dtype)
if actual.dtype.is_floating_point:
assert torch.allclose(actual, eager)
else:
@@ -2498,17 +2477,6 @@ def test_conv1d_bias(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_floor_div_positional_pt2(device: torch.device):
"""Conv1d stride output uses floor division before positional add."""
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.shape == original.shape == (15, 16)
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)

View File

@@ -1623,32 +1623,16 @@ class SplitTestModel(torch.nn.Module):
class ArgsortStableDuplicatesModel(torch.nn.Module):
"""Tests deterministic duplicate ordering for exported argsort.
``idx_dtype`` parameterizes the integer dtype of the returned indices so
the test can verify dtype preservation across luminal's int dtype paths
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
lets us drive the same model toward int32 or int64 outputs.
"""
"""Tests deterministic duplicate ordering for exported argsort."""
SORT_DIM = 1
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
return torch.argsort(x, dim=self.SORT_DIM)
class TinyMoERoutingModel(torch.nn.Module):
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
group_ids) to the requested dtype so the test can sweep int32 and int64
output paths (LUM-486). Internal indices stay int64 because torch.gather
/ torch.scatter require int64 index tensors.
"""
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
TOP_K = 2
ROUTING_DIM = -1
@@ -1656,9 +1640,8 @@ class TinyMoERoutingModel(torch.nn.Module):
DISPATCH_ON = 1
GROUP_SIZE = 2
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
def __init__(self) -> None:
super().__init__()
self.idx_dtype = idx_dtype
self.register_buffer(
"expert_scale",
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
@@ -1694,11 +1677,11 @@ class TinyMoERoutingModel(torch.nn.Module):
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
routing_sign = torch.sign(masked_values)
return (
routed_indices.to(self.idx_dtype),
routed_indices,
masked_values,
dispatch.to(self.idx_dtype),
dispatch,
inactive_mask,
group_ids.to(self.idx_dtype),
group_ids,
routing_sign,
)
@@ -1969,24 +1952,6 @@ class Conv1dBiasModel(torch.nn.Module):
return self.conv(x)
class Conv1dFloorDivPositionalModel(torch.nn.Module):
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
self.conv2 = torch.nn.Conv1d(
16, 16, kernel_size=3, stride=2, padding=1, bias=True
)
self.position = torch.nn.Parameter(torch.randn(15, 16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.gelu(self.conv1(x))
x = torch.nn.functional.gelu(self.conv2(x))
x = x.squeeze(0).transpose(0, 1)
return x + self.position
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""

File diff suppressed because it is too large Load Diff

View File

@@ -315,13 +315,8 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,5 +1,3 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -8,14 +6,18 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{
io::Write,
time::{Duration, Instant},
};
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
@@ -23,10 +25,9 @@ fn env_bool(name: &str) -> bool {
}
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = env_usize("GEN_TOKENS", 30);
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
@@ -37,6 +38,11 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
@@ -57,14 +63,11 @@ fn main() {
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let reset_cache = |runtime: &mut CudaRuntime| {
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
};
reset_cache(&mut runtime);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', 1);
@@ -72,66 +75,15 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
&prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
print_token_ids: bool,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -141,7 +93,7 @@ fn run_prompt(
const EOS_TOKEN: u32 = 1;
let prefill_start = Instant::now();
let prefill_start = std::time::Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -169,26 +121,12 @@ fn run_prompt(
.unwrap()
.0 as u32;
generated_token_ids.push(next_token);
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
seen_tokens.insert(next_token);
for _ in 1..gen_tokens {
if stdio && next_token == EOS_TOKEN {
break;
}
let start = Instant::now();
let start = std::time::Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -227,21 +165,10 @@ fn run_prompt(
break;
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
fwd_durations.push(start.elapsed());
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
if print_token_ids {
println!("Generated token ids: {generated_token_ids:?}");

View File

@@ -462,13 +462,8 @@ fn hlir_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
@@ -621,8 +616,6 @@ impl Gemma4SparseMoE {
let hidden_exp = hidden.unsqueeze(2);
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
}
}

View File

@@ -1,5 +1,3 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -9,36 +7,22 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{
io::Write,
time::{Duration, Instant},
};
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let gen_tokens = 500;
let search_graphs = 500;
let prompt = "Explain what a neural network is in a paragraph.";
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -47,6 +31,14 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
// Build graph
let mut cx = Graph::default();
@@ -74,13 +66,10 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', 1);
@@ -88,65 +77,12 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -158,16 +94,13 @@ fn run_prompt(
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
let mut generated = 0usize;
for i in 0..total_steps {
let start = Instant::now();
let start = std::time::Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -226,21 +159,12 @@ fn run_prompt(
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -246,13 +246,8 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,5 +1,3 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -9,36 +7,22 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{
io::Write,
time::{Duration, Instant},
};
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "Qwen/Qwen3-4B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let gen_tokens = 500;
let search_graphs = 500;
let prompt = "Explain what a neural network is in a paragraph.";
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -47,6 +31,7 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -69,13 +54,10 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', 1);
@@ -83,58 +65,12 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -146,16 +82,13 @@ fn run_prompt(
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
const STOP_TOKEN: u32 = 151643; // <|end|>
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
let mut generated = 0usize;
for i in 0..total_steps {
let start = Instant::now();
let start = std::time::Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -214,21 +147,12 @@ fn run_prompt(
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -287,13 +287,8 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,5 +1,3 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -8,27 +6,15 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{
io::Write,
time::{Duration, Instant},
};
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 30)
} else {
30
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
} else {
50
};
let gen_tokens = 30;
let search_graphs = 50;
let prompt = "The capital of France is";
let ctx = CudaContext::new(0).unwrap();
@@ -38,6 +24,7 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -60,13 +47,10 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', 1);
@@ -74,63 +58,14 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -141,7 +76,7 @@ fn run_prompt(
const STOP_TOKEN: u32 = 151643;
// Prefill: process prompt tokens one at a time
let prefill_start = Instant::now();
let prefill_start = std::time::Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -170,27 +105,13 @@ fn run_prompt(
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
seen_tokens.insert(next_token);
// Decode loop
for _ in 1..gen_tokens {
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
break;
}
let start = Instant::now();
let start = std::time::Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -229,23 +150,13 @@ fn run_prompt(
break;
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
fwd_durations.push(start.elapsed());
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
// Report benchmarks
println!();
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,

View File

@@ -287,8 +287,7 @@ impl QwenMoE {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
(down_out * weights_exp).sum(n - 1)
}
}
@@ -386,13 +385,8 @@ fn attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
// GQA expand
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -174,13 +174,8 @@ fn decoder_self_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let mut k_full = k_cache_out.slice((.., ..total, ..));
let mut v_full = v_cache_out.slice((.., ..total, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
k_full.shape.dims[1] = total;
v_full.shape.dims[1] = total;
let k_full = k_cache_out.slice((.., ..total, ..));
let v_full = v_cache_out.slice((.., ..total, ..));
let q = split_heads(q);

View File

@@ -1,58 +0,0 @@
use std::{
io::{BufRead, Write},
time::Instant,
};
pub fn enabled() -> bool {
std::env::args().any(|arg| arg == "--stdio")
}
pub fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn emit_ready() {
println!("READY");
std::io::stdout().flush().unwrap();
}
pub fn serve(mut f: impl FnMut(&str)) {
emit_ready();
let stdin = std::io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap();
f(&line);
}
}
pub fn emit_token(token: &str) {
println!("TOK\t{}", escape_token(token));
std::io::stdout().flush().unwrap();
}
pub fn emit_eoq(generated: usize, query_start: Instant) {
println!(
"EOQ\t{}\t{:.3}",
generated,
query_start.elapsed().as_secs_f64() * 1e3
);
std::io::stdout().flush().unwrap();
}
fn escape_token(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
_ => out.push(ch),
}
}
out
}

View File

@@ -7,13 +7,14 @@
//! - [`NativeDynBackend`]: the reference implementation for CPU
use std::collections::HashMap;
use std::time::Duration;
use half::{bf16, f16};
use petgraph::stable_graph::NodeIndex;
use rustc_hash::FxHashMap;
use crate::dtype::DType;
use crate::graph::Graph;
use crate::graph::{Graph, SearchOptions};
use crate::hlir::{NativeData, NativeRuntime, Output};
use crate::op::Runtime;
@@ -46,6 +47,18 @@ pub trait DynBackend {
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
// --- Optional diagnostics --------------------------------------------
fn kernel_names(&self) -> Vec<String> {
vec![]
}
fn host_op_names(&self) -> Vec<String> {
vec![]
}
fn print_execution_stats(&self) {}
// --- Optional device pointer support (GPU backends) --------------------
fn supports_device_ptrs(&self) -> bool {
@@ -162,7 +175,9 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, args.search_iters);
let mut rng = rand::rng();
let search_options = search_options_from_env(args.search_iters);
let mut rt = graph.search_options(rt, search_options, &mut rng);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);
@@ -179,6 +194,39 @@ pub fn compile_backend<Rt: Runtime + 'static>(
Ok(wrap(rt))
}
fn env_usize(name: &str) -> Option<usize> {
std::env::var(name).ok()?.parse().ok()
}
fn env_duration_ms(name: &str) -> Option<Duration> {
Some(Duration::from_millis(env_usize(name)? as u64))
}
fn search_options_from_env(limit: usize) -> SearchOptions {
let mut options = SearchOptions::new(limit);
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
options = options.generation_size(generation_size);
}
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
options = options.mutations(mutations);
}
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
options = options.trials(trials);
}
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
options = options.keep_best(keep_best);
}
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
options = options.profile_timeout(profile_timeout);
}
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
options = options.group_timeout(group_timeout);
}
options
}
// ---------------------------------------------------------------------------
// Shared utilities
// ---------------------------------------------------------------------------

View File

@@ -11,7 +11,6 @@ impl Add for GraphTensor {
type Output = GraphTensor;
fn add(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to add tensors. Got {:?} and {:?}",
@@ -74,11 +73,6 @@ impl Mul for GraphTensor {
type Output = GraphTensor;
fn mul(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(
self.dims(),
rhs.dims(),
"Dims must match to multiply tensors."
);
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
@@ -480,42 +474,6 @@ pub(super) mod tests {
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
}
#[test]
#[should_panic(expected = "Dims must match to add tensors.")]
fn test_add_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a + b;
}
#[test]
#[should_panic(expected = "Dims must match to multiply tensors.")]
fn test_mul_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a * b;
}
#[test]
#[should_panic(expected = "Dims must match to mod tensors.")]
fn test_mod_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a % b;
}
#[test]
#[should_panic(expected = "Dims must match to lt tensors.")]
fn test_lt_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a.lt(b);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -599,27 +557,6 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_mod_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
test_binary_transforms(
size,
(),
|a, b| a % b.expand_rhs(a.shape),
|a, b| {
let lhs = a.to_vec1::<f32>().unwrap();
let rhs_scalar = b.to_scalar::<f32>().unwrap();
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
},
identity,
shift_from_zero,
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -633,28 +570,6 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_lt_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS for `lt`.
test_binary(
size,
(),
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|a, b| {
let scalar = b.to_scalar::<f32>().unwrap();
let lhs = a.to_vec1::<f32>().unwrap();
let result: Vec<f32> = lhs
.iter()
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
.collect();
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
},
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]

View File

@@ -467,7 +467,7 @@ impl GraphTensor {
let mut win = Vec::with_capacity(n);
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
let effective_window = *d * (*k - 1) + 1;
win.push((*dim - effective_window).floor_div(s) + 1);
win.push(((*dim - effective_window) / s) + 1);
}
// [win..., kernel...]
@@ -905,14 +905,6 @@ mod tests {
);
}
#[test]
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
let mut cx = Graph::new();
let inp = cx.tensor((80, 3000));
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
}
#[test]
fn test_unsqueeze() {
let mut cx = Graph::new();

View File

@@ -217,32 +217,38 @@ pub fn reduce_sort(name: &str) -> SortDef {
}
pub type HLIROps = (
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Scatter,
SumReduce,
MaxReduce,
Softmax,
(
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
),
(
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Concat2D,
EmbeddingBagSum,
Scatter,
SumReduce,
MaxReduce,
Softmax,
),
);
#[derive(Default, Debug, Clone)]
@@ -1721,7 +1727,9 @@ impl NativeOp for Add {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
}
}
}
}
@@ -1808,7 +1816,9 @@ impl NativeOp for Mul {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
}
}
}
}
@@ -2126,6 +2136,233 @@ impl NativeOp for Gather {
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Concat2D {
pub rows: Expression,
pub lhs_cols: Expression,
pub rhs_cols: Expression,
}
impl Display for Concat2D {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Concat2D")
}
}
impl HLIROp for Concat2D {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (Concat2D {} {} {}) {})",
self.rows.to_egglog(),
self.lhs_cols.to_egglog(),
self.rhs_cols.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
)
}
}
impl EgglogOp for Concat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"Concat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for Concat2D {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let rows = self.rows.exec(dyn_map).unwrap();
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
fn concat_rows<T: Clone>(
lhs: &[T],
rhs: &[T],
rows: usize,
lhs_cols: usize,
rhs_cols: usize,
) -> Vec<T> {
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
for row in 0..rows {
let lhs_base = row * lhs_cols;
let rhs_base = row * rhs_cols;
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
}
out
}
match (inputs[0], inputs[1]) {
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
_ => panic!("Concat2D inputs must have the same dtype"),
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct EmbeddingBagSum {
pub n_bags: Expression,
pub n_indices: Expression,
pub hidden_dim: Expression,
pub num_embeddings: Expression,
}
impl Display for EmbeddingBagSum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "EmbeddingBagSum")
}
}
impl HLIROp for EmbeddingBagSum {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (EmbeddingBagSum {} {} {} {}) {})",
self.n_bags.to_egglog(),
self.n_indices.to_egglog(),
self.hidden_dim.to_egglog(),
self.num_embeddings.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
)
}
}
impl EgglogOp for EmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"EmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for EmbeddingBagSum {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let n_bags = self.n_bags.exec(dyn_map).unwrap();
let n_indices = self.n_indices.exec(dyn_map).unwrap();
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
let clamp_index = |value: i32, limit: usize| -> usize {
if limit == 0 {
return 0;
}
value.clamp(0, limit.saturating_sub(1) as i32) as usize
};
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
let NativeData::Int(indices) = inputs[1] else {
panic!("EmbeddingBagSum indices must be Int")
};
let NativeData::Int(offsets) = inputs[2] else {
panic!("EmbeddingBagSum offsets must be Int")
};
let bag_bounds = |bag: usize| -> (usize, usize) {
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
let end = offsets
.get(bag + 1)
.copied()
.map(clamp_offset)
.unwrap_or(n_indices);
(start, end.max(start))
};
match inputs[0] {
NativeData::F32(weight) => {
let mut out = vec![0.0f32; n_bags * hidden_dim];
for bag in 0..n_bags {
let (start, end) = bag_bounds(bag);
let out_base = bag * hidden_dim;
for pos in start..end {
let row = clamp_index(indices[pos], num_embeddings);
let row_base = row * hidden_dim;
for dim in 0..hidden_dim {
out[out_base + dim] += weight[row_base + dim];
}
}
}
NativeData::F32(out)
}
_ => panic!("EmbeddingBagSum only supports F32 weights"),
}
}
}
// Scatter Op (inverse of Gather)
#[derive(Debug, Clone, Default, PartialEq)]

View File

@@ -455,31 +455,6 @@ impl Expression {
terms.push(Term::CeilDiv);
Expression::new(terms)
}
/// Floor Division
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == 1 {
return self;
}
if self == 0 {
return 0.into();
}
if self == rhs {
return 1.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
&& let Some(c) = floor_div_i64(a, b)
{
return c.into();
}
// Shape dimensions are non-negative, so the existing integer Div term
// evaluates with floor semantics for dynamic shape expressions.
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Div);
Expression::new(terms)
}
/// Less than
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
@@ -679,16 +654,6 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
depth == 1
}
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
let q = a.checked_div(b)?;
let r = a.checked_rem(b)?;
if r != 0 && ((r > 0) != (b > 0)) {
q.checked_sub(1)
} else {
Some(q)
}
}
impl From<Term> for Expression {
fn from(value: Term) -> Self {
Expression::new(vec![value])
@@ -1029,12 +994,8 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
impl std::iter::Product for Expression {
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
// Empty product is the multiplicative identity, 1 — not 0. Returning
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
let Some(mut p) = iter.next() else {
return 1.into();
return 0.into();
};
for n in iter {
p *= n;
@@ -1145,27 +1106,6 @@ mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_empty_product_is_one() {
// The empty product (e.g. for a rank-0 tensor's shape) must be the
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
// emitters use `shape.iter().product()` to compute `numel`, and a
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
// launch with grid=(0, 1, 1) and crash at runtime.
let empty: Vec<Expression> = vec![];
assert_eq!(
empty.into_iter().product::<Expression>(),
Expression::from(1)
);
}
#[test]
fn test_empty_sum_is_zero() {
// Sanity check the additive identity stays 0 (it always was).
let empty: Vec<Expression> = vec![];
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
}
#[test]
fn test_basic_simplifications() {
let x = expr('x');