Compare commits

...

1 Commits

Author SHA1 Message Date
Tucker Morgan
c41ede0e5b luminal_python: vanilla-PyTorch DLRMv1 fast paths + batched FFI + input cache
Lands `torch.compile(model, backend=luminal_backend)` for vanilla
DLRMv1 (`nn.EmbeddingBag` per table, pairwise-dot interaction, top
MLP) at correctness parity with PyTorch eager (max abs diff
≤ 1.8e-7) and ~6.5× faster than `pt_eager` across an 11 × 5
(batch × num_cat) sweep. Crosses over `graph_safe + inductor +
manual CUDAGraph` at `nc=32, b=2048` (174 µs vs 241 µs).

Three layers of changes.

Translator (crates/luminal_python/rust/src/translator/)
- aten._embedding_bag(_forward_only).default: 2-D weight + 1-D
  indices + uniform-stride offsets → embedding_bag_sum_kernel
  (per-table) or, when a sibling group of emb-bag outputs feeds the
  pairwise-dot peephole, the fused multi_table_embedding_bag_sum
  kernel.
- aten.index_select.default: 2-D source, dim=0, 1-D index lowering.
- aten.addmm.default fast path: detect addmm(bias, x, weight.t())
  with α=β=1 and emit luminal_cuda_lite::kernel::linear_bias against
  the original (N, K) weight. With a unique relu / sigmoid consumer
  downstream, absorb it into linear_bias_relu / linear_bias_sigmoid
  via a forward-looking absorbed-nodes pass.
- aten.bmm.default fast path: detect bmm(A, permute(B, [0,2,1]))
  via transpose_2d_source and emit matmul_3d_t.
- aten.sum.dim_IntList peephole: sum ← view ← index_select chain →
  embedding_bag_sum_kernel.
- aten.index.Tensor peephole for DLRM strict-lower-tri pairwise
  dot: detect Z[:, li, lj] over bmm(cat(unsqueezes), permute([0,2,1]))
  and fold to dlrm_pairwise_dot_lower_tri (variadic) or, when all
  cat inputs but one are emb-bag outputs with matching shapes,
  dlrm_pairwise_dot_lower_tri_stacked over (dense, emb_stack).
- aten.permute.default: record 2-D [1,0] and 3-D [0,2,1] sources
  into transpose_2d_source for the matmul fast paths.
- node_chain + op_inputs side-tables populated alongside every
  translated op (including variadic-first-input ops like cat) so
  multi-input peepholes can walk back across the producer chain.
- consumers map + absorbed_nodes set populated once at start of
  translate_graph for forward-looking fusions (addmm + relu).
- Post-translation DCE: walks back from Output HLIR sinks and drops
  every unreachable producer. Load-bearing for the pairwise-dot
  peephole — the cat/bmm/permute chain it supersedes survives
  egglog on its own.

cuda_lite kernels (crates/luminal_cuda_lite/src/kernel/)
- embedding_bag.rs (new): EmbeddingBagSumKernel (single table),
  MultiTableEmbeddingBagSumKernel (K weight + K idx pointer pairs
  through packed staging buffers), and StackedEmbeddingBagKernel
  (pre-stacked weight + block-bundle launch layout). All produce
  (batch, num_tables, d) F32 in one launch.
- dlrm_interact.rs (new): PairwiseDotLowerTriKernel (variadic) and
  PairwiseDotLowerTriStackedKernel (block-per-batch with
  cooperative cache of all F·D feature vectors into shared memory;
  pair (i, j) recovered from p via closed-form
  (1+sqrt(1+8p))/2). Both emit the F(F-1)/2 dot products in one
  launch.
- matmul2d.rs: linear_bias_(relu|sigmoid)(_split_a) frontends over
  the existing fused Matmul2DKernel.
- runtime.rs: gate end-of-execute cuStreamSynchronize on
  self.profiling so the regular execute() path stays capturable by
  external graph capture (and is consistent with `set_input_*_ptr`
  being persistent across calls). Preserve CudaInput::Ptr entries
  in the post-execute "consume" step — they're non-owning views
  over caller memory and represent stable input slots that the
  wrapper may skip re-registering on a hot iter.
- runtime.rs: synchronize_stream() and read_per_kernel_timings_ms()
  helpers (the latter is a stub; the per-kernel timing infra from
  origin/dlrm-fused-kernels isn't ported yet — example's optional
  LUMINAL_KERNEL_TIMING=1 path prints "no timings available").

Wrapper (crates/luminal_python/{rust/src/compiled_graph,
src/luminal/compiled_model}.{rs,py})
- compiled_graph.rs: tensor_id(name) → u32 returns the underlying
  NodeIndex once, and run_with_ptrs(inputs, outputs) takes two
  flat Vec<(u32, u64, usize)> lists for batched input + output
  pointer registration in one pyo3 hop instead of N
  set_input_device_ptr + M set_output_device_ptr + run calls.
- compiled_model.py: __init__ resolves and caches the input /
  output IDs once. __call__ has a fast path that uses
  run_with_ptrs when all user inputs are on CUDA and outputs are
  fp. Per-input cache of (id(orig_tensor), orig_data_ptr,
  cast_tensor) skips Python-side detach/contiguous/.to/data_ptr/
  numel/element_size and the FFI register when the same user
  tensor is passed again with an unchanged data_ptr. Always
  returns a tuple to match the unbatched path's contract
  (returning a bare tensor broke dynamo's output handling — it
  iterated the first dim of the result, slicing the output to
  shape (1,)).

Per-iter time at (nc=32, batch=2048):

  Before:  set_in 995us + run 61us + collect 13us = 1069us
  After:   cache 15us + out 11us + run_with_ptrs 56us = 82us

13× wall-clock reduction, none of it in the kernels themselves.

Status
- 165-cell sweep (batch ∈ {2..2048} × num_cat ∈ {2..32} × 3
  variants) at examples/dlrm/results.csv.
- Beats pt_eager in 55/55 cells (4.7–7.1×).
- Beats graph_safe_inductor_cg at 1 cell (the largest), within
  10% at 3 more.
- Hand-written rust DLRM at examples/dlrm/src/main.rs runs the
  same kernel set and lands at 104µs on the same cell — the
  remaining gap to `luminal_compiled` (174µs) is the residual
  Python+dynamo+runtime per-call fixed cost.

Full breakdown at examples/dlrm/RESULTS.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 16:18:31 +00:00
22 changed files with 5156 additions and 22 deletions

View File

@@ -0,0 +1,483 @@
//! Fused DLRM pairwise-dot interaction.
//!
//! Replaces the cat→bmm(T,Tᵀ)→tril-gather chain with a single kernel
//! that reads N separate `(batch, d)` tensors and writes the strict
//! lower-triangular pairwise dot products directly into the output —
//! `out[b, p] = Σ_d v_i[b, d] * v_j[b, d]` for each ordered pair (i, j)
//! with i > j.
//!
//! Why this matters for the DLRM forward: the natural luminal lowering
//! materializes the `(B, F, D)` stacked tensor, then the full `(B, F, F)`
//! BMM output, then a flat gather to pull out F(F-1)/2 pairs. That's
//! ~12 small kernels and an `F²·B` intermediate even though only half
//! of those elements are kept. The fused version uses N pointer args
//! (one per feature vector), computes only the F(F-1)/2 dot products,
//! and writes directly to the final `(B, F(F-1)/2)` buffer.
//!
//! All shapes are static. The kernel source is generated with the
//! exact pair table baked in (so the inner loop is a fixed `D`-element
//! reduction with no shape-dependent branching).
use std::sync::Arc;
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use luminal::{
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
shape::Expression,
};
use crate::compile_module_image_for_current_device;
use crate::kernel::KernelOp;
#[derive(Debug, Clone)]
pub struct PairwiseDotLowerTriKernel {
pub batch: usize,
pub num_features: usize, // F
pub d: usize,
}
impl PairwiseDotLowerTriKernel {
fn pair_count(&self) -> usize {
self.num_features * (self.num_features - 1) / 2
}
}
impl KernelOp for PairwiseDotLowerTriKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let f = self.num_features;
let p = self.pair_count();
// Pair table (i, j) with i > j, in strict-lower-tri (row-major over
// i then j) order — same convention as torch.tril_indices(F, F, -1).
let mut pairs: Vec<(usize, usize)> = Vec::with_capacity(p);
for i in 0..f {
for j in 0..i {
pairs.push((i, j));
}
}
// Build kernel params signature: one pointer per input feature.
let in_params: String = (0..f)
.map(|k| format!(", const float* __restrict__ v{k}"))
.collect::<Vec<_>>()
.concat();
// For each pair p, generate one branch in the switch that selects
// the two input pointers to dot-product. With F small (DLRM has
// F=4), the branch is fully unrolled.
let mut pair_switch = String::new();
for (pidx, (i, j)) in pairs.iter().enumerate() {
pair_switch += &format!(
" case {pidx}: pa = v{i}; pb = v{j}; break;\n"
);
}
let kernel = format!(
"
extern \"C\" __global__ void dlrm_pairwise_dot_lower_tri_kernel(
float* __restrict__ out{in_params}
) {{
const int B = {batch};
const int D = {d};
const int P = {p};
int b = blockIdx.x;
int p = blockIdx.y;
int t = threadIdx.x;
if (b >= B || p >= P) return;
const float* pa = nullptr;
const float* pb = nullptr;
switch (p) {{
{pair_switch}
default: return;
}}
// Block-wide reduction of dot(pa[b], pb[b]) over D using shared mem.
extern __shared__ float smem[];
float partial = 0.0f;
for (int d = t; d < D; d += blockDim.x) {{
partial += pa[b * D + d] * pb[b * D + d];
}}
smem[t] = partial;
__syncthreads();
// Power-of-two tree reduce. blockDim.x must be a power of two.
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {{
if (t < stride) {{
smem[t] += smem[t + stride];
}}
__syncthreads();
}}
if (t == 0) {{
out[b * P + p] = smem[0];
}}
}}
",
batch = self.batch,
d = self.d,
p = p,
pair_switch = pair_switch,
in_params = in_params,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module
.load_function("dlrm_pairwise_dot_lower_tri_kernel")
.unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
// Pick a power-of-two thread count ≤ D, ≥ 32 where possible.
let mut threads = 1usize;
while threads * 2 <= self.d.max(32) {
threads *= 2;
}
let threads = threads.max(32).min(1024);
(
func,
module,
kernel,
(
Expression::from(self.batch),
Expression::from(p),
Expression::from(1usize),
),
(
Expression::from(threads),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(threads * 4),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch * self.pair_count())
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
// Each pair reads 2 vectors of D floats per batch row. F-choose-2
// pairs, so per-batch each input vector is read F-1 times.
Expression::from(self.batch * self.num_features * (self.num_features - 1) * self.d * 4)
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
// 2D-1 flops per dot product (D mul + D-1 add).
Expression::from(self.batch * self.pair_count() * (2 * self.d - 1))
}
fn kernel_name(&self) -> &'static str {
"DLRMPairwiseDotLowerTri"
}
}
#[derive(Debug, Clone)]
pub struct PairwiseDotLowerTriCustom(pub PairwiseDotLowerTriKernel);
impl CustomOp for PairwiseDotLowerTriCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
/// Two-input variant of [`PairwiseDotLowerTriKernel`] that consumes the
/// dense MLP output and a stacked embedding output without requiring
/// the caller to first slice the stack into individual (B, D) views.
///
/// Treats feature 0 as `dense_out[b, t]` and features 1..=num_emb as
/// `emb_stack[b, k-1, t]`. Output pair table is the strict lower tri
/// of an `F × F` matrix where `F = num_emb + 1`.
#[derive(Debug, Clone)]
pub struct PairwiseDotLowerTriStackedKernel {
pub batch: usize,
pub num_emb: usize, // N (excluding the dense feature)
pub d: usize,
}
impl PairwiseDotLowerTriStackedKernel {
fn num_features(&self) -> usize {
self.num_emb + 1
}
fn pair_count(&self) -> usize {
let f = self.num_features();
f * (f - 1) / 2
}
}
impl KernelOp for PairwiseDotLowerTriStackedKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let f = self.num_features();
let p = self.pair_count();
let n_emb = self.num_emb;
let d_ = self.d;
// Block-per-batch layout. Each block:
// 1. Cooperatively loads all F feature vectors for batch b into
// shared memory once — F*D floats total. Feature 0 = dense[b];
// features 1..F = emb_stack[b, k-1, :].
// 2. Each thread `tid` strides over pairs `p = tid, tid+blockDim.x,
// …, P-1`. For each, derives (i, j) such that i > j and writes
// the dot product of feat[i] and feat[j].
//
// Compared to the previous (B, P) grid-of-one-block-per-output
// layout this:
// - Cuts launch count by P× (e.g. 528× at num_cat=32).
// - Reads each feature vector once per batch instead of (F-1)
// times — F(F-1) reads → F reads, an ~(F-1)/2× memory traffic
// reduction (e.g. 16× at num_cat=32, F=33).
// - Reuses cached features across all P pairs at shared-memory
// latency instead of refetching from global per pair.
//
// Pair-index → (i, j) is computed from `p` directly using the
// closed-form for strict lower-tri row indexing:
// row i contains i pairs (j ∈ [0, i)); cumulative row starts
// at `i*(i-1)/2`; so `i = floor((1+sqrt(1+8p))/2)` and
// `j = p - i*(i-1)/2`. We do a tiny defensive adjustment
// afterwards to absorb sqrtf rounding.
let kernel = format!(
"
extern \"C\" __global__ void dlrm_pairwise_dot_lower_tri_stacked_kernel(
float* __restrict__ out,
const float* __restrict__ dense, // (B, D)
const float* __restrict__ emb_stack // (B, N, D)
) {{
const int B = {batch};
const int D = {d};
const int N = {n_emb};
const int F = {f};
const int P = {p};
int b = blockIdx.x;
int tid = threadIdx.x;
int tcount = blockDim.x;
if (b >= B) return;
// Shared feature cache: F * D floats.
extern __shared__ float feat[];
for (int i = tid; i < F * D; i += tcount) {{
int feat_idx = i / D;
int dim = i - feat_idx * D;
if (feat_idx == 0) {{
feat[i] = dense[b * D + dim];
}} else {{
int slot = feat_idx - 1;
feat[i] = emb_stack[(b * N + slot) * D + dim];
}}
}}
__syncthreads();
// Each thread handles a strided slice of the P pairs.
for (int p = tid; p < P; p += tcount) {{
float t = sqrtf(8.0f * (float)p + 1.0f);
int pi = (int)((t + 1.0f) * 0.5f);
// Adjust for fp rounding — pi*(pi-1)/2 must be the largest
// row-start ≤ p.
while (pi * (pi - 1) / 2 > p) pi--;
while ((pi + 1) * pi / 2 <= p) pi++;
int pj = p - pi * (pi - 1) / 2;
float acc = 0.0f;
#pragma unroll
for (int d = 0; d < {d}; ++d) {{
acc += feat[pi * {d} + d] * feat[pj * {d} + d];
}}
out[b * P + p] = acc;
}}
}}
",
batch = self.batch,
d = d_,
n_emb = n_emb,
f = f,
p = p,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module
.load_function("dlrm_pairwise_dot_lower_tri_stacked_kernel")
.unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
// Block size: enough threads to cover both the feature-load phase
// (F*D elements) and the pair computation (P elements) without
// serial waves dominating, capped at 1024 (max CUDA block size)
// and rounded down to a multiple of 32 for warp alignment.
let want = std::cmp::max(f * d_, p);
let threads = want.clamp(32, 1024).next_multiple_of(32);
let threads = threads.min(1024);
let shared_bytes = f * d_ * 4;
(
func,
module,
kernel,
(
Expression::from(self.batch),
Expression::from(1usize),
Expression::from(1usize),
),
(
Expression::from(threads),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(shared_bytes),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch * self.pair_count())
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
Expression::from(self.batch * self.num_features() * (self.num_features() - 1) * self.d * 4)
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
Expression::from(self.batch * self.pair_count() * (2 * self.d - 1))
}
fn kernel_name(&self) -> &'static str {
"DLRMPairwiseDotLowerTriStacked"
}
}
#[derive(Debug, Clone)]
pub struct PairwiseDotLowerTriStackedCustom(pub PairwiseDotLowerTriStackedKernel);
impl CustomOp for PairwiseDotLowerTriStackedCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
/// Pairwise lower-tri dot product over `dense_out` plus a stacked
/// embedding output. Avoids the per-table slice that the variadic
/// variant would otherwise need to materialize.
///
/// * `dense_out`: `(batch, d)` — feature 0 in the pair table.
/// * `emb_stack`: `(batch, num_emb, d)` — features 1..=num_emb.
///
/// Returns `(batch, (num_emb+1) * num_emb / 2)`, same strict-lower-tri
/// ordering as [`dlrm_pairwise_dot_lower_tri`].
pub fn dlrm_pairwise_dot_lower_tri_stacked(
dense_out: GraphTensor,
emb_stack: GraphTensor,
) -> GraphTensor {
assert_eq!(dense_out.dtype, DType::F32, "dense_out must be F32");
assert_eq!(emb_stack.dtype, DType::F32, "emb_stack must be F32");
let dd = dense_out.dims();
let sd = emb_stack.dims();
assert_eq!(dd.len(), 2, "dense_out must be 2D");
assert_eq!(sd.len(), 3, "emb_stack must be 3D (batch, num_emb, d)");
let batch = dd[0].to_usize().expect("batch must be static");
let d = dd[1].to_usize().expect("d must be static");
assert_eq!(sd[0].to_usize().unwrap(), batch);
let num_emb = sd[1].to_usize().expect("num_emb must be static");
assert_eq!(sd[2].to_usize().unwrap(), d);
let kern = PairwiseDotLowerTriStackedKernel {
batch,
num_emb,
d,
};
let f = num_emb + 1;
let p = f * (f - 1) / 2;
let cx = unsafe { &mut *dense_out.graph_ref };
cx.custom_op(
PairwiseDotLowerTriStackedCustom(kern),
vec![dense_out, emb_stack],
(batch, p),
DType::F32,
)
}
/// Strict-lower-triangular pairwise dot product of N feature vectors.
///
/// * `features`: N tensors, each `(batch, d)`, all F32, all the same shape.
///
/// Returns `(batch, N*(N-1)/2)` with pair ordering matching
/// `torch.tril_indices(N, N, -1)` (row-major: (1,0), (2,0), (2,1), …).
pub fn dlrm_pairwise_dot_lower_tri(features: Vec<GraphTensor>) -> GraphTensor {
assert!(features.len() >= 2, "need at least 2 feature vectors");
let first = features[0];
let dims = first.dims();
assert_eq!(dims.len(), 2, "each feature vector must be 2D (batch, d)");
let batch = dims[0].to_usize().expect("batch must be static");
let d = dims[1].to_usize().expect("d must be static");
let f = features.len();
for v in &features {
assert_eq!(v.dtype, DType::F32, "features must all be F32");
let vd = v.dims();
assert_eq!(vd.len(), 2, "features must all be 2D");
assert_eq!(vd[0].to_usize().unwrap(), batch, "batch mismatch");
assert_eq!(vd[1].to_usize().unwrap(), d, "d mismatch");
}
let kern = PairwiseDotLowerTriKernel {
batch,
num_features: f,
d,
};
let p = f * (f - 1) / 2;
let cx = unsafe { &mut *first.graph_ref };
cx.custom_op(
PairwiseDotLowerTriCustom(kern),
features,
(batch, p),
DType::F32,
)
}

View File

@@ -0,0 +1,757 @@
//! Single-kernel fused EmbeddingBag (sum-pool) operator.
//!
//! DLRM-style embedding lookups in luminal currently lower into a chain
//! of broadcast-iota + multiply + add + Gather + SumReduce kernels (~6
//! kernels per table). For a model with even a handful of tables that
//! eats most of the per-iter launch budget once everything else is
//! captured into a single CUDA graph.
//!
//! This op collapses the whole pattern — `gather(table, idx) → sum(L)` —
//! into one kernel. Same template as `Matmul2DKernel`: implement
//! [`KernelOp`], wrap in a [`CustomOp`] so the user-facing call comes
//! out as a `dyn KernelOp` in the LLIR (which means it can be absorbed
//! into the same CudaGraphOp as everything around it — no extra host
//! op, no extra CUDA launch outside the graph).
//!
//! Semantics: `out[b, d] = Σ_l table[indices[b, l], d]` with
//! table: (n_emb, d), F32, row-major
//! indices: (batch, bag), I32, row-major
//! out: (batch, d), F32, row-major
//!
//! Fixed-shape: `n_emb`, `d`, `batch`, `bag` are static (baked into
//! the kernel source via #defines), matching how the rest of the
//! `kernel::` ops in this crate handle shape.
use std::sync::Arc;
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr};
use luminal::{
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
shape::Expression,
};
use crate::compile_module_image_for_current_device;
use crate::kernel::KernelOp;
/// One-kernel fused EmbeddingBag with sum pooling and fixed bag size.
#[derive(Debug, Clone)]
pub struct EmbeddingBagSumKernel {
pub batch: usize,
pub bag: usize,
pub d: usize,
pub n_emb: usize,
}
impl KernelOp for EmbeddingBagSumKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
// One block per batch row, `d` threads per block. Each thread sums
// one output column over the `bag` indices. This is the standard
// bag-size-1..L pattern and is memory-bandwidth bound on `table`,
// which is exactly the right roofline for this op.
let kernel = format!(
"
extern \"C\" __global__ void embedding_bag_sum_kernel(
float* __restrict__ out,
const float* __restrict__ table,
const int* __restrict__ indices
) {{
const int B = {batch};
const int L = {bag};
const int D = {d};
const int N = {n_emb};
int b = blockIdx.x;
int d = threadIdx.x;
if (b >= B || d >= D) return;
float acc = 0.0f;
#pragma unroll 4
for (int l = 0; l < L; ++l) {{
int row = indices[b * L + l];
// Index is from user input; trust it (matches torch.EmbeddingBag).
acc += table[row * D + d];
}}
out[b * D + d] = acc;
(void)N;
}}
",
batch = self.batch,
bag = self.bag,
d = self.d,
n_emb = self.n_emb,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embedding_bag_sum_kernel").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(
Expression::from(self.batch),
Expression::from(1usize),
Expression::from(1usize),
),
(
Expression::from(self.d),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(0usize),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch * self.d)
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
// For each output element, L reads from table (4 bytes each), plus
// L reads from indices (4 bytes each, shared across D threads — we
// just bill once per output to keep this readable).
Expression::from(self.batch * self.d * self.bag * 4 + self.batch * self.bag * 4)
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
// L adds per output element. Pointer math doesn't count.
Expression::from(self.batch * self.d * self.bag)
}
fn kernel_name(&self) -> &'static str {
"EmbeddingBagSum"
}
}
/// CustomOp wrapper for [`EmbeddingBagSumKernel`].
#[derive(Debug, Clone)]
pub struct EmbeddingBagSumCustom(pub EmbeddingBagSumKernel);
impl CustomOp for EmbeddingBagSumCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
/// One-kernel fused multi-table EmbeddingBag with sum pooling.
///
/// Folds all `num_tables` independent embedding lookups into a single
/// CUDA kernel launch. Reads from one big weight tensor that is the
/// row-wise concatenation of every table; per-table row offsets are
/// baked into the kernel source. Per-table index tensors stay separate.
/// Output is `(batch, num_tables, d)` so downstream ops can consume it
/// as a single stacked tensor (matches v3's `index_select + reshape`
/// trick — Inductor fuses gather+sum across all tables; this kernel
/// just does it directly).
#[derive(Debug, Clone)]
pub struct StackedEmbeddingBagKernel {
pub batch: usize,
pub bag: usize,
pub d: usize,
pub num_tables: usize,
/// Cumulative row counts: `row_offsets[k]` = number of rows in all
/// tables strictly before table `k`. Length = `num_tables + 1`.
/// `row_offsets[num_tables]` = total rows in the stacked weight.
pub row_offsets: Vec<usize>,
}
impl KernelOp for StackedEmbeddingBagKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert_eq!(
self.row_offsets.len(),
self.num_tables + 1,
"row_offsets must have num_tables+1 entries"
);
// One index pointer per table — variadic via generated kernel signature.
let idx_params: String = (0..self.num_tables)
.map(|k| format!(", const int* __restrict__ idx_{k}"))
.collect::<Vec<_>>()
.concat();
// For each table k, generate a `case k` branch that picks the right
// index pointer and row offset. The case body is the same fused
// gather+sum loop as the single-table kernel.
let mut switch = String::new();
for k in 0..self.num_tables {
let off = self.row_offsets[k];
switch += &format!(
" case {k}: {{ const int* __restrict__ idx_ptr = idx_{k}; const int row_off = {off}; for (int l = 0; l < L; ++l) {{ int row = idx_ptr[b * L + l] + row_off; acc += weight[row * D + d]; }} break; }}\n"
);
}
// Grid is (B,); one block per batch row. Block holds *all* (k, d)
// output threads together. The previous (B, N) grid had 16-thread
// blocks at D=16, which left each SM under-occupied (Hopper's
// max-blocks-per-SM × 16 threads ≪ 64 warps/SM, so the warp
// scheduler couldn't hide memory latency). With one batch row
// per block we get K·D threads (e.g. 512 at K=32, D=16), which
// is 16 warps — enough for the SM to overlap pending loads with
// compute on other warps. Each block now produces (K, D) outputs
// instead of (1, D), so total block count drops from B·K to B
// (e.g. 65k → 2k at K=32, B=2048).
//
// Threads stride over `total = K · D` if the requested block
// size exceeds 1024 (CUDA max). At D=16 this only kicks in for
// K > 64, well above the DLRM range.
let kernel = format!(
"
extern \"C\" __global__ void stacked_embedding_bag_sum_kernel(
float* __restrict__ out,
const float* __restrict__ weight{idx_params}
) {{
const int B = {batch};
const int L = {bag};
const int D = {d};
const int K = {num_tables};
const int total = K * D;
int b = blockIdx.x;
if (b >= B) return;
for (int tid = threadIdx.x; tid < total; tid += blockDim.x) {{
int k = tid / D;
int d = tid - k * D;
float acc = 0.0f;
switch (k) {{
{switch}
default: continue;
}}
// Output laid out as (B, K, D) row-major.
out[(b * K + k) * D + d] = acc;
}}
}}
",
batch = self.batch,
bag = self.bag,
d = self.d,
num_tables = self.num_tables,
idx_params = idx_params,
switch = switch,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module
.load_function("stacked_embedding_bag_sum_kernel")
.unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
// Block size: enough threads to cover K·D output cells per batch
// row, rounded up to a warp (32) for full warp utilization, capped
// at 1024 (CUDA max block dim). Lower bound of 32 ensures we never
// launch sub-warp blocks when K·D < 32 (e.g. N=1).
let total = self.num_tables * self.d;
let block_threads = total.next_multiple_of(32).clamp(32, 1024);
(
func,
module,
kernel,
(
Expression::from(self.batch),
Expression::from(1usize),
Expression::from(1usize),
),
(
Expression::from(block_threads),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(0usize),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch * self.num_tables * self.d)
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
// Per output element, L reads from weight. Index reads ~negligible
// (D threads share the same L indices per output row).
Expression::from(self.batch * self.num_tables * self.d * self.bag * 4)
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
Expression::from(self.batch * self.num_tables * self.d * self.bag)
}
fn kernel_name(&self) -> &'static str {
"StackedEmbeddingBagSum"
}
}
#[derive(Debug, Clone)]
pub struct StackedEmbeddingBagSumCustom(pub StackedEmbeddingBagKernel);
impl CustomOp for StackedEmbeddingBagSumCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
/// Stacked-table fused EmbeddingBag with sum pooling.
///
/// * `stacked_weight`: `(sum_k rows_per_table[k], d)` F32, row-major.
/// The k-th table's rows occupy indices `[row_offsets[k], row_offsets[k+1])`
/// where `row_offsets[k] = sum_{j<k} rows_per_table[j]`.
/// * `indices`: list of `num_tables` tensors, each `(batch, bag)` I32.
/// Index values for table k are in `[0, rows_per_table[k])` — the
/// per-table row offset is added inside the kernel.
/// * `row_offsets`: cumulative starting row index for each table
/// (length `num_tables + 1`).
///
/// Returns `(batch, num_tables, d)` F32. Use `slice_along` + `squeeze`
/// (or the bundled `dlrm_pairwise_dot_lower_tri_stacked` op) to consume
/// per-table outputs downstream.
pub fn stacked_embedding_bag_sum_kernel(
stacked_weight: GraphTensor,
indices: Vec<GraphTensor>,
row_offsets: &[usize],
) -> GraphTensor {
assert_eq!(
stacked_weight.dtype,
DType::F32,
"stacked_embedding_bag_sum_kernel: weight must be F32"
);
let num_tables = indices.len();
assert!(num_tables >= 1, "need at least one index tensor");
assert_eq!(
row_offsets.len(),
num_tables + 1,
"row_offsets must have num_tables+1 entries"
);
let w_dims = stacked_weight.dims();
assert_eq!(w_dims.len(), 2, "stacked weight must be 2D (total_rows, d)");
let total_rows = w_dims[0].to_usize().expect("total_rows must be static");
assert_eq!(
total_rows, row_offsets[num_tables],
"row_offsets[-1] must equal weight total_rows"
);
let d = w_dims[1].to_usize().expect("d must be static");
let i_dims = indices[0].dims();
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
let batch = i_dims[0].to_usize().expect("batch must be static");
let bag = i_dims[1].to_usize().expect("bag must be static");
for idx in &indices {
assert_eq!(idx.dtype, DType::Int, "indices must be Int");
let id = idx.dims();
assert_eq!(id.len(), 2);
assert_eq!(id[0].to_usize().unwrap(), batch);
assert_eq!(id[1].to_usize().unwrap(), bag);
}
let kern = StackedEmbeddingBagKernel {
batch,
bag,
d,
num_tables,
row_offsets: row_offsets.to_vec(),
};
let cx = unsafe { &mut *stacked_weight.graph_ref };
let mut inputs = vec![stacked_weight];
inputs.extend(indices);
cx.custom_op(
StackedEmbeddingBagSumCustom(kern),
inputs,
(batch, num_tables, d),
DType::F32,
)
}
/// Fused EmbeddingBag with sum pooling (single table).
///
/// * `table`: `(n_emb, d)` F32, row-major.
/// * `indices`: `(batch, bag)` I32, row-major. Values must be in `[0, n_emb)`.
///
/// Returns: `(batch, d)` F32, row-major. Each output row is the sum of
/// `bag` looked-up rows from `table`.
///
/// All dimensions must be static. The returned tensor's graph node is a
/// `dyn KernelOp` in LLIR, so it lives inside the same CudaGraphOp as
/// surrounding kernel ops and benefits from the same CUDA-graph replay.
pub fn embedding_bag_sum_kernel(table: GraphTensor, indices: GraphTensor) -> GraphTensor {
assert_eq!(table.dtype, DType::F32, "embedding_bag_sum_kernel: table must be F32");
assert_eq!(
indices.dtype,
DType::Int,
"embedding_bag_sum_kernel: indices must be Int"
);
let t_dims = table.dims();
let i_dims = indices.dims();
assert_eq!(t_dims.len(), 2, "table must be 2D (n_emb, d)");
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
let n_emb = t_dims[0].to_usize().expect("n_emb must be static");
let d = t_dims[1].to_usize().expect("d must be static");
let batch = i_dims[0].to_usize().expect("batch must be static");
let bag = i_dims[1].to_usize().expect("bag must be static");
let kern = EmbeddingBagSumKernel {
batch,
bag,
d,
n_emb,
};
let cx = unsafe { &mut *table.graph_ref };
cx.custom_op(
EmbeddingBagSumCustom(kern),
vec![table, indices],
(batch, d),
DType::F32,
)
}
// ---------------------------------------------------------------------------
// Multi-table EmbeddingBag (one kernel for K independent (weight, idx) pairs)
// ---------------------------------------------------------------------------
/// Folds K independent `EmbeddingBag(sum)` lookups into a single CUDA
/// kernel launch. Used by the vanilla-DLRMv1 translator path where the
/// model has K separate `nn.EmbeddingBag` modules — each one would
/// otherwise lower to its own (~5 µs) launch.
///
/// Inputs (in `KernelOp`-order):
/// - `weight_0, weight_1, ..., weight_{K-1}` — each `(n_emb_k, d)` F32.
/// **The per-table `n_emb` may differ**; only `d` and bag size `L`
/// must match across tables.
/// - `idx_0, idx_1, ..., idx_{K-1}` — each `(batch, L)` Int (i32).
///
/// Two packed staging buffers carry the K weight + K idx device pointers
/// into the kernel (`build_params` fills them per execution via
/// `cuMemcpyHtoD`). The hot loop reads each pointer from shared memory
/// — no per-table switch needed.
///
/// Output shape: `(batch, num_tables, d)` F32, row-major.
#[derive(Debug, Clone)]
pub struct MultiTableEmbeddingBagSumKernel {
pub batch: usize,
pub bag: usize,
pub d: usize,
pub num_tables: usize,
}
impl KernelOp for MultiTableEmbeddingBagSumKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
// Layout (mirrors worktree's StackedEmbeddingBagSumKernel):
// - One block per batch row (B blocks).
// - Each block produces (K, D) output cells, striding over K·D
// threads (rounded up to a warp).
// - K weight pointers + K idx pointers come in via two packed
// staging buffers populated in `build_params`.
// - Shared memory caches both pointer arrays so the hot loop
// reads at shmem latency.
let kernel = format!(
"
extern \"C\" __global__ void multi_table_embedding_bag_sum_kernel(
float* __restrict__ out,
const long* __restrict__ w_ptrs_packed,
const long* __restrict__ idx_ptrs_packed
) {{
const int B = {batch};
const int L = {bag};
const int D = {d};
const int K = {num_tables};
const int total = K * D;
int b = blockIdx.x;
if (b >= B) return;
__shared__ const float* s_w_ptrs[K];
__shared__ const int* s_idx_ptrs[K];
if (threadIdx.x < K) {{
s_w_ptrs[threadIdx.x] = (const float*)(w_ptrs_packed[threadIdx.x]);
s_idx_ptrs[threadIdx.x] = (const int*)(idx_ptrs_packed[threadIdx.x]);
}}
__syncthreads();
for (int tid = threadIdx.x; tid < total; tid += blockDim.x) {{
int k = tid / D;
int d = tid - k * D;
const float* w = s_w_ptrs[k];
const int* idx = s_idx_ptrs[k];
float acc = 0.0f;
#pragma unroll 4
for (int l = 0; l < L; ++l) {{
int row = idx[b * L + l];
acc += w[row * D + d];
}}
// (B, K, D) row-major.
out[(b * K + k) * D + d] = acc;
}}
}}
",
batch = self.batch,
bag = self.bag,
d = self.d,
num_tables = self.num_tables,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module
.load_function("multi_table_embedding_bag_sum_kernel")
.unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let total = self.num_tables * self.d;
let block_threads = total.next_multiple_of(32).clamp(32, 1024);
(
func,
module,
kernel,
(
Expression::from(self.batch),
Expression::from(1usize),
Expression::from(1usize),
),
(
Expression::from(block_threads),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(0usize),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch * self.num_tables * self.d)
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
Expression::from(self.batch * self.num_tables * self.d * self.bag * 4
+ self.batch * self.num_tables * self.bag * 4)
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
Expression::from(self.batch * self.num_tables * self.d * self.bag)
}
fn kernel_name(&self) -> &'static str {
"MultiTableEmbeddingBagSum"
}
/// Two staging buffers: one for K weight ptrs, one for K idx ptrs.
/// Each is `K * 8` bytes (an array of u64s, written as `long*` on
/// the device side).
fn allocate_internal_buffers(
&self,
stream: &Arc<CudaStream>,
_dyn_map: &FxHashMap<char, usize>,
) -> Vec<CudaSlice<u8>> {
let buf_size = self.num_tables * 8;
vec![
stream
.alloc_zeros::<u8>(buf_size)
.expect("alloc MultiTableEmbBag w-ptr staging buffer"),
stream
.alloc_zeros::<u8>(buf_size)
.expect("alloc MultiTableEmbBag idx-ptr staging buffer"),
]
}
/// Pack the K weight + K idx pointers into the two staging buffers
/// each execution, then emit `[out, w_buf, idx_buf]` as kernel params.
///
/// `input_ptrs` layout: `[w_0, w_1, ..., w_{K-1}, idx_0, ..., idx_{K-1}]`.
/// `cuMemcpyHtoD_v2` is a blocking host call so by the time we return
/// the staging buffers are populated and the subsequent CUDA-graph
/// node-param update reads stable device pointers.
fn build_params(
&self,
stream: &Arc<CudaStream>,
output_ptr: u64,
input_ptrs: &[u64],
internal_bufs: &[CudaSlice<u8>],
_dyn_dims_ptr: u64,
) -> Vec<u64> {
assert_eq!(
input_ptrs.len(),
2 * self.num_tables,
"MultiTableEmbeddingBagSum: expected {} input pointers (K weights + K idx), got {}",
2 * self.num_tables,
input_ptrs.len(),
);
let (w_ptrs, idx_ptrs) = input_ptrs.split_at(self.num_tables);
let w_buf = &internal_bufs[0];
let idx_buf = &internal_bufs[1];
let w_dev_ptr: u64 = w_buf.device_ptr(stream).0;
let idx_dev_ptr: u64 = idx_buf.device_ptr(stream).0;
unsafe {
let r1 = cudarc::driver::sys::cuMemcpyHtoD_v2(
w_dev_ptr,
w_ptrs.as_ptr() as *const std::ffi::c_void,
w_ptrs.len() * 8,
);
assert_eq!(
r1,
cudarc::driver::sys::CUresult::CUDA_SUCCESS,
"cuMemcpyHtoD_v2 for MultiTableEmbBag w-ptr staging failed: {r1:?}",
);
let r2 = cudarc::driver::sys::cuMemcpyHtoD_v2(
idx_dev_ptr,
idx_ptrs.as_ptr() as *const std::ffi::c_void,
idx_ptrs.len() * 8,
);
assert_eq!(
r2,
cudarc::driver::sys::CUresult::CUDA_SUCCESS,
"cuMemcpyHtoD_v2 for MultiTableEmbBag idx-ptr staging failed: {r2:?}",
);
}
vec![output_ptr, w_dev_ptr, idx_dev_ptr]
}
}
#[derive(Debug, Clone)]
pub struct MultiTableEmbeddingBagSumCustom(pub MultiTableEmbeddingBagSumKernel);
impl CustomOp for MultiTableEmbeddingBagSumCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
/// Frontend helper: K independent EmbeddingBag(sum) lookups in one
/// kernel launch. Returns `(batch, num_tables, d)` F32, row-major;
/// slice along axis 1 (`out.slice_along(k..k+1, 1).squeeze(1)`) to
/// recover the k-th table's `(batch, d)` output.
///
/// * `weights`: K `(n_emb_k, d)` F32 tensors. Per-table `n_emb` may
/// differ; only `d` must be shared.
/// * `indices`: K `(batch, bag)` Int tensors (cast `.cast(DType::Int)`
/// on the caller side if your indices are i64).
pub fn multi_table_embedding_bag_sum_kernel(
weights: Vec<GraphTensor>,
indices: Vec<GraphTensor>,
) -> GraphTensor {
assert_eq!(
weights.len(),
indices.len(),
"multi_table_embedding_bag_sum_kernel: need one weight per index tensor"
);
let num_tables = weights.len();
assert!(num_tables >= 1, "need at least one table");
let first_w = weights[0];
let first_idx = indices[0];
let w_dims = first_w.dims();
let i_dims = first_idx.dims();
assert_eq!(w_dims.len(), 2, "weights must be 2D (n_emb, d)");
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
let d = w_dims[1].to_usize().expect("d must be static");
let batch = i_dims[0].to_usize().expect("batch must be static");
let bag = i_dims[1].to_usize().expect("bag must be static");
for w in &weights {
assert_eq!(w.dtype, DType::F32, "weights must all be F32");
let wd = w.dims();
assert_eq!(wd.len(), 2, "weight must be 2D");
assert_eq!(
wd[1].to_usize().unwrap(),
d,
"all weights must share inner dim"
);
}
for idx in &indices {
assert_eq!(idx.dtype, DType::Int, "indices must all be Int (i32)");
let id = idx.dims();
assert_eq!(id.len(), 2);
assert_eq!(id[0].to_usize().unwrap(), batch);
assert_eq!(id[1].to_usize().unwrap(), bag);
}
let kern = MultiTableEmbeddingBagSumKernel {
batch,
bag,
d,
num_tables,
};
let mut inputs = weights;
inputs.extend(indices);
let cx = unsafe { &mut *first_w.graph_ref };
cx.custom_op(
MultiTableEmbeddingBagSumCustom(kern),
inputs,
(batch, num_tables, d),
DType::F32,
)
}

View File

@@ -52,6 +52,19 @@ use crate::kernel::KernelOp;
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
/// `(N,)` and broadcast across batches.
/// Activation epilogue fused into the matmul kernel's store path.
///
/// Saves one full pass over the output buffer per MLP layer — the same
/// trick cuBLASLt does with `CUBLASLT_EPILOGUE_RELU_BIAS` etc., but
/// inside our custom kernel so we don't have to invoke cuBLASLt.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Activation {
#[default]
None,
Relu,
Sigmoid,
}
#[derive(Debug, Clone)]
pub struct Matmul2DKernel {
pub m: usize,
@@ -65,6 +78,18 @@ pub struct Matmul2DKernel {
pub has_bias: bool,
/// Storage dtype of B. Currently F32 or BF16 are supported.
pub weight_dtype: DType,
/// Activation applied to `acc + bias` before writing to C.
/// Defaults to None; ReLU and Sigmoid avoid a separate elementwise
/// pass over the matmul output.
pub activation: Activation,
/// When `Some(split)`, A is read from two source pointers:
/// columns `0..split` → `A_lo`, stride `split` per row
/// columns `split..K` → `A_hi`, stride `K - split` per row
/// This lets a `cat(A_lo, A_hi)` materialization be skipped entirely —
/// the K-loop's A-load branches on the column index instead. `None`
/// keeps the existing single-pointer path. Only supported for
/// `batch == 1` (DLRM's use case); the kernel asserts on this.
pub a_split: Option<usize>,
}
const TILE: usize = 16;
@@ -93,6 +118,46 @@ impl KernelOp for Matmul2DKernel {
} else {
""
};
let activation_apply = match self.activation {
Activation::None => "",
// Branchless ReLU; keeps the fully-occupied write path simple.
Activation::Relu => " acc = fmaxf(acc, 0.0f);\n",
// Sigmoid: 1/(1+exp(-acc)). Used by DLRM's final layer.
Activation::Sigmoid => " acc = 1.0f / (1.0f + __expf(-acc));\n",
};
// A-input parameter declaration + per-K-tile load expression depend
// on whether the caller asked for the dual-source (split) path.
// Single-source (default) keeps the original `const float* A` and
// reads `A[a_m * K + a_k]`. Split mode takes two pointer args
// (A_lo / A_hi) and selects between them at runtime by comparing
// `a_k` against the compile-time-baked split column.
let (a_param_decl, a_load_expr) = if let Some(split) = self.a_split {
assert!(
split > 0 && split < self.k,
"Matmul2DKernel a_split must be in 1..K; got split={split}, K={}",
self.k
);
assert_eq!(
self.batch, 1,
"Matmul2DKernel a_split path only supports batch=1 (got batch={})",
self.batch
);
let hi = self.k - split;
(
"const float* __restrict__ A_lo, const float* __restrict__ A_hi"
.to_string(),
format!(
"((a_k < {split}) \
? A_lo[a_m * {split} + a_k] \
: A_hi[a_m * {hi} + (a_k - {split})])"
),
)
} else {
(
"const float* __restrict__ A".to_string(),
"A[a_batch_off + a_m * K + a_k]".to_string(),
)
};
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
@@ -122,7 +187,7 @@ impl KernelOp for Matmul2DKernel {
"
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
float* __restrict__ C,
const float* __restrict__ A,
{a_param_decl},
{b_param_type}{bias_param}
) {{
const int M = {m};
@@ -156,10 +221,12 @@ impl KernelOp for Matmul2DKernel {
for (int t = 0; t < n_tiles; ++t) {{
int k0 = t * TILE;
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
// Load A tile (TILE, TILE) row-major from A[m, k]. In single-source
// mode this is `A[a_batch_off + a_m * K + a_k]`. In split mode the
// load expression branches on `a_k < split` (baked in by the host).
int a_m = a_m_base + ty;
int a_k = k0 + tx;
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
As[ty][tx] = (a_m < M && a_k < K) ? ({a_load_expr}) : 0.0f;
// Load B tile depending on transpose_b
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
@@ -182,7 +249,7 @@ impl KernelOp for Matmul2DKernel {
if (m_global < M && n_global < N) {{
int n = n_global;
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
{bias_add}{activation_apply} C[c_batch_off + m_global * N + n_global] = acc;
}}
}}
",
@@ -195,7 +262,10 @@ impl KernelOp for Matmul2DKernel {
b_param_type = b_param_type,
bias_param = bias_param,
bias_add = bias_add,
activation_apply = activation_apply,
bf16_include = bf16_include,
a_param_decl = a_param_decl,
a_load_expr = a_load_expr,
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
@@ -264,7 +334,20 @@ impl KernelOp for Matmul2DKernel {
}
fn kernel_name(&self) -> &'static str {
"Matmul2D"
match (self.has_bias, self.activation, self.a_split.is_some()) {
(true, Activation::Relu, false) => "Matmul2D_BiasRelu",
(true, Activation::Sigmoid, false) => "Matmul2D_BiasSigmoid",
(true, Activation::None, false) => "Matmul2D_Bias",
(false, Activation::Relu, false) => "Matmul2D_Relu",
(false, Activation::Sigmoid, false) => "Matmul2D_Sigmoid",
(false, Activation::None, false) => "Matmul2D",
(true, Activation::Relu, true) => "Matmul2D_BiasRelu_SplitA",
(true, Activation::Sigmoid, true) => "Matmul2D_BiasSigmoid_SplitA",
(true, Activation::None, true) => "Matmul2D_Bias_SplitA",
(false, Activation::Relu, true) => "Matmul2D_Relu_SplitA",
(false, Activation::Sigmoid, true) => "Matmul2D_Sigmoid_SplitA",
(false, Activation::None, true) => "Matmul2D_SplitA",
}
}
}
@@ -280,20 +363,82 @@ impl CustomOp for Matmul2DCustom {
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ false, None)
matmul_inner(a, b, /*transpose_b=*/ false, None, Activation::None)
}
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
/// pattern produced by linear / projection layers (`x @ w.t()`).
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ true, None)
matmul_inner(a, b, /*transpose_b=*/ true, None, Activation::None)
}
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
/// `(N,)`, row-major F32 throughout.
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::None)
}
/// Like [`linear_bias`] but applies ReLU in the kernel epilogue. Saves
/// one full pass over the output buffer per layer.
pub fn linear_bias_relu(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::Relu)
}
/// Like [`linear_bias`] but applies Sigmoid in the kernel epilogue.
/// Used for the final layer of binary-classifier MLPs (DLRM CTR head).
pub fn linear_bias_sigmoid(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::Sigmoid)
}
/// Two-A-input variant of [`linear_bias`].
///
/// Computes `cat(a_lo, a_hi) @ bᵀ + bias` *without* materializing the
/// concat — the K-loop's A-load reads from `a_lo` for columns `0..K_lo`
/// and from `a_hi` for columns `K_lo..K_lo+K_hi`. Logically equivalent
/// to feeding `concat_along(a_lo, a_hi, 1)` into [`linear_bias`], but
/// skips ~9 scaffolding kernels (Iota + Cast + Gather + masked-add) per
/// concat call.
///
/// Shapes:
/// * `a_lo`: `(M, K_lo)` F32
/// * `a_hi`: `(M, K_hi)` F32
/// * `b`: `(N, K_lo + K_hi)` F32 (transposed convention, same as
/// [`linear_bias`])
/// * `bias`: `(N,)` F32
///
/// Output: `(M, N)` F32. Only 2D inputs are supported (batch=1).
pub fn linear_bias_split_a(
a_lo: GraphTensor,
a_hi: GraphTensor,
b: GraphTensor,
bias: GraphTensor,
) -> GraphTensor {
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::None)
}
/// Like [`linear_bias_split_a`] but applies ReLU in the kernel epilogue.
/// Use this for hidden MLP layers that consume a concat of two upstream
/// tensors — the natural shape of DLRM's top-MLP first layer (which reads
/// `cat(dense_out, interactions)`).
pub fn linear_bias_relu_split_a(
a_lo: GraphTensor,
a_hi: GraphTensor,
b: GraphTensor,
bias: GraphTensor,
) -> GraphTensor {
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::Relu)
}
/// Like [`linear_bias_split_a`] but applies Sigmoid in the kernel
/// epilogue.
pub fn linear_bias_sigmoid_split_a(
a_lo: GraphTensor,
a_hi: GraphTensor,
b: GraphTensor,
bias: GraphTensor,
) -> GraphTensor {
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::Sigmoid)
}
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
@@ -321,12 +466,12 @@ pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ false, None)
matmul_inner(a, b, /*transpose_b=*/ false, None, Activation::None)
}
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
matmul_inner(a, b, /*transpose_b=*/ true, None)
matmul_inner(a, b, /*transpose_b=*/ true, None, Activation::None)
}
fn matmul_inner(
@@ -334,6 +479,7 @@ fn matmul_inner(
b: GraphTensor,
transpose_b: bool,
bias: Option<GraphTensor>,
activation: Activation,
) -> GraphTensor {
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
let weight_dtype = b.dtype;
@@ -412,6 +558,8 @@ fn matmul_inner(
transpose_b,
has_bias,
weight_dtype,
activation,
a_split: None,
};
let cx = unsafe { &mut *a.graph_ref };
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
@@ -425,3 +573,71 @@ fn matmul_inner(
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
}
}
/// Internal helper for the split-A path. Validates shapes and dispatches
/// to a [`Matmul2DKernel`] with `a_split = Some(K_lo)`. Always uses
/// `transpose_b = true` (linear-projection convention; matches
/// [`linear_bias`]). Only 2D inputs are supported.
fn matmul_inner_split_a(
a_lo: GraphTensor,
a_hi: GraphTensor,
b: GraphTensor,
bias: Option<GraphTensor>,
activation: Activation,
) -> GraphTensor {
assert_eq!(a_lo.dtype, DType::F32, "split-A matmul requires F32 A_lo");
assert_eq!(a_hi.dtype, DType::F32, "split-A matmul requires F32 A_hi");
let weight_dtype = b.dtype;
assert_eq!(
weight_dtype,
DType::F32,
"split-A matmul currently only supports F32 B (got {weight_dtype:?})"
);
let lo_dims = a_lo.dims();
let hi_dims = a_hi.dims();
let b_dims = b.dims();
assert_eq!(lo_dims.len(), 2, "split-A matmul A_lo must be 2D");
assert_eq!(hi_dims.len(), 2, "split-A matmul A_hi must be 2D");
assert_eq!(b_dims.len(), 2, "split-A matmul B must be 2D");
let m = lo_dims[0].to_usize().expect("M must be a static dim");
let m_hi = hi_dims[0].to_usize().expect("M (A_hi) must be a static dim");
assert_eq!(m, m_hi, "split-A matmul: A_lo and A_hi must have the same M");
let k_lo = lo_dims[1].to_usize().expect("K_lo must be a static dim");
let k_hi = hi_dims[1].to_usize().expect("K_hi must be a static dim");
let k = k_lo + k_hi;
let n = b_dims[0].to_usize().expect("N must be a static dim");
let k_b = b_dims[1].to_usize().expect("K (B) must be a static dim");
assert_eq!(
k, k_b,
"split-A matmul: A_lo.K + A_hi.K = {k} must equal B.K = {k_b}"
);
let has_bias = bias.is_some();
if let Some(bias) = bias {
let bdims = bias.dims();
assert_eq!(bdims.len(), 1, "split-A matmul bias must be 1D");
assert_eq!(
bdims[0].to_usize().expect("bias dim must be static"),
n,
"split-A matmul bias size must equal N"
);
assert_eq!(bias.dtype, DType::F32, "split-A matmul bias must be F32");
}
let kern = Matmul2DKernel {
m,
n,
k,
batch: 1,
transpose_b: true,
has_bias,
weight_dtype,
activation,
a_split: Some(k_lo),
};
let cx = unsafe { &mut *a_lo.graph_ref };
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
vec![a_lo, a_hi, b, bias]
} else {
vec![a_lo, a_hi, b]
};
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
}

View File

@@ -11,6 +11,8 @@ use uuid::Uuid;
pub mod conv2d;
pub mod cuda_graph;
pub mod dlrm_interact;
pub mod embedding_bag;
pub mod fusion;
pub mod generic_matmul;
pub mod hlir;
@@ -20,10 +22,22 @@ pub mod rope;
pub use conv2d::KernelConv2D;
pub use cuda_graph::*;
pub use dlrm_interact::{
PairwiseDotLowerTriCustom, PairwiseDotLowerTriKernel, PairwiseDotLowerTriStackedCustom,
PairwiseDotLowerTriStackedKernel, dlrm_pairwise_dot_lower_tri,
dlrm_pairwise_dot_lower_tri_stacked,
};
pub use embedding_bag::{
EmbeddingBagSumCustom, EmbeddingBagSumKernel, MultiTableEmbeddingBagSumCustom,
MultiTableEmbeddingBagSumKernel, StackedEmbeddingBagKernel,
StackedEmbeddingBagSumCustom, embedding_bag_sum_kernel,
multi_table_embedding_bag_sum_kernel, stacked_embedding_bag_sum_kernel,
};
pub use generic_matmul::GenericMatmul;
pub use matmul2d::{
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
matmul_3d, matmul_3d_t,
Activation, Matmul2DCustom, Matmul2DKernel, linear_bias, linear_bias_relu,
linear_bias_relu_split_a, linear_bias_sigmoid, linear_bias_sigmoid_split_a,
linear_bias_split_a, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t, matmul_3d, matmul_3d_t,
};
pub use rope::{RoPECustom, RoPEKernel, apply_rope};

View File

@@ -215,6 +215,25 @@ impl CudaRuntime {
.collect()
}
/// Read per-kernel GPU elapsed times (ms), if event-record nodes were
/// inserted at graph build time.
///
/// The per-kernel event recording infra from `origin/dlrm-fused-kernels`
/// is not ported on this branch yet — this stub returns empty so the
/// dlrm example's optional `LUMINAL_KERNEL_TIMING=1` path falls back to
/// "(no per-kernel timings available — events not recorded)".
pub fn read_per_kernel_timings_ms(&self) -> Vec<(&'static str, f32)> {
Vec::new()
}
/// Synchronize the runtime's CUDA stream. Use this after `execute()` if
/// you need GPU-side completion time (e.g. for benchmarking) — `execute`
/// itself no longer syncs at the end so it stays capturable by
/// `torch.cuda.CUDAGraph` and similar external graph-capture machinery.
pub fn synchronize_stream(&self) {
let _ = self.cuda_stream.synchronize();
}
fn bucket_buffer(
bucket: &CompiledBucket,
stream: &Arc<CudaStream>,
@@ -1675,8 +1694,15 @@ impl Runtime for CudaRuntime {
}
}
}
// Single sync at end - CUDA stream ordering guarantees sequential execution
self.cuda_stream.synchronize().unwrap();
// Sync only when profiling (kernel search timing needs an accurate
// total). In the regular execute path, dropping the sync lets the
// call be captured by `torch.cuda.CUDAGraph` (or any external graph
// capture) — PyTorch syncs on tensor reads, so correctness is
// preserved. The CPU-side `last_total_time_us` becomes a dispatch-
// time measurement in that case.
if self.profiling {
self.cuda_stream.synchronize().unwrap();
}
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
// Populate last_kernel_stats from HostOps that report stats
@@ -1717,9 +1743,22 @@ impl Runtime for CudaRuntime {
let to_consume: Vec<NodeIndex> = self
.hlir_buffers
.keys()
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
.copied()
.iter()
// Don't consume external device pointers — they're non-owning
// views over caller-provided memory and they represent stable
// input slots that the wrapper *may* skip re-registering on a
// hot iter if the underlying pointer is unchanged. Removing
// them here would force the wrapper to re-register every
// iter even when nothing changed (which the prior behavior
// assumed). Internal `CudaInput::Buffer` entries — e.g.
// weights loaded via `set_data_bytes` and one-shot CPU input
// copies — still get consumed when they're not preserved by
// the bucket.
.filter(|(hlir_node, input)| {
!inputs_with_outputs.contains(hlir_node)
&& !matches!(input, CudaInput::Ptr(_))
})
.map(|(n, _)| *n)
.collect();
for hlir_node in to_consume {

View File

@@ -391,6 +391,73 @@ impl CompiledGraph {
Ok(())
}
/// Resolve an input or output tensor name to an opaque integer ID.
/// One FFI hop at compile time, then per-iter `run_with_ptrs` and
/// the `*_by_id` setters can skip the string-keyed HashMap lookup.
/// Returns the underlying `NodeIndex.index() as u32` — caller should
/// treat it as opaque.
fn tensor_id(&self, name: &str) -> PyResult<u32> {
self.tensor_ids
.get(name)
.map(|n| n.index() as u32)
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown tensor: {}",
name
))
})
}
/// One-shot batched: register all input + output device pointers and
/// execute. Collapses the per-iter `set_input_device_ptr` ×
/// (n_inputs) + `set_output_device_ptr` × (n_outputs) + `run` chain
/// into a single Python→Rust FFI crossing.
///
/// Inputs and outputs use the opaque IDs returned by `tensor_id(name)`
/// to skip the per-call string lookup. Returns a parallel vector of
/// "is zero-copy" booleans for each output (same as
/// `output_is_zero_copy(name)` on the unbatched path), so the Python
/// caller can fall back to a DtoD copy when a kernel aliased an
/// output instead of writing into the registered buffer.
///
/// Safety contract:
/// * Every `device_ptr` must point to a valid CUDA allocation with
/// at least `n_bytes` bytes.
/// * Pointers must remain valid through the duration of `run()`.
/// * `tensor_id`s must come from `self.tensor_id(name)` on this
/// same graph — using stale IDs from a different graph is UB
/// (will likely panic in `NodeIndex` lookups, but not guaranteed).
fn run_with_ptrs(
&mut self,
inputs: Vec<(u32, u64, usize)>,
outputs: Vec<(u32, u64, usize)>,
) -> PyResult<Vec<bool>> {
if !self.runtime.supports_device_ptrs() {
return Err(pyo3::exceptions::PyValueError::new_err(
"run_with_ptrs requires a GPU backend",
));
}
// Register inputs.
for (id, ptr, n) in &inputs {
let node_id = NodeIndex::new(*id as usize);
unsafe { self.runtime.set_device_ptr(node_id, *ptr, *n) };
}
// Register outputs.
for (id, ptr, n) in &outputs {
let node_id = NodeIndex::new(*id as usize);
unsafe { self.runtime.set_output_device_ptr(node_id, *ptr, *n) };
}
// Execute.
self.runtime.execute(&self.graph.dyn_map);
// Report zero-copy status for each output (parallel to `outputs`
// input order). Aliased outputs need a DtoD copy on the Python
// side, same as the unbatched path.
Ok(outputs
.iter()
.map(|(id, _, _)| self.runtime.output_is_zero_copy(NodeIndex::new(*id as usize)))
.collect())
}
/// Register a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
/// Requires a GPU backend.
fn set_weight_device_ptr(

View File

@@ -119,30 +119,147 @@ impl<'a> Translator<'a> {
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
// Matmul
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
"torch.ops.aten.mm.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
a.matmul(b)
}
// bmm: batched 3-D matmul. Fast path under cuda + F32 when
// B was produced by a permute([0, 2, 1]) (i.e. `T @ T.T` —
// the DLRM pairwise-interaction pattern): route to
// `matmul_3d_t` with the original (B, F, D) tensor, which
// uses the fused Matmul2DKernel and avoids the
// expand+mul+sum-reduce decomposition that produces ~25
// small kernels per bmm.
"torch.ops.aten.bmm.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let b_src = node.inputs.get(1).and_then(|n| n.arg.as_tensor_name())
.and_then(|n| self.transpose_2d_source.get(n).cloned());
let f32_all = a.dtype == DType::F32 && b.dtype == DType::F32;
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
if cfg!(feature = "cuda")
&& backend_is_cuda
&& f32_all
&& a.shape.dims.len() == 3
&& b.shape.dims.len() == 3
&& let Some(orig_name) = b_src
&& let Some(orig_b) = self.tensors.get(&orig_name).copied()
&& orig_b.shape.dims.len() == 3
{
// a: (B, M, K), orig_b: (B, N, K) — matmul_3d_t does
// a @ orig_b.t() = (B, M, K) @ (B, K, N) = (B, M, N).
luminal_cuda_lite::kernel::matmul_3d_t(a, orig_b)
} else {
let (a, b) = ensure_same_dtype(a, b);
a.matmul(b)
}
}
// addmm: beta*input + alpha*(mat1 @ mat2)
//
// Fast path (CUDA, the common nn.Linear case): when
// * shapes are 2-D F32, alpha=beta=1
// * mat2 was produced by `aten.permute([1,0])` of a 2-D
// tensor (`weight.t()` from nn.Linear)
// * bias is 1-D
// we lower to the fused `linear_bias` kernel using the
// *original* (N,K) weight — bypassing the
// expand+mul+sum-reduce decomposition that otherwise
// produces ~25 small kernels per Linear layer (~3.7 ms on
// tiny shapes due to launch overhead).
//
// The transpose detection comes from
// `translate_permute`, which populates
// `transpose_2d_source` whenever it sees a 2-D permute.
"torch.ops.aten.addmm.default" => {
let input = self.get_input_tensor(node, 0)?;
let mat1 = self.get_input_tensor(node, 1)?;
let mat2 = self.get_input_tensor(node, 2)?;
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
let mat2_src = node.inputs.get(2).and_then(|n| n.arg.as_tensor_name())
.and_then(|n| self.transpose_2d_source.get(n).cloned());
let unit_scale = (alpha - 1.0).abs() < 1e-7 && (beta - 1.0).abs() < 1e-7;
let f32_all = mat1.dtype == DType::F32
&& mat2.dtype == DType::F32
&& input.dtype == DType::F32;
let two_d = mat1.shape.dims.len() == 2 && mat2.shape.dims.len() == 2;
let bias_is_1d = input.shape.dims.len() == 1;
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
if cfg!(feature = "cuda")
&& backend_is_cuda
&& two_d
&& f32_all
&& unit_scale
&& bias_is_1d
&& let Some(weight_name) = mat2_src
&& let Some(orig_weight) = self.tensors.get(&weight_name).copied()
&& orig_weight.shape.dims.len() == 2
{
// Forward-looking fusion: if this addmm has exactly one
// consumer and that consumer is `aten.relu.default` or
// `aten.sigmoid.default`, emit the fused
// `linear_bias_relu` / `linear_bias_sigmoid` kernel and
// mark the consumer as absorbed so we don't emit a
// redundant unary op downstream. This collapses the
// standard nn.Linear+ReLU MLP layer to one kernel.
let addmm_out: Option<String> = node
.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
let fuse_act = addmm_out
.as_deref()
.and_then(|n| self.unique_consumer(n));
let (fused, absorbed) = match fuse_act {
Some((target, out_name)) if target == "torch.ops.aten.relu.default" => {
(
Some(luminal_cuda_lite::kernel::linear_bias_relu(
mat1, orig_weight, input,
)),
Some(out_name),
)
}
Some((target, out_name)) if target == "torch.ops.aten.sigmoid.default" => {
(
Some(luminal_cuda_lite::kernel::linear_bias_sigmoid(
mat1, orig_weight, input,
)),
Some(out_name),
)
}
_ => (None, None),
};
if let Some(t) = fused {
if let Some(out_name) = absorbed {
self.absorbed_nodes.insert(out_name.clone());
self.tensors.insert(out_name, t);
}
t
} else {
// No unary consumer to fuse; plain linear+bias.
luminal_cuda_lite::kernel::linear_bias(mat1, orig_weight, input)
}
} else {
// Generic fallback (non-cuda, scaled, or unknown
// mat2 source).
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
}
}
// Convolution
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
// Reduction ops
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
"torch.ops.aten.sum.dim_IntList" => self.translate_sum_with_embbag_peephole(node)?,
"torch.ops.aten.mean.dim" => self.translate_reduction(node, ReductionOp::Mean)?,
"torch.ops.aten.amax.default" => self.translate_reduction(node, ReductionOp::Max)?,
@@ -151,10 +268,25 @@ impl<'a> Translator<'a> {
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
// EmbeddingBag (sum-pool, fixed bag size). PT2 export decomposes
// nn.EmbeddingBag → `_embedding_bag` (with backward) or
// `_embedding_bag_forward_only` (without). Both return the same
// 4-tuple `(output, offset2bag, bag_size, max_indices)` and only
// the first slot is consumed at inference. The two op variants
// are math-identical for the forward path, so route both through
// the same handler. The handler stores into `tensors` itself
// (multi-output op) so we return early afterwards.
"torch.ops.aten._embedding_bag_forward_only.default"
| "torch.ops.aten._embedding_bag.default" => {
self.translate_embedding_bag_forward_only(node)?;
return Ok(());
}
// Softmax
"torch.ops.aten._softmax.default" => {
let a = self.get_input_tensor(node, 0)?;
@@ -514,6 +646,51 @@ impl<'a> Translator<'a> {
};
if !output_name.is_empty() {
// Record the chain (FX target + first input name) keyed by
// output name so multi-node peepholes (e.g. the EmbBag fast
// path that detects sum ← view ← index_select) can walk
// back without re-scanning all parsed nodes.
//
// For variadic ops (e.g. `aten.cat.default` whose first
// arg is `as_tensors`) fall back to the first entry of the
// variadic tensor list. The DLRM PairwiseDot peephole
// needs `node_chain[cat]` to walk back from `bmm → cat`.
let first_input_name: Option<String> = node
.inputs
.first()
.and_then(|i| {
i.arg
.as_tensor_name()
.map(|s| s.to_string())
.or_else(|| {
i.arg
.as_tensors()
.and_then(|ts| ts.first().map(|tn| tn.name.clone()))
})
});
if let Some(first_input) = first_input_name {
self.node_chain.insert(
output_name.clone(),
(node.target.clone(), first_input),
);
}
// Also record the full input-name list (in order, including
// entries that come from `as_tensors` for variadic ops like
// `aten.cat`). Used by the DLRM PairwiseDot peephole which
// needs all cat inputs and both bmm inputs.
let mut all_inputs: Vec<String> = Vec::new();
for inp in &node.inputs {
if let Some(names) = inp.arg.as_tensors() {
for tn in names {
all_inputs.push(tn.name.clone());
}
} else if let Some(name) = inp.arg.as_tensor_name() {
all_inputs.push(name.to_string());
}
}
if !all_inputs.is_empty() {
self.op_inputs.insert(output_name.clone(), all_inputs);
}
self.tensors.insert(output_name, result);
}
Ok(())
@@ -521,6 +698,69 @@ impl<'a> Translator<'a> {
}
impl<'a> Translator<'a> {
/// Peephole for the DLRM-v3 embedding-bag pattern:
/// `sum(dim=[1], keepdim=False)( view([?, L, D])( index_select(W, 0, IDX) ) )`
/// substitutes the fused `embedding_bag_sum_kernel(W, IDX.view(?, L))`
/// — same kernel as the hand-rolled DLRM example uses. Falls back to
/// the generic reduction path when the chain doesn't match.
pub(crate) fn translate_sum_with_embbag_peephole(
&mut self,
node: &Node,
) -> Result<GraphTensor> {
let dims = self.get_ints_arg(node, 1).unwrap_or_default();
let keepdim = self.get_bool_arg(node, 2).unwrap_or(false);
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
// Only attempt fast-path under cuda + the specific sum(dim=[1]) pattern.
if cfg!(feature = "cuda")
&& backend_is_cuda
&& dims.len() == 1
&& dims[0] == 1
&& !keepdim
&& let Some(sum_input_name) = node.inputs.first().and_then(|i| i.arg.as_tensor_name())
&& let Some((view_target, view_src)) = self.node_chain.get(sum_input_name).cloned()
&& view_target == "torch.ops.aten.view.default"
&& let Some((is_target, is_src)) = self.node_chain.get(&view_src).cloned()
&& is_target == "torch.ops.aten.index_select.default"
// Pull the FX index_select node so we can grab its dim + index args.
&& let Some(is_node) = self.parsed.program.graph_module.graph.nodes.iter()
.find(|n| n.outputs.first()
.and_then(|o| o.as_tensor.as_ref())
.map(|t| t.name == view_src)
.unwrap_or(false))
{
let weight = self.tensors.get(&is_src).copied();
let idx_name = is_node.inputs.get(2).and_then(|i| i.arg.as_tensor_name());
let is_dim = self.get_int_arg(is_node, 1).unwrap_or(-1);
let in_tensor = self.tensors.get(sum_input_name).copied();
if is_dim == 0
&& let Some(w) = weight
&& let Some(idx_n) = idx_name
&& let Some(idx) = self.tensors.get(idx_n).copied()
&& let Some(inp) = in_tensor
&& w.shape.dims.len() == 2
&& idx.shape.dims.len() == 1
&& inp.shape.dims.len() == 3
// ensure view's middle dim == idx's bag dim divides idx total
&& inp.dtype == DType::F32
&& w.dtype == DType::F32
{
let l = inp.shape.dims[1];
let kb = inp.shape.dims[0];
let d = inp.shape.dims[2];
// Reshape flat indices (K*B*L,) to (K*B, L).
let idx_2d = reshape_tensor(idx, vec![kb, l]);
// embedding_bag_sum_kernel expects (n_emb, d) weights +
// (batch, bag) indices, returns (batch, d).
let _ = d; // d already encoded in `w.shape.dims[1]`
return Ok(luminal_cuda_lite::kernel::embedding_bag_sum_kernel(w, idx_2d));
}
}
// Generic fallback.
self.translate_reduction(node, ReductionOp::Sum)
}
fn translate_scalar_comparison(
&mut self,
node: &Node,

View File

@@ -51,6 +51,42 @@ pub(crate) struct Translator<'a> {
pub(crate) output_ids: Vec<(String, NodeIndex)>,
/// Extra tensor metadata from inlined subgraphs.
pub(crate) extra_tensor_values: HashMap<String, TensorMeta>,
/// Peephole: maps an output-tensor name produced by a `permute([1,0])`
/// (i.e. a 2-D transpose) back to its input-tensor name. Used by the
/// addmm dispatch to detect `aten.addmm(bias, x, weight.t())` and
/// route it through the fused `Matmul2DKernel` (`matmul_2d_t`) with
/// the original weight, instead of through the generic
/// expand+mul+sum decomposition that materializes ~25 small kernels.
pub(crate) transpose_2d_source: HashMap<String, String>,
/// Trace each emitted node by its first output's name → (FX target,
/// first input's name). Used by the EmbBag peephole to walk back
/// `sum.dim_IntList → aten.view.default → aten.index_select.default`
/// and substitute the fused `embedding_bag_sum_kernel` for the slow
/// expand+gather decomposition. Populated by `record_node_chain`
/// after dispatching each op.
pub(crate) node_chain: HashMap<String, (String, String)>,
/// Per-node side table mapping the primary output name → list of all
/// input tensor names (in order). Lets multi-input peepholes — e.g.
/// `index.Tensor(bmm(cat([…]), permute(cat([…]))), [None, li, lj])`
/// → `dlrm_pairwise_dot_lower_tri([…])` — walk back through cat and
/// bmm without re-scanning the FX node array. Populated alongside
/// `node_chain` after each translated op.
pub(crate) op_inputs: HashMap<String, Vec<String>>,
/// Tensor name → list of *consumer output-tensor names*. Built once at
/// the start of `translate_graph` from the parsed FX node array.
/// (The pt2_schema's `Node` has no name field; nodes are identified by
/// their primary output tensor name.) Used by forward-looking fusions:
/// e.g. when the addmm fast path fires and the single consumer is
/// `relu`/`sigmoid`, we emit the fused `linear_bias_relu`/
/// `linear_bias_sigmoid` kernel and absorb the consumer node via
/// `absorbed_nodes`.
pub(crate) consumers: HashMap<String, Vec<String>>,
/// Set of *output tensor names* whose producing FX node was absorbed
/// into an earlier node's emission (e.g. a `relu` folded into
/// `linear_bias` by the addmm fast path). `translate_graph` short-
/// circuits these nodes. The absorbed node's output tensor must be
/// pre-populated under its name by the absorbing node.
pub(crate) absorbed_nodes: std::collections::HashSet<String>,
}
impl<'a> Translator<'a> {
@@ -61,17 +97,112 @@ impl<'a> Translator<'a> {
graph: Graph::new(),
tensors: HashMap::new(),
sym_map,
transpose_2d_source: HashMap::new(),
node_chain: HashMap::new(),
op_inputs: HashMap::new(),
consumers: HashMap::new(),
absorbed_nodes: std::collections::HashSet::new(),
user_input_ids: Vec::new(),
output_ids: Vec::new(),
extra_tensor_values: HashMap::new(),
})
}
/// Helper: extract a node's primary output tensor name. Nodes in the
/// pt2 schema are identified by this (no separate name field).
fn node_out_name(node: &Node) -> Option<String> {
node.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
.or_else(|| {
node.outputs
.first()
.and_then(|o| o.as_tensors.as_ref())
.and_then(|ts| ts.first().map(|t| t.name.clone()))
})
}
/// Build a `tensor_name → consumer output-tensor names` map for the
/// parsed graph. One pass over all FX nodes; each node contributes to
/// the `consumers` entry of every tensor it reads (Argument::Tensor or
/// Argument::Tensors). The consumer is keyed by its primary output
/// tensor name. Used by forward-looking fast paths to detect when an
/// op's output has a single downstream consumer of a known kind
/// (relu/sigmoid) and emit a fused kernel that absorbs the consumer.
fn build_consumers(&mut self) {
let nodes = &self.parsed.program.graph_module.graph.nodes;
for node in nodes {
let Some(consumer_out) = Self::node_out_name(node) else {
continue;
};
for inp in &node.inputs {
match &inp.arg {
Argument::Tensor(t) => {
self.consumers
.entry(t.as_tensor.name.clone())
.or_default()
.push(consumer_out.clone());
}
Argument::Tensors(ts) => {
for t in &ts.as_tensors {
self.consumers
.entry(t.name.clone())
.or_default()
.push(consumer_out.clone());
}
}
Argument::OptionalTensors(ots) => {
for ot in &ots.as_optional_tensors {
if let crate::pt2_schema::OptionalTensorEntry::Tensor(t) = ot {
self.consumers
.entry(t.as_tensor.name.clone())
.or_default()
.push(consumer_out.clone());
}
}
}
_ => {}
}
}
}
}
/// Return the unique consumer of `tensor_name`, if there is exactly one
/// and we can find the corresponding FX node. Returns `(target,
/// output_tensor_name)`. None if the consumer set is empty, has more
/// than one entry, or the FX node lookup fails.
pub(crate) fn unique_consumer(&self, tensor_name: &str) -> Option<(String, String)> {
let consumers = self.consumers.get(tensor_name)?;
if consumers.len() != 1 {
return None;
}
let consumer_out = &consumers[0];
let node = self
.parsed
.program
.graph_module
.graph
.nodes
.iter()
.find(|n| Self::node_out_name(n).as_deref() == Some(consumer_out.as_str()))?;
Some((node.target.clone(), consumer_out.clone()))
}
fn translate_graph(&mut self) -> Result<()> {
self.create_inputs()?;
self.build_consumers();
let nodes = &self.parsed.program.graph_module.graph.nodes;
for (i, node) in nodes.iter().enumerate() {
// Skip nodes whose translation was absorbed by an earlier
// node's fast path (e.g. a `relu` folded into a fused
// `linear_bias_relu`). The absorbing node has already
// populated `tensors` under this node's output name.
if let Some(out_name) = Self::node_out_name(node)
&& self.absorbed_nodes.contains(&out_name)
{
continue;
}
self.translate_node(node)
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
}
@@ -90,9 +221,64 @@ impl<'a> Translator<'a> {
self.output_ids.push((name.clone(), tensor.id));
}
// Post-translation dead-code elimination. luminal's egglog DOES
// prune unreachable subgraphs in the common case (e.g. an unused
// `x*2.0` next to a returned `x+1.0`), but in some patterns the
// optimizer holds onto subgraphs that were created and then
// superseded by a translator peephole — most notably the DLRM
// PairwiseDot path where `index.Tensor(bmm(cat(...), perm(cat(...))), ...)`
// is replaced with a fused custom op but the original bmm/cat
// pad-and-add chain remains in the HLIR. Walk back from every
// `Output` HLIR node, mark reachable producers, and drop the rest.
// Preserves `Input` nodes unconditionally so the runtime's input
// signature stays intact even when an input is unused (a few
// models pass dead constants alongside live tensors).
self.dce();
Ok(())
}
/// Sweep the HLIR graph: remove every node not reachable backward
/// from an `Output` HLIR sink. Inputs are kept regardless so the
/// runtime input contract is preserved.
fn dce(&mut self) {
use luminal::hlir::{Input, Output};
use petgraph::Direction;
use std::collections::HashSet;
let mut keep: HashSet<NodeIndex> = HashSet::new();
let mut stack: Vec<NodeIndex> = Vec::new();
let node_ids: Vec<NodeIndex> = self.graph.graph.node_indices().collect();
for n in &node_ids {
if self.graph.try_get_op::<Output>(*n).is_some() {
if keep.insert(*n) {
stack.push(*n);
}
}
if self.graph.try_get_op::<Input>(*n).is_some() {
keep.insert(*n);
}
}
while let Some(n) = stack.pop() {
// Walk incoming edges — operands of `n`.
let preds: Vec<NodeIndex> = self
.graph
.graph
.neighbors_directed(n, Direction::Incoming)
.collect();
for pred in preds {
if keep.insert(pred) {
stack.push(pred);
}
}
}
for n in node_ids {
if !keep.contains(&n) {
self.graph.graph.remove_node(n);
}
}
}
fn create_inputs(&mut self) -> Result<()> {
let inputs = self.parsed.classify_inputs();
for input in &inputs {

View File

@@ -72,6 +72,24 @@ impl<'a> Translator<'a> {
.iter()
.map(|&d| normalize_dim(d, a.shape.len()))
.collect();
// Record matmul-compatible inner-axis transposes so addmm /
// bmm can route them through the fused Matmul2DKernel /
// matmul_3d_t with the *original* input. The view-transposed
// tensor has non-contiguous strides that the SGEMM kernel
// doesn't honor, so we need the original. We recognize two
// patterns:
// * 2-D permute [1, 0] — `weight.t()` from nn.Linear
// * 3-D permute [0, 2, 1] — `T.transpose(1, 2)` for bmm
let is_inner_transpose = (axes == [1usize, 0usize] && a.shape.dims.len() == 2)
|| (axes == [0usize, 2usize, 1usize] && a.shape.dims.len() == 3);
if is_inner_transpose
&& let Some(src_name) = node.inputs.first().and_then(|i| i.arg.as_tensor_name())
&& let Some(out_ref) = node.outputs.first()
&& let Some(out_t) = out_ref.as_tensor.as_ref()
{
self.transpose_2d_source
.insert(out_t.name.clone(), src_name.to_string());
}
Ok(a.permute(axes))
}
@@ -256,7 +274,231 @@ impl<'a> Translator<'a> {
Ok(weight.gather(ids_expanded + arange_expanded))
}
/// `aten.index_select(input, dim, index)` — pick rows/slices of `input`
/// along `dim` using a 1-D `index` tensor. Output shape is
/// `input.shape` with `dim` replaced by `index.shape[0]`.
///
/// For the DLRM v3 use case this is `index_select(emb_weight, 0,
/// flat_indices)` — a 2-D source and 1-D index along dim 0. We lower
/// it the same way `translate_embedding` does: build a flat-rows
/// gather index `(index * hidden_dim) + arange(hidden_dim)` and read
/// the flattened weight in one pass. Higher-rank sources and non-zero
/// `dim` are not yet wired (would need stride math over Expression
/// shapes); they error out cleanly so they're easy to add when the
/// next model surfaces them.
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
let source = self.get_input_tensor(node, 0)?;
let dim_raw = self.get_int_arg(node, 1)?;
let index = self.get_input_tensor(node, 2)?;
let rank = source.shape.dims.len();
anyhow::ensure!(
rank == 2,
"translate_index_select: only 2-D source supported (got rank {rank}); \
extend this when a model needs higher rank."
);
let dim = if dim_raw < 0 {
dim_raw + rank as i64
} else {
dim_raw
};
anyhow::ensure!(
dim == 0,
"translate_index_select: only dim=0 supported (got {dim}); \
extend this when a model needs another axis."
);
anyhow::ensure!(
index.shape.dims.len() == 1,
"translate_index_select: index must be 1-D (got rank {})",
index.shape.dims.len()
);
// Same lowering as `translate_embedding`: build a flat gather index
// that combines the row-base offsets (`index * hidden_dim`) with a
// per-row `arange(hidden_dim)` broadcast.
let hidden_dim = source.shape.dims[1];
let n_idx = index.shape.dims[0];
let index_int = index.cast(DType::Int);
let base_expanded = (index_int * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim);
let arange_expanded = arange.expand_dim(0, n_idx);
Ok(source.gather(base_expanded + arange_expanded))
}
/// `aten._embedding_bag_forward_only(weight, indices, offsets,
/// scale_grad_by_freq, mode, sparse, per_sample_weights,
/// include_last_offset, padding_idx)` →
/// `(output, offset2bag, bag_size, max_indices)`.
///
/// PyTorch decomposes `nn.EmbeddingBag` to this op. For the DLRM use
/// case all bags share a fixed stride `L = indices.len() / offsets.len()`
/// and `mode == 0` (sum). We detect that and lower to the fused
/// [`embedding_bag_sum_kernel`] on CUDA, or to a generic
/// `gather → reshape → sum` chain on CPU.
///
/// Only `output` (tuple slot 0) is computed — `offset2bag`, `bag_size`
/// and `max_indices` are training-time dead ends for inference DLRM
/// and never read by any downstream `getitem`.
pub(crate) fn translate_embedding_bag_forward_only(&mut self, node: &Node) -> Result<()> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?;
let offsets = self.get_input_tensor(node, 2)?;
let mode = self.get_int_arg(node, 4).unwrap_or(0);
anyhow::ensure!(
mode == 0,
"translate_embedding_bag_forward_only: only mode=0 (sum) supported (got {mode}); \
vanilla DLRM uses sum-pooled bags. Extend this when a model needs mean/max."
);
// per_sample_weights is input index 6 and may be None / absent.
let has_per_sample_weights = node
.inputs
.get(6)
.and_then(|i| i.arg.as_tensor_name())
.is_some();
anyhow::ensure!(
!has_per_sample_weights,
"translate_embedding_bag_forward_only: per_sample_weights not supported \
(DLRM doesn't use them)."
);
anyhow::ensure!(
weight.shape.dims.len() == 2,
"translate_embedding_bag_forward_only: weight must be 2-D (got rank {})",
weight.shape.dims.len()
);
anyhow::ensure!(
indices.shape.dims.len() == 1,
"translate_embedding_bag_forward_only: indices must be 1-D (got rank {})",
indices.shape.dims.len()
);
anyhow::ensure!(
offsets.shape.dims.len() == 1,
"translate_embedding_bag_forward_only: offsets must be 1-D (got rank {})",
offsets.shape.dims.len()
);
let n_idx = indices.shape.dims[0]
.to_usize()
.context("translate_embedding_bag_forward_only: indices length must be static")?;
let batch = offsets.shape.dims[0]
.to_usize()
.context("translate_embedding_bag_forward_only: offsets length must be static")?;
anyhow::ensure!(
n_idx % batch == 0,
"translate_embedding_bag_forward_only: indices length ({n_idx}) must be a \
multiple of offsets length ({batch}); variable bag sizes not supported."
);
let bag = n_idx / batch;
let d = weight.shape.dims[1]
.to_usize()
.context("translate_embedding_bag_forward_only: weight dim 1 must be static")?;
// Reshape indices (B*L,) → (B, L) and cast to i32 (luminal kernel
// wants Int). Then either use the fused kernel under CUDA or
// a host-portable gather+sum lowering.
let indices_int = indices.cast(DType::Int);
let indices_2d = {
let new_shape = ShapeTracker::new(vec![
Expression::from(batch),
Expression::from(bag),
]);
GraphTensor {
id: indices_int.id,
graph_ref: indices_int.graph_ref,
shape: new_shape,
dtype: indices_int.dtype,
}
};
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
let result = if cfg!(feature = "cuda") && backend_is_cuda && weight.dtype == DType::F32 {
// Fused CUDA path: one kernel for the whole bag-sum.
#[cfg(feature = "cuda")]
{
luminal_cuda_lite::kernel::embedding_bag_sum_kernel(weight, indices_2d)
}
#[cfg(not(feature = "cuda"))]
{
// Unreachable — gated above, but keep the compiler happy.
unreachable!()
}
} else {
// Generic fallback: gather (B*L, D) then reshape and sum.
let hidden_dim = weight.shape.dims[1];
let ids_expanded = (indices_2d * hidden_dim).expand_dim(2, hidden_dim);
let arange = self.graph.arange(hidden_dim);
let arange_expanded = arange.expand_dim(0, batch).expand_dim(0, bag);
// Note: weight.gather expects the gather indices to broadcast
// against weight's row-flattened layout; we want (B, L, D)
// out, then sum along L.
let _ = d; // hidden_dim is used; keep `d` reachable for debug only.
let gathered = weight.gather(ids_expanded + arange_expanded);
gathered.sum(1)
};
// Record the output under outputs[0][0] (the tuple's first slot).
// The other three slots are dead under inference and there's no
// downstream `getitem` that reads them — but if there ever is,
// we'd need to materialize them too.
let out_name = node
.outputs
.first()
.and_then(|o| o.as_tensors.as_ref())
.and_then(|ts| ts.first().map(|t| t.name.clone()))
.or_else(|| {
node.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
})
.context(
"translate_embedding_bag_forward_only: missing output[0] name in FX node",
)?;
self.tensors.insert(out_name.clone(), result);
// Record node_chain / op_inputs for the *primary* output (tuple
// slot 0). Multi-output ops normally skip the bookkeeping at the
// bottom of `translate_node` (they return early), but later
// peepholes — specifically the stacked-emb-bag fusion inside the
// pairwise-dot peephole — need to identify which tensor sources
// came from an embedding_bag, so we record explicitly.
let first_input_name: Option<String> = node
.inputs
.first()
.and_then(|i| i.arg.as_tensor_name().map(|s| s.to_string()));
if let Some(first_input) = first_input_name {
self.node_chain
.insert(out_name.clone(), (node.target.clone(), first_input));
}
let mut all_inputs: Vec<String> = Vec::new();
for inp in &node.inputs {
if let Some(names) = inp.arg.as_tensors() {
for tn in names {
all_inputs.push(tn.name.clone());
}
} else if let Some(name) = inp.arg.as_tensor_name() {
all_inputs.push(name.to_string());
}
}
if !all_inputs.is_empty() {
self.op_inputs.insert(out_name, all_inputs);
}
Ok(())
}
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
// Try the DLRM PairwiseDot peephole before falling back to the
// generic gather-based lowering. Detects the
// Z[:, li, lj] ← bmm(T, T.transpose(1, 2)) ← cat([t1.unsqueeze(1), ..., tF.unsqueeze(1)], dim=1)
// pattern that vanilla `nn.Sequential` DLRM (DLRMv1) emits — the
// version a user writes with one `EmbeddingBag` per categorical
// table. Replacing it with `dlrm_pairwise_dot_lower_tri` collapses
// the cat-then-bmm-then-gather chain (which lowers to ~40
// small Iota/Cast/Gather/FusedRegion kernels via pad_along+add)
// into a single CUDA kernel.
if let Some(t) = self.try_translate_pairwise_dot_lower_tri(node)? {
return Ok(t);
}
let source = self.get_input_tensor(node, 0)?;
// Handle indices as_tensors (all non-None) or as individual args with None entries
@@ -308,6 +550,98 @@ impl<'a> Translator<'a> {
expanded.shape.expand(target);
return Ok(source.gather_elements(expanded, first_non_none_dim));
}
// Multi-index advanced indexing through leading dims that
// pass through (e.g. DLRM's `Z[:, li, lj]` where Z has
// shape `(B, F, F)` and output[b, p] = Z[b, li[p], lj[p]]).
//
// Strategy: reduce to the proven single-index simple
// case. Combine the multi-axis indices into one (`li * F
// + lj`) and reshape the source so the indexed region
// becomes a single dim. Then take the exact `gather_elements`
// path the rest of this translator already uses.
//
// Supported shape pattern (DLRM): exactly one leading
// passthrough dim, no trailing dims after the indexed
// region, per-axis indices all 1-D of the same length.
if first_non_none_dim > 0 {
let src_dims = source.shape.dims;
let src_rank = src_dims.len();
let n_idx = index_names.len();
let trailing_start = first_non_none_dim + n_idx;
anyhow::ensure!(
first_non_none_dim == 1,
"index.Tensor: leading-dim passthrough only supported for \
exactly one leading dim (got {first_non_none_dim})."
);
anyhow::ensure!(
trailing_start == src_rank,
"index.Tensor: trailing dims after indexed region not yet supported."
);
let mut idx_tensors: Vec<GraphTensor> = Vec::with_capacity(n_idx);
for n in &index_names {
idx_tensors.push(self.get_tensor(&n.name)?.cast(DType::Int));
}
let idx0_shape = idx_tensors[0].shape.dims;
anyhow::ensure!(
idx0_shape.len() == 1,
"index.Tensor: only 1-D per-axis indices supported (got rank {})",
idx0_shape.len()
);
for it in idx_tensors.iter().skip(1) {
anyhow::ensure!(
it.shape.dims == idx0_shape,
"index.Tensor: per-axis indices must share a common shape"
);
}
// strides over indexed axes (no trailing dims).
let mut strides_idx: Vec<Expression> = vec![Expression::from(1usize); n_idx];
for i in (0..n_idx - 1).rev() {
strides_idx[i] =
strides_idx[i + 1] * src_dims[first_non_none_dim + i + 1];
}
// combined[p] = sum_i idx_i * stride_i (1-D)
let mut combined: Option<GraphTensor> = None;
for (i, it) in idx_tensors.into_iter().enumerate() {
let weighted = if strides_idx[i].to_usize() == Some(1) {
it
} else {
it * strides_idx[i]
};
combined = Some(match combined {
Some(acc) => {
let (a, b) = broadcast_binary(acc, weighted);
a + b
}
None => weighted,
});
}
let combined = combined.context("index.Tensor: no indices")?;
// Indexed region size, then a (leading, indexed_size) reshape.
let mut indexed_size = Expression::from(1usize);
for d in &src_dims[first_non_none_dim..trailing_start] {
indexed_size *= *d;
}
let leading_dim = src_dims[0];
let flat_source =
reshape_tensor(source, vec![leading_dim, indexed_size]);
// Now dispatch through the exact single-index simple
// case lowering — known-good. Add unit leading dims
// to match flat_source rank, then expand to the full
// (leading_dim, pair_count) shape.
let mut expanded = combined;
let flat_rank = 2; // (leading, indexed_size)
for _ in 0..(flat_rank - expanded.shape.len()) {
expanded = expanded.expand_dim(0, Expression::from(1usize));
}
let idx_dim_size = expanded.shape.dims[1];
let mut target: Vec<Expression> = vec![leading_dim, indexed_size];
target[1] = idx_dim_size;
expanded.shape.expand(target);
return Ok(flat_source.gather_elements(expanded, 1));
}
} else {
bail!(
"index.Tensor: unsupported indices format: {:?}",
@@ -552,4 +886,394 @@ impl<'a> Translator<'a> {
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
}
/// DLRM PairwiseDot peephole: detect
/// `aten.index.Tensor(bmm, [None, li, lj])`
/// where
/// `bmm = aten.bmm.default(T, T_permuted)`
/// `T_permuted = aten.permute.default(T, [0, 2, 1])`
/// `T = aten.cat.default([unsqueeze_a, unsqueeze_b, …], dim=1)`
/// each `unsqueeze_k = aten.unsqueeze.default(t_k, 1)`
/// and lower to `dlrm_pairwise_dot_lower_tri([t_0, t_1, …])`.
///
/// Why this matters: at DLRM nc=3 the generic lowering produces
/// ~80 CUDA-graph kernels from the cat+bmm+gather chain alone (the
/// `pad_along + add` decomposition of cat fans out into many
/// Iota/Cast/Gather/FusedRegion launches). The fused kernel
/// computes the F(F-1)/2 dot products directly with one launch.
///
/// Returns `Ok(Some(out))` on match, `Ok(None)` if the pattern
/// doesn't apply, `Err(_)` only if matching diagnostics surface a
/// genuine bug.
fn try_translate_pairwise_dot_lower_tri(
&mut self,
node: &Node,
) -> Result<Option<GraphTensor>> {
// CUDA-only fast path. The kernel is in luminal_cuda_lite.
#[cfg(not(feature = "cuda"))]
{
let _ = node;
return Ok(None);
}
#[cfg(feature = "cuda")]
{
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
if !backend_is_cuda {
return Ok(None);
}
// 1. Detect [None, li, lj] index list.
let opt_tensors =
match node.inputs.get(1).and_then(|i| i.arg.as_optional_tensors()) {
Some(t) => t,
None => {
return Ok(None);
}
};
// Expect [None, li, lj]: three entries, first None, last two
// are tensors.
if opt_tensors.len() != 3 {
return Ok(None);
}
use crate::pt2_schema::OptionalTensorEntry;
let (li_name, lj_name) =
match (&opt_tensors[0], &opt_tensors[1], &opt_tensors[2]) {
(OptionalTensorEntry::None(_), OptionalTensorEntry::Tensor(li), OptionalTensorEntry::Tensor(lj)) => {
(li.as_tensor.name.clone(), lj.as_tensor.name.clone())
}
_ => {
return Ok(None);
}
};
// 2. Source must be a bmm of (T, T_permuted).
let source_name = match node.inputs.first().and_then(|i| i.arg.as_tensor_name()) {
Some(s) => s.to_string(),
None => {
return Ok(None);
}
};
let bmm_info = match self.node_chain.get(&source_name) {
Some(x) => x.clone(),
None => {
return Ok(None);
}
};
if bmm_info.0 != "torch.ops.aten.bmm.default" {
return Ok(None);
}
let bmm_inputs = match self.op_inputs.get(&source_name) {
Some(v) => v.clone(),
None => {
return Ok(None);
}
};
if bmm_inputs.len() != 2 {
return Ok(None);
}
let (bmm_a, bmm_b) = (bmm_inputs[0].clone(), bmm_inputs[1].clone());
// 3. Both bmm inputs must descend from the same cat — one
// directly, one via a [0, 2, 1] permute. The permute is
// already recorded in `transpose_2d_source` (we extended
// it to cover 3-D `[0, 2, 1]` for bmm fast paths).
let permute_src = self.transpose_2d_source.get(&bmm_b).cloned();
let (cat_name, _has_transpose) = if permute_src.as_deref() == Some(bmm_a.as_str()) {
(bmm_a.clone(), true)
} else if self
.transpose_2d_source
.get(&bmm_a)
.map(|s| s.as_str() == bmm_b.as_str())
.unwrap_or(false)
{
(bmm_b.clone(), true)
} else {
return Ok(None);
};
let cat_info = match self.node_chain.get(&cat_name) {
Some(x) => x.clone(),
None => {
return Ok(None);
}
};
if cat_info.0 != "torch.ops.aten.cat.default" {
return Ok(None);
}
let cat_inputs = match self.op_inputs.get(&cat_name) {
Some(v) => v.clone(),
None => {
return Ok(None);
}
};
if cat_inputs.len() < 2 {
return Ok(None);
}
// 4. Each cat input should be `unsqueeze(t_k, 1)` — peel the
// unsqueeze so we recover the original (B, D) tensor.
// Also track the FX *source name* of each feature so the
// multi-call fusion below can recognize emb-bag-rooted
// features and emit the stacked path.
let mut feature_tensors: Vec<GraphTensor> = Vec::with_capacity(cat_inputs.len());
let mut feature_source_names: Vec<String> = Vec::with_capacity(cat_inputs.len());
for ci in &cat_inputs {
let unsqueeze_info = match self.node_chain.get(ci) {
Some(x) => x.clone(),
None => {
return Ok(None);
}
};
if unsqueeze_info.0 != "torch.ops.aten.unsqueeze.default" {
return Ok(None);
}
// unsqueeze's first input is the source tensor name.
let src = unsqueeze_info.1;
let t = self.get_tensor(&src)?;
if t.dtype != DType::F32 || t.shape.dims.len() != 2 {
return Ok(None);
}
feature_tensors.push(t);
feature_source_names.push(src);
}
// 5. Sanity check on li/lj — they must be the strict lower-tri
// pair table for F = feature_tensors.len(). We don't
// materialize them; just verify the index buffers have
// the right length, then trust they're tril-indices.
// A user passing arbitrary indices through this exact
// chain would silently get tril results; gating on
// `index buffer length == F*(F-1)/2` catches the common
// case without invasive constant-folding work.
let f = feature_tensors.len();
let pair_count = f * (f - 1) / 2;
let li_t = self.get_tensor(&li_name)?;
let lj_t = self.get_tensor(&lj_name)?;
if li_t.shape.dims.len() != 1
|| lj_t.shape.dims.len() != 1
|| li_t.shape.dims[0].to_usize() != Some(pair_count)
|| lj_t.shape.dims[0].to_usize() != Some(pair_count)
{
return Ok(None);
}
// 6. Multi-call fusion: detect when exactly one feature is the
// dense MLP output and all others came from
// `_embedding_bag.default` calls with matching `(rows, D)`.
// Then collapse the N separate per-table embedding-bag
// kernels into one fused `stacked_embedding_bag_sum_kernel`
// and use the stacked-variant pairwise dot. Mirrors the
// hand-written DLRM rust example's kernel shape.
if let Some((dense, emb_stack)) =
self.try_fuse_stacked_emb_bag(&feature_tensors, &feature_source_names)?
{
let out = luminal_cuda_lite::kernel::dlrm_pairwise_dot_lower_tri_stacked(
dense, emb_stack,
);
return Ok(Some(out));
}
// Fallback: per-feature variadic kernel. The bmm, cat,
// and unsqueeze nodes left dangling in the HLIR get picked
// up by `Translator::dce()` after every FX node is translated
// (walks back from `Output` HLIR sinks and drops everything
// unreachable). luminal's egglog optimizer leaves some of
// these subgraphs alive on its own, so the explicit pass is
// load-bearing for this peephole.
let out = luminal_cuda_lite::kernel::dlrm_pairwise_dot_lower_tri(feature_tensors);
Ok(Some(out))
}
}
/// Multi-call fusion: scan a list of feature sources for the DLRM
/// pattern "1 dense MLP output + N `_embedding_bag.default` outputs"
/// and, when matched, emit a single `stacked_embedding_bag_sum_kernel`
/// over a concatenated weight tensor. Returns `(dense, emb_stack)`
/// where `dense` is the (B, D) MLP output and `emb_stack` is
/// `(B, num_emb, D)` for the stacked-pairwise-dot kernel to consume.
///
/// Requirements:
/// * All embedding-bag weights are 2-D F32 with the same `(rows, D)`
/// * All bag indices have the same `(batch, bag_size)` shape
/// * The dense feature is the *first* cat input (DLRMv1 emits
/// `cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)`)
#[cfg(feature = "cuda")]
fn try_fuse_stacked_emb_bag(
&mut self,
feature_tensors: &[GraphTensor],
feature_source_names: &[String],
) -> Result<Option<(GraphTensor, GraphTensor)>> {
use crate::pt2_schema::OptionalTensorEntry;
if feature_tensors.len() < 2 {
return Ok(None);
}
// Look up each feature source's producing FX node target. The
// dense MLP output has target ≠ `_embedding_bag*`. We require the
// *first* feature to be dense (DLRMv1 convention) and the rest to
// be emb-bag.
let producer_target = |name: &str| -> Option<String> {
self.node_chain.get(name).map(|(target, _)| target.clone())
};
let dense_target = producer_target(&feature_source_names[0]);
if let Some(t) = &dense_target
&& (t == "torch.ops.aten._embedding_bag.default"
|| t == "torch.ops.aten._embedding_bag_forward_only.default")
{
return Ok(None);
}
for src in feature_source_names.iter().skip(1) {
match producer_target(src).as_deref() {
Some("torch.ops.aten._embedding_bag.default")
| Some("torch.ops.aten._embedding_bag_forward_only.default") => {}
_ => return Ok(None),
}
}
// For each emb-bag feature source, pull (weight, indices, offsets)
// by walking back through the parsed FX node array. The FX nodes
// for emb-bag store the output tensor name(s) in `outputs[0].as_tensors`.
let mut emb_weights: Vec<GraphTensor> = Vec::new();
let mut emb_indices: Vec<GraphTensor> = Vec::new();
let mut emb_rows: Option<usize> = None;
let mut emb_d: Option<usize> = None;
let mut emb_batch: Option<usize> = None;
let mut emb_bag: Option<usize> = None;
for src in feature_source_names.iter().skip(1) {
// Find the FX emb-bag node whose primary output is `src`.
// emb-bag is a multi-output op: `outputs[0].as_tensors[0].name`
// is what downstream `getitem(node, 0)` references, which is
// what our translator stores in `tensors`. Some exports drop
// the getitem entirely and just inline the name as a single
// tensor output, so check both shapes.
let node_opt = self
.parsed
.program
.graph_module
.graph
.nodes
.iter()
.find(|n| {
if n.target != "torch.ops.aten._embedding_bag.default"
&& n.target != "torch.ops.aten._embedding_bag_forward_only.default"
{
return false;
}
let Some(out) = n.outputs.first() else {
return false;
};
if let Some(ts) = out.as_tensors.as_ref() {
return ts.iter().any(|t| t.name == *src);
}
if let Some(t) = out.as_tensor.as_ref() {
return t.name == *src;
}
false
});
let Some(node) = node_opt else {
return Ok(None);
};
// _embedding_bag(weight, indices, offsets, ...)
let weight_name = node
.inputs
.first()
.and_then(|i| i.arg.as_tensor_name())
.map(|s| s.to_string());
let indices_name = node
.inputs
.get(1)
.and_then(|i| i.arg.as_tensor_name())
.map(|s| s.to_string());
let offsets_name = node
.inputs
.get(2)
.and_then(|i| i.arg.as_tensor_name())
.map(|s| s.to_string());
let mode = self.get_int_arg(node, 4).unwrap_or(0);
if mode != 0 {
return Ok(None);
}
// per_sample_weights at index 6
let has_psw = node
.inputs
.get(6)
.and_then(|i| i.arg.as_tensor_name())
.is_some();
if has_psw {
return Ok(None);
}
let (Some(wn), Some(in_), Some(on)) = (weight_name, indices_name, offsets_name) else {
return Ok(None);
};
let w = self.get_tensor(&wn)?;
let i = self.get_tensor(&in_)?;
let o = self.get_tensor(&on)?;
if w.dtype != DType::F32 || w.shape.dims.len() != 2 {
return Ok(None);
}
if i.shape.dims.len() != 1 || o.shape.dims.len() != 1 {
return Ok(None);
}
let rows = w.shape.dims[0].to_usize().ok_or_else(|| {
anyhow::anyhow!("stacked emb-bag fusion: rows must be static")
})?;
let d = w.shape.dims[1].to_usize().ok_or_else(|| {
anyhow::anyhow!("stacked emb-bag fusion: D must be static")
})?;
let n_idx = i.shape.dims[0].to_usize().ok_or_else(|| {
anyhow::anyhow!("stacked emb-bag fusion: indices length must be static")
})?;
let batch = o.shape.dims[0].to_usize().ok_or_else(|| {
anyhow::anyhow!("stacked emb-bag fusion: offsets length must be static")
})?;
if n_idx % batch != 0 {
return Ok(None);
}
let bag = n_idx / batch;
// All tables must agree on (rows, d, batch, bag).
if emb_rows.get_or_insert(rows) != &rows {
return Ok(None);
}
if emb_d.get_or_insert(d) != &d {
return Ok(None);
}
if emb_batch.get_or_insert(batch) != &batch {
return Ok(None);
}
if emb_bag.get_or_insert(bag) != &bag {
return Ok(None);
}
// Reshape indices to (B, bag) of int32 for the kernel.
let i_int = i.cast(DType::Int);
let new_shape = ShapeTracker::new(vec![
Expression::from(batch),
Expression::from(bag),
]);
let indices_2d = GraphTensor {
id: i_int.id,
graph_ref: i_int.graph_ref,
shape: new_shape,
dtype: i_int.dtype,
};
emb_weights.push(w);
emb_indices.push(indices_2d);
// Silence unused warning when no tables match the size.
let _ = OptionalTensorEntry::None;
}
let _ = emb_rows; // (already validated via per-table equality)
// Use the multi-table kernel: takes N (weight, idx) pairs and
// produces (B, num_emb, D) in one launch. Crucially this avoids
// the `concat_along`-of-persistent-weights expansion (which would
// emit pad+add HLIR kernels per pair) — the kernel reads each
// table's weight pointer directly via a packed staging buffer.
let emb_stack = luminal_cuda_lite::kernel::multi_table_embedding_bag_sum_kernel(
emb_weights,
emb_indices,
);
Ok(Some((feature_tensors[0], emb_stack)))
}
#[cfg(not(feature = "cuda"))]
fn try_fuse_stacked_emb_bag(
&mut self,
_feature_tensors: &[GraphTensor],
_feature_source_names: &[String],
) -> Result<Option<(GraphTensor, GraphTensor)>> {
Ok(None)
}
}

View File

@@ -43,6 +43,63 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
# Pre-resolve opaque integer ids for inputs and outputs so the
# per-iter `run_with_ptrs` path can skip the string-keyed lookup
# inside Rust. One pyo3 call per name at compile time, no per-iter
# name handling thereafter. Falls back to None when the Rust side
# is an older build that doesn't expose `tensor_id` /
# `run_with_ptrs`.
self._batched_ptrs_supported = (
self._supports_device_ptrs
and hasattr(graph_result, "tensor_id")
and hasattr(graph_result, "run_with_ptrs")
)
if self._batched_ptrs_supported:
self._input_ids = [graph_result.tensor_id(n) for n in self._input_names]
self._output_ids = [graph_result.tensor_id(n) for n in self._output_names]
else:
self._input_ids = None
self._output_ids = None
# Output dtype codes — cached at __init__ instead of re-fetched in
# each call (was a per-iter pyo3 attribute access on the dynamic
# path; for static-shape models the codes don't change).
self._output_dtype_codes_cached = (
None if self._has_dynamic_dims else list(graph_result.output_dtypes)
)
# Per-input "have-we-seen-this-ptr-before" cache. In a hot bench
# loop the same user tensor objects are passed each iter, so most
# of the per-iter Python work (detach + contiguous + dtype cast +
# data_ptr + numel + element_size) repeats with identical inputs.
# Cache (id(orig_tensor), orig_data_ptr, cast_tensor, cast_ptr,
# cast_n_bytes) per input slot; on hit, skip everything and rely
# on luminal's previously-registered pointer (the runtime keeps
# CudaInput::Ptr entries across `execute()` calls thanks to the
# consume-step filter in `cuda_lite/runtime.rs`). The cast tensor
# reference is held inside the cache so PyTorch's caching
# allocator can't recycle the converted buffer.
#
# Sharp edge: callers that mutate a user tensor in place via
# `.copy_(...)` keep the same `id()` and `data_ptr()` and will
# silently get stale cached data. Fresh-tensor callers
# (`make_inputs(...)` each iter, or new outputs from upstream)
# cache-miss naturally and pay the full cold-path cost. If a
# future model needs in-place input mutation, swap this check
# for one that also looks at `tensor._version` (autograd's
# mutation counter) — but PyTorch flags `_version` as private,
# so don't reach for it unless an actual model needs it.
self._input_cache_ids = [None] * len(self._input_names)
self._input_cache_orig_ptrs = [0] * len(self._input_names)
self._input_cache_cast_tensors = [None] * len(self._input_names)
self._input_cache_specs = [None] * len(self._input_names)
# Cached output tensors mirror the input-side cache. For static-
# shape models the output tensors can be reused across calls if
# the input device is unchanged — saves ~3 μs/output of
# `torch.empty` + the FFI to register the device pointer.
# NB: callers that stash the returned tensor must `.clone()`
# before the next call; the default contract returns a fresh
# tensor each call so leave this opt-in via env var for now.
self._output_cache_tensors = None
self._output_cache_specs = None
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -87,6 +144,90 @@ class CompiledModel:
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
# Batched CUDA fast path. When all user inputs are on CUDA and all
# outputs are floating-point, collapse the per-iter `set_input_device_ptr`
# × N, `set_output_device_ptr` × M, and `run` calls into a single
# `run_with_ptrs` FFI crossing. The Rust side iterates the
# (id, ptr, n_bytes) tuples without paying per-call pyo3
# marshalling cost.
all_cuda = bool(user_inputs) and all(t.is_cuda for t in user_inputs)
if self._batched_ptrs_supported and all_cuda:
output_shapes = (
self._graph.resolve_output_shapes()
if self._has_dynamic_dims
else self._output_shapes
)
output_dtype_codes = (
self._graph.output_dtypes
if self._output_dtype_codes_cached is None
else self._output_dtype_codes_cached
)
output_dtypes = [
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
for i in range(len(self._output_names))
]
if all(d.is_floating_point for d in output_dtypes):
# Build the input-spec list. Hot path: when the user passes
# the same tensor object as last call AND that tensor's
# data_ptr is unchanged, the previous registration is still
# valid in the luminal runtime — skip both the Python-side
# cast/contiguous/data_ptr work AND the Rust-side
# `set_device_ptr` call by omitting it from `input_specs`.
input_specs = []
_cache_ids = self._input_cache_ids
_cache_orig = self._input_cache_orig_ptrs
_cache_cast = self._input_cache_cast_tensors
_cache_spec = self._input_cache_specs
for i, (id_, tensor, expected_dtype) in enumerate(zip(
self._input_ids, user_inputs, self._input_dtypes
)):
orig_id = id(tensor)
orig_ptr = tensor.data_ptr()
if _cache_ids[i] == orig_id and _cache_orig[i] == orig_ptr:
# Pointer unchanged — luminal already has the
# registration. Skip everything.
continue
# Cold path: cast (if needed) and update cache.
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
cast_ptr = t.data_ptr()
spec = (id_, cast_ptr, n_bytes)
input_specs.append(spec)
_cache_ids[i] = orig_id
_cache_orig[i] = orig_ptr
_cache_cast[i] = t # keep alive
_cache_spec[i] = spec
# Outputs: pre-allocate fresh each call so the caller gets
# a unique tensor (matches the unbatched path's contract;
# callers that stash results don't get blown away on the
# next iteration). The output_specs list is always passed
# through to `run_with_ptrs`.
output_tensors = []
output_specs = []
for id_, shape, dt in zip(self._output_ids, output_shapes, output_dtypes):
out = torch.empty(shape, dtype=dt, device=input_device)
output_specs.append((id_, out.data_ptr(), out.numel() * out.element_size()))
output_tensors.append(out)
zero_copy_flags = self._graph.run_with_ptrs(input_specs, output_specs)
# For any output the runtime had to alias (not zero-copy),
# request the DtoD copy explicitly into the registered buffer.
# In DLRMv1 and similar models this never fires, but it's the
# same fallback the unbatched path has.
for ok, name, tensor in zip(zero_copy_flags, self._output_names, output_tensors):
if not ok:
self._graph.copy_output_to_device_ptr(
name, tensor.data_ptr(), tensor.numel() * tensor.element_size()
)
# Return tuple unconditionally — the unbatched path below
# also returns `tuple(outputs)` even for single-output
# models, and torch.compile / dynamo's output handling
# depends on that contract. Returning a bare Tensor here
# made dynamo iterate the first dim and reshape the
# output to a 1-element slice.
return tuple(output_tensors)
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs, keep references alive so the caching allocator doesn't

11
examples/dlrm/.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
# Correctness-test fixtures generated by `correctness_dump.py`.
# Regenerate with: python correctness_dump.py --num-cat N --rows R --out-dir weights_N
weights/
weights_*/
# Python bytecode
__pycache__/
# Intermediate sweep outputs from sweep_all.py — `results.csv` is the
# canonical published result; everything else is regenerable.
quick*.csv

18
examples/dlrm/Cargo.toml Normal file
View File

@@ -0,0 +1,18 @@
[package]
name = "dlrm"
version = "0.1.0"
edition = "2024"
[[bin]]
name = "dlrm"
path = "src/main.rs"
[[bin]]
name = "check"
path = "src/bin/check.rs"
[dependencies]
luminal = { path = "../.." }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
rand = "0.9.2"
bytemuck = "1"

236
examples/dlrm/RESULTS.md Normal file
View File

@@ -0,0 +1,236 @@
# DLRMv1 via `torch.compile(backend=luminal_backend)` — sweep results
## TL;DR
- **Compiles end-to-end**: vanilla DLRMv1 (`nn.EmbeddingBag` × num_cat, then
pairwise-dot interaction, then top MLP) lands on `torch.compile` with
`backend=luminal_backend` for all 55 (`batch`, `num_cat`) cells in the
sweep.
- **Correctness**: max abs diff ≤ 1.8 × 10⁻⁷ vs PyTorch eager at
`num_cat ∈ {2, 4, 8, 16, 32}` (essentially fp32 noise).
- **Beats `pt_eager` in 55/55 cells (100%)**, by **4.77.1×**.
- **Beats `graph_safe_inductor_cg` at the highest cell** (`nc=32, b=2048`):
174 μs vs 241 μs (1.38× faster). Still slower at smaller cells —
100 μs of fixed Python-wrapper overhead per call vs PT-CUDAGraph's
~22 μs replay floor.
## Hardware / setup
- NVIDIA GH200 480GB, CUDA 12.8, driver 570.148.08.
- DLRMv1 inline (matches `examples/dlrm/sweep_pytorch.py`'s `DLRMv1`).
Fixed dials: `m_den=3`, `m_spa=16`, `bag=2`, `rows=4096`.
- Harness: 5 rounds × 20 iters × 10 warmup before round 0,
median-of-round-medians, CUDA-event timing.
## Sweep dimensions
| dim | values |
|---|---|
| `batch` | 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 |
| `num_cat` | 2, 4, 8, 16, 32 |
55 cells × 3 variants = 165 timings. CSV: `results.csv`.
## Variants
- `pt_eager`: vanilla DLRMv1 forward, no compile, no graph capture.
- `graph_safe_inductor_cg`: wrap DLRMv1 in `GraphSafeDLRM` (pre-bakes the
`li`/`lj` lower-tri index buffers so the forward is capturable), then
`torch.compile(backend="inductor", mode="max-autotune-no-cudagraphs",
fullgraph=False, dynamic=False)`, then manual
`torch.cuda.CUDAGraph` capture/replay. **This is the named PT baseline.**
- `luminal_compiled`: `torch.compile(backend=luminal_backend, fullgraph=
False, dynamic=False)`, no external CUDAGraph wrap. luminal's
`cuda_lite` runtime captures/replays a CUDA graph internally on
every `execute()` call.
## Wall-clock results (median-of-round-medians, μs)
### `luminal_compiled`
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|---|---|---|---|---|---|---|
| 2 | 101 | 95 | 96 | 100 | 98 | 102 |
| 4 | 100 | 99 | 110 | 103 | 104 | 112 |
| 8 | 109 | 106 | 116 | 110 | 111 | 121 |
| 16 | 131 | 133 | 134 | 136 | 140 | 135 |
| 32 | 172 | 172 | 178 | 174 | 176 | 174 |
### `graph_safe_inductor_cg` — the named PT baseline
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|---|---|---|---|---|---|---|
| 2 | 22 | 24 | 26 | 27 | 30 | 34 |
| 4 | 27 | 29 | 31 | 33 | 35 | 45 |
| 8 | 46 | 48 | 52 | 56 | 57 | 66 |
| 16 | 75 | 79 | 91 | 96 | 96 | 121 |
| 32 | 132 | 138 | 164 | 157 | 170 | 241 |
### `pt_eager`
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|---|---|---|---|---|---|---|
| 2 | 506 | 453 | 472 | 496 | 578 | 555 |
| 4 | 514 | 488 | 569 | 535 | 605 | 588 |
| 8 | 624 | 632 | 709 | 639 | 676 | 604 |
| 16 | 786 | 783 | 879 | 813 | 862 | 865 |
| 32 | 1125 | 1133 | 1180 | 1169 | 1176 | 1233 |
## Speedup vs `graph_safe_inductor_cg` (>1 = luminal faster)
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|---|---|---|---|---|---|---|
| 2 | 0.22× | 0.25× | 0.27× | 0.28× | 0.30× | 0.33× |
| 4 | 0.27× | 0.29× | 0.29× | 0.32× | 0.34× | 0.40× |
| 8 | 0.42× | 0.45× | 0.45× | 0.51× | 0.52× | 0.55× |
| 16 | 0.57× | 0.60× | 0.68× | 0.71× | 0.68× | 0.90× |
| 32 | 0.77× | 0.80× | 0.92× | 0.91× | 0.97× | **1.38×** |
## Speedup vs `pt_eager` (>1 = luminal faster)
luminal_compiled is **4.77.1× faster** at every cell. Won at 55/55 cells (100%).
## What changed to get here
Three rounds of wrapper-overhead reduction (all in `crates/luminal_python/`
and `crates/luminal_cuda_lite/`):
### Round 1: Translator + cuda_lite kernel matching
Made `torch.compile(model, backend=luminal_backend)` emit the same kernel
set as the hand-written rust DLRM:
| FX subgraph | luminal kernel |
|---|---|
| `addmm(bias, x, weight.t())` + `relu/sigmoid` consumer | `linear_bias_(relu|sigmoid)` (one fused kernel per MLP layer) |
| N × `_embedding_bag(W_k, idx_k, off_k)` | `multi_table_embedding_bag_sum_kernel` (one fused kernel for all tables) |
| `index.Tensor(bmm(cat(unsqueezes), permute([0,2,1])), [None, li, lj])` | `dlrm_pairwise_dot_lower_tri_stacked` (one fused kernel) |
| `bmm(A, permute(B, [0,2,1]))` (generic) | `matmul_3d_t` |
Plus a post-translation DCE pass to drop the now-unreachable
`bmm/cat/permute` chain superseded by the pairwise-dot peephole.
### Round 2: Batched FFI + by-id setter
Added `tensor_id(name) → u32` and `run_with_ptrs(inputs, outputs)` on the
Rust side (one pyo3 hop instead of N), and cached the IDs once at
`CompiledModel.__init__`. The Python wrapper now passes
`[(id, ptr, n_bytes), …]` lists instead of N separate
`set_input_device_ptr(name, ptr, n_bytes)` calls.
### Round 3: Skip re-registration on unchanged inputs
The Python wrapper now caches `(id(tensor), data_ptr)` per input slot.
On a hot bench loop the same user tensors are passed each iteration —
all 65 inputs hit the cache, the `input_specs` list passed to
`run_with_ptrs` is empty, and the runtime relies on the previously-
registered pointers. Required one runtime change: in
`cuda_lite/runtime.rs::execute()`'s post-run consume step, don't drop
`CudaInput::Ptr` entries — they're non-owning views over caller memory
and must persist across `execute()` calls for the skip path to work.
## Where the per-iter time goes (nc=32, b=2048)
Before any of the rounds:
| section | μs/iter |
|---|---|
| `set_input_device_ptr` × 65 (Python loop + FFI) | 995 |
| `_graph.run()` (CUDA-graph replay) | 61 |
| set_output + alloc + collect | 13 |
| **total** | **1069** |
After all three rounds:
| section | μs/iter |
|---|---|
| input cache check (0 cold-path inputs after warmup) | 15 |
| `torch.empty` + register output | 11 |
| `run_with_ptrs` (batched FFI, all caches hit, runtime executes graph) | 56 |
| **total** | **82** |
**13× per-iter reduction**, none of it in the kernels themselves.
## Hand-written rust DLRM (`examples/dlrm/src/main.rs`)
The hand-written rust uses the same `cuda_lite` kernels directly — no
Python, no `torch.compile`. On `nc=32, b=2048`: **104 μs**
(median-of-round-medians with explicit `synchronize_stream` for accurate
GPU time). The `luminal_compiled` figure for the same cell is 174 μs —
the residual 70 μs gap is roughly:
- `torch.empty` + output registration (~11 μs)
- Python-side cache check + 1 round-trip into `run_with_ptrs` (~10 μs)
- Marshalling 1 input spec + 1 output spec across the pyo3 boundary
- pt2 backend wrapper invocation by torch.compile (a few μs of
dynamo/eval_frame work)
- The remaining ~40 μs is the runtime's per-call exec_op iteration +
buffer_map building inside `cuda_lite/src/runtime.rs::execute()` —
a structural cost of how the runtime dispatches host ops, paid once
per call regardless of input count.
## What's left
In scope, deferred:
- `linear_bias_relu_split_a` peephole on top MLP first layer. Saves
one materialized `cat` + one small kernel. Modest cell-by-cell win.
Out of scope:
- The remaining ~80100 μs of wrapper / runtime fixed overhead is a
combination of `torch.compile`'s dynamo eval-frame dispatch, the
Python `__call__` setup work, and the runtime's per-execute toposort
+ buffer_map build. Closing it further would either need a deeper
rewrite of `runtime.rs::execute()` (probably worth it independent of
DLRM) or a way to skip dynamo's per-call overhead.
## Files this work touches
In scope:
- `crates/luminal_python/rust/src/translator/{mod,dispatch,movement}.rs` —
peephole infra, op handlers, multi-call fusions, post-translation DCE.
- `crates/luminal_cuda_lite/src/kernel/{embedding_bag,dlrm_interact,
matmul2d,mod}.rs` — fused kernels for embedding bag (single + multi-
table + stacked), pairwise dot (variadic + stacked), and fused-
activation linear (relu/sigmoid + split-A).
- `crates/luminal_cuda_lite/src/runtime.rs` — gate end-of-execute
stream sync on `profiling`; expose `synchronize_stream()` and
`read_per_kernel_timings_ms()` stub; preserve `CudaInput::Ptr`
entries across `execute()` calls (the unchanged-input cache fix).
Wrapper layer (relaxed scope; minimal targeted changes):
- `crates/luminal_python/rust/src/compiled_graph.rs` — `tensor_id(name)`
+ `run_with_ptrs(inputs, outputs)` for batched FFI.
- `crates/luminal_python/src/luminal/compiled_model.py` — fast-path
`run_with_ptrs` call with the per-input cache.
Leaf consumer (allowed):
- `examples/dlrm/` — bench harness, hand-written rust DLRM, correctness
check, sweep scripts, this report. Cherry-picked from
`origin/dlrm-fused-kernels` then adapted (inline DLRMv1, no
upstream `dlrm_s_pytorch` dependency, added third PT variant).
## Reproduction
From `examples/dlrm`:
```bash
# Build
(cd ../.. && cargo build -p dlrm --release)
(cd ../../crates/luminal_python && CUDARC_CUDA_VERSION=12080 \
uv run --group dev maturin develop --manifest-path rust/Cargo.toml \
--features cuda -r)
# Correctness check (PyTorch eager → luminal hand-written rust)
for nc in 2 4 8 16 32; do
python correctness_dump.py --num-cat $nc --rows 4096
(cd ../.. && ./target/release/check)
done
# Full 55-cell sweep, 3 variants
python sweep_all.py --variants pt_eager graph_safe_inductor_cg luminal_compiled
# Pretty summary
python summarize.py results.csv
```

View File

@@ -0,0 +1,348 @@
"""PyTorch reference for the DLRM `_make_dlrm_batch_2048` config.
Mirrors `bench_luminal_exact.py` from https://github.com/jss8649/tmp-dlrm-bench
in spirit (shape-for-shape, same indices, same MLPs, same interaction op)
but reimplements the model inline so we don't need the upstream
facebookresearch/dlrm package.
Measures:
* eager
* torch.compile (inductor, default)
* torch.compile (inductor, mode="reduce-overhead" → CUDA-graph capture)
* the v3 fused trick (index_select + reshape + sum on a stacked table)
with reduce-overhead — that's the WINNER in the reference repo.
Harness: 5 rounds × 20 iters, 10 warmup, median-of-round-medians (CUDA-event
timing). Matches the luminal-side harness verbatim.
Usage:
python bench_pytorch.py
"""
from __future__ import annotations
import copy
from contextlib import contextmanager
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
# ---- Config (matches `_make_dlrm_batch_2048`) -----------------------------
BATCH = 2048
M_DEN = 3
M_SPA = 16
INDICES_PER_BAG = 2
LN_EMB = np.array([4096, 2048, 1024], dtype=np.int64)
NUM_EMB = len(LN_EMB)
NUM_FEA = NUM_EMB + 1
PAIR_COUNT = NUM_FEA * (NUM_FEA - 1) // 2 # strict lower tri, no diagonal
TOP_IN = PAIR_COUNT + M_SPA # 6 + 16 = 22
LN_BOT = [M_DEN, 64, M_SPA] # [3, 64, 16]
LN_TOP_TAIL = [64, 32, 1] # ln_top = [22, 64, 32, 1]
SEED = 0
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
layers: List[nn.Module] = []
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
layers.append(nn.Linear(a, b, bias=True))
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
return nn.Sequential(*layers)
# ---- v1-shape model: EmbeddingBag per table (the "natural" expression
# a user writes; same shape as the upstream DLRM forward) ------------
class DLRMv1(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(SEED)
np.random.seed(SEED)
ln_top = [TOP_IN] + LN_TOP_TAIL # [22, 64, 32, 1]
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
# sigmoid_top in the upstream is `len(ln_top) - 2` = 2 → final linear
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
self.emb = nn.ModuleList(
[
nn.EmbeddingBag(int(n), M_SPA, mode="sum", sparse=False)
for n in LN_EMB
]
)
# Pre-compute strict lower-tri (no diagonal) indices.
li, lj = [], []
for i in range(NUM_FEA):
for j in range(i):
li.append(i)
lj.append(j)
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
def forward(self, dense_x, lS_o, lS_i):
x = self.bot(dense_x) # (B, M_SPA)
ly = [
self.emb[k](lS_i[k], lS_o[k]) for k in range(NUM_EMB)
] # each (B, M_SPA)
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1) # (B, F, M_SPA)
Z = torch.bmm(T, T.transpose(1, 2)) # (B, F, F)
Zflat = Z[:, self.li, self.lj] # (B, PAIRS)
R = torch.cat([x, Zflat], dim=1) # (B, M_SPA + PAIRS)
return self.top(R)
# ---- v3 fused: stacked embedding table, index_select + reshape + sum ------
# This is the winner per perf.md (Inductor can fuse gather+sum into
# a single Triton kernel; opaque EmbeddingBag blocks that fusion). ----
class DLRMv3(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(SEED)
np.random.seed(SEED)
total_rows = int(LN_EMB.sum())
self.num_emb = NUM_EMB
self.m_spa = M_SPA
self.L = INDICES_PER_BAG
big_w = np.empty((total_rows, M_SPA), dtype=np.float32)
starts = np.zeros(NUM_EMB, dtype=np.int64)
s = 0
for k, n in enumerate(LN_EMB):
starts[k] = s
big_w[s : s + int(n)] = np.random.uniform(
-np.sqrt(1.0 / int(n)),
np.sqrt(1.0 / int(n)),
(int(n), M_SPA),
).astype(np.float32)
s += int(n)
self.emb_weight = nn.Parameter(torch.from_numpy(big_w))
self.register_buffer("row_offsets", torch.from_numpy(starts))
ln_top = [TOP_IN] + LN_TOP_TAIL
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
li, lj = [], []
for i in range(NUM_FEA):
for j in range(i):
li.append(i)
lj.append(j)
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
def forward(self, dense_x, flat_indices):
bs = dense_x.shape[0]
gathered = torch.index_select(self.emb_weight, 0, flat_indices)
gathered = gathered.view(self.num_emb * bs, self.L, self.m_spa)
pooled = gathered.sum(dim=1) # (num_emb*B, m_spa)
ly = pooled.view(self.num_emb, bs, self.m_spa).transpose(0, 1) # (B, num_emb, m_spa)
x = self.bot(dense_x)
T = torch.cat([x.unsqueeze(1), ly], dim=1)
Z = torch.bmm(T, T.transpose(1, 2))
Zflat = Z[:, self.li, self.lj]
R = torch.cat([x, Zflat], dim=1)
return self.top(R)
# ---- Deterministic inputs matching `_make_dlrm_batch_2048` ---------------
def make_v1_inputs(device):
dense_x = (
torch.linspace(-1.0, 1.0, BATCH * M_DEN, dtype=torch.float32, device=device)
.reshape(BATCH, M_DEN)
)
total = BATCH * INDICES_PER_BAG
positions = torch.arange(total, dtype=torch.int64, device=device)
offsets = torch.arange(0, total, INDICES_PER_BAG, dtype=torch.int64, device=device)
lS_o = [offsets.clone() for _ in range(NUM_EMB)]
lS_i = [
((positions * 3 + 1) % int(LN_EMB[0])).to(torch.int64),
((positions * 5 + 2) % int(LN_EMB[1])).to(torch.int64),
((positions * 7 + 3) % int(LN_EMB[2])).to(torch.int64),
]
return dense_x, lS_o, lS_i
def make_v3_inputs(device, model: DLRMv3):
dense_x, _, lS_i = make_v1_inputs(device)
# Stack and add per-table row offsets so a single index_select pulls
# from the unified table.
stacked = torch.stack(lS_i, dim=0) # (NUM_EMB, B*L)
flat = (stacked + model.row_offsets.view(NUM_EMB, 1)).reshape(-1)
return dense_x, flat
# ---- Timing harness (mirrors bench_luminal_exact.py) ---------------------
@contextmanager
def _relaxed_dynamo_limits():
prev_r = torch._dynamo.config.recompile_limit
prev_c = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_r
torch._dynamo.config.cache_size_limit = prev_c
def _compile_inductor(model):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend="inductor")
def _compile_inductor_ro(model):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(
copy.deepcopy(model), backend="inductor", mode="reduce-overhead"
)
def _timed_cuda_runs(model, inputs, warmup, timed, mark_step=False):
with torch.no_grad():
for _ in range(warmup):
if mark_step:
torch.compiler.cudagraph_mark_step_begin()
_ = model(*inputs)
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed)]
with torch.no_grad():
for i in range(timed):
if mark_step:
torch.compiler.cudagraph_mark_step_begin()
starts[i].record()
_ = model(*inputs)
ends[i].record()
torch.cuda.synchronize()
return np.array(
[s.elapsed_time(e) for s, e in zip(starts, ends)], dtype=np.float64
)
def _timed_rounds(model, inputs, *, warmup, timed, rounds, mark_step=False):
arr = _timed_cuda_runs(model, inputs, warmup, timed, mark_step=mark_step)
round_medians = [float(np.median(arr))]
for _ in range(rounds - 1):
arr = _timed_cuda_runs(model, inputs, 0, timed, mark_step=mark_step)
round_medians.append(float(np.median(arr)))
return round_medians
def _manual_cudagraph_capture(model, inputs):
"""Capture model(inputs) into a CUDA graph and return a replayable
closure that runs zero new launches (just cuGraphLaunch)."""
# Warm up.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s), torch.no_grad():
for _ in range(3):
_ = model(*inputs)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g), torch.no_grad():
out = model(*inputs)
torch.cuda.synchronize()
def replay():
g.replay()
return out
return replay
def main():
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
print(f"device={torch.cuda.get_device_name(0)} cap={torch.cuda.get_device_capability(0)}")
print(f"float32 matmul precision: {torch.get_float32_matmul_precision()}")
device = torch.device("cuda")
torch.manual_seed(SEED)
np.random.seed(SEED)
# v1: EmbeddingBag-style
v1 = DLRMv1().to(device).eval()
v1_inputs = make_v1_inputs(device)
v1_eager = copy.deepcopy(v1).to(device).eval()
v1_ind = _compile_inductor(v1)
v1_ind_ro = _compile_inductor_ro(v1)
# v3: index_select + reshape + sum on stacked table
v3 = DLRMv3().to(device).eval()
v3_inputs = make_v3_inputs(device, v3)
v3_ind_ro = _compile_inductor_ro(v3)
# Prime
with torch.no_grad():
_ = v1_eager(*v1_inputs)
torch.compiler.cudagraph_mark_step_begin()
_ = v1_ind_ro(*v1_inputs)
_ = v1_ind(*v1_inputs)
torch.compiler.cudagraph_mark_step_begin()
_ = v3_ind_ro(*v3_inputs)
torch.cuda.synchronize()
# Manual CUDA graph capture for v3 eager
v3_eager_replay = _manual_cudagraph_capture(v3, v3_inputs)
# Manual CUDA graph capture for v1 eager
v1_eager_replay = _manual_cudagraph_capture(v1, v1_inputs)
rounds, iters, warmup = 5, 20, 10
def report(label, model, inputs, mark_step=False, *, raw_replay=False):
if raw_replay:
# Time the replay closure directly.
replay = model
with torch.no_grad():
for _ in range(warmup):
replay()
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
with torch.no_grad():
for i in range(iters):
starts[i].record()
replay()
ends[i].record()
torch.cuda.synchronize()
arr = np.array([s.elapsed_time(e) for s, e in zip(starts, ends)])
rms = [float(np.median(arr))]
for _ in range(rounds - 1):
with torch.no_grad():
for i in range(iters):
starts[i].record()
replay()
ends[i].record()
torch.cuda.synchronize()
arr = np.array(
[s.elapsed_time(e) for s, e in zip(starts, ends)]
)
rms.append(float(np.median(arr)))
else:
rms = _timed_rounds(
model, inputs, warmup=warmup, timed=iters, rounds=rounds,
mark_step=mark_step,
)
rms_sorted = sorted(rms)
med = rms_sorted[len(rms_sorted) // 2]
tput = BATCH / (med / 1000.0)
print(
f" {label:<60} median {med:7.3f} ms ({tput:>9,.0f} samples/s) "
f"round medians: [{', '.join(f'{v:.3f}' for v in rms)}]"
)
print()
print(f"PyTorch reference (5 rounds x 20 iters, 10 warmup), batch={BATCH}:")
report("v1 eager", v1_eager, v1_inputs)
report("v1 torch.compile(inductor)", v1_ind, v1_inputs)
report("v1 torch.compile(reduce-overhead)", v1_ind_ro, v1_inputs, mark_step=True)
report("v1 eager + manual CUDAGraph", v1_eager_replay, None, raw_replay=True)
report("v3 torch.compile(reduce-overhead)", v3_ind_ro, v3_inputs, mark_step=True)
report("v3 eager + manual CUDAGraph", v3_eager_replay, None, raw_replay=True)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,172 @@
"""Dump a deterministic PyTorch DLRMv1 forward to disk so a parallel
luminal binary can load the exact same weights/inputs and verify it
produces the same output.
Writes one f32 little-endian binary blob per tensor, plus a manifest
JSON describing the shapes. All paths are under `weights/`.
Usage: python correctness_dump.py [--num-cat N] [--rows R]
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import List
import numpy as np
import torch
import torch.nn as nn
SEED = 1234
BATCH = 2048
M_DEN = 3
M_SPA = 16
L = 2
LN_BOT = [M_DEN, 64, M_SPA]
LN_TOP_TAIL = [64, 32, 1]
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
layers: List[nn.Module] = []
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
layers.append(nn.Linear(a, b, bias=True))
layers.append(nn.Sigmoid() if i == sigmoid_layer else nn.ReLU())
return nn.Sequential(*layers)
class DLRMv1(nn.Module):
def __init__(self, num_cat: int, rows: int):
super().__init__()
torch.manual_seed(SEED)
np.random.seed(SEED)
self.num_cat = num_cat
ni = num_cat + 1
num_int = ni * (ni - 1) // 2 + M_SPA
ln_top = [num_int] + LN_TOP_TAIL
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
# `sigmoid_top=2` in upstream means sigmoid on the final linear,
# which corresponds to sigmoid_layer = len(layers in Sequential) - 1
# Our `_build_mlp` indexes by Linear-index, so the last linear has
# index len(ln_top) - 2.
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
self.emb = nn.ModuleList(
[nn.EmbeddingBag(rows, M_SPA, mode="sum", sparse=False) for _ in range(num_cat)]
)
li, lj = [], []
for i in range(ni):
for j in range(i):
li.append(i)
lj.append(j)
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
def forward(self, dense_x, lS_o, lS_i):
x = self.bot(dense_x)
ly = [self.emb[k](lS_i[k], lS_o[k]) for k in range(self.num_cat)]
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)
Z = torch.bmm(T, T.transpose(1, 2))
Zflat = Z[:, self.li, self.lj]
R = torch.cat([x, Zflat], dim=1)
return self.top(R)
def build_indices(table_idx: int, batch: int, bag: int, rows: int) -> np.ndarray:
# Match sweep_categories.py and the luminal Rust binary exactly.
s = 2 * table_idx + 3
o = table_idx + 1
pos = np.arange(batch * bag, dtype=np.int64)
return (pos * s + o) % rows
def build_dense_x(batch: int, m_den: int) -> np.ndarray:
total = batch * m_den
return np.linspace(-1.0, 1.0, num=total, dtype=np.float32).reshape(batch, m_den)
def write_f32(path: Path, arr: np.ndarray) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
arr.astype(np.float32, copy=False).tofile(path)
def write_i32(path: Path, arr: np.ndarray) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
arr.astype(np.int32, copy=False).tofile(path)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--num-cat", type=int, default=3)
ap.add_argument("--rows", type=int, default=4096)
ap.add_argument("--out-dir", type=str, default="weights")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
model = DLRMv1(args.num_cat, args.rows).to(device).eval()
dense_np = build_dense_x(BATCH, M_DEN)
dense = torch.from_numpy(dense_np).to(device)
offsets = torch.arange(0, BATCH * L, L, dtype=torch.int64, device=device)
lS_o = [offsets.clone() for _ in range(args.num_cat)]
lS_i = [
torch.from_numpy(build_indices(k, BATCH, L, args.rows)).to(device)
for k in range(args.num_cat)
]
with torch.no_grad():
out = model(dense, lS_o, lS_i)
print(f"output[:8] = {out.detach().cpu().numpy().flatten()[:8]}")
print(f"output stats: min={out.min().item():.6f} max={out.max().item():.6f} "
f"mean={out.mean().item():.6f}")
# ---- dump weights ----
# Bottom MLP: linears at indices 0 and 2 in the Sequential (Linear, ReLU,
# Linear). Luminal stores W as (out, in), matching PyTorch's nn.Linear.
bot_lins = [m for m in model.bot if isinstance(m, nn.Linear)]
top_lins = [m for m in model.top if isinstance(m, nn.Linear)]
for i, l in enumerate(bot_lins):
write_f32(out_dir / f"bot_{i}_w.bin", l.weight.detach().cpu().numpy())
write_f32(out_dir / f"bot_{i}_b.bin", l.bias.detach().cpu().numpy())
for i, l in enumerate(top_lins):
write_f32(out_dir / f"top_{i}_w.bin", l.weight.detach().cpu().numpy())
write_f32(out_dir / f"top_{i}_b.bin", l.bias.detach().cpu().numpy())
for k, e in enumerate(model.emb):
write_f32(out_dir / f"emb_{k}.bin", e.weight.detach().cpu().numpy())
# ---- dump inputs ----
write_f32(out_dir / "dense.bin", dense_np)
for k, idx in enumerate(lS_i):
write_i32(out_dir / f"idx_{k}.bin", idx.cpu().numpy())
# ---- dump expected output ----
write_f32(out_dir / "expected.bin", out.detach().cpu().numpy())
manifest = {
"num_cat": args.num_cat,
"rows": args.rows,
"batch": BATCH,
"m_den": M_DEN,
"m_spa": M_SPA,
"indices_per_bag": L,
"ln_bot": LN_BOT,
"ln_top": [args.num_cat * (args.num_cat + 1) // 2 + M_SPA] + LN_TOP_TAIL,
"bot_layer_shapes": [list(l.weight.shape) for l in bot_lins],
"top_layer_shapes": [list(l.weight.shape) for l in top_lins],
"output_shape": list(out.shape),
"output_head": out.detach().cpu().numpy().flatten()[:8].tolist(),
}
with open(out_dir / "manifest.json", "w") as f:
json.dump(manifest, f, indent=2)
print(f"\nWrote weights/inputs/expected to {out_dir}/")
print(f" manifest: {out_dir}/manifest.json")
return 0
if __name__ == "__main__":
raise SystemExit(main())

166
examples/dlrm/results.csv Normal file
View File

@@ -0,0 +1,166 @@
variant,num_cat,batch,m_spa,bag,rows,ms,samples_per_sec,status
pt_eager,2,2,16,2,4096,0.5056319832801819,3955.4459886524933,ok
graph_safe_inductor_cg,2,2,16,2,4096,0.022352000698447227,89477.44888621707,ok
luminal_compiled,2,2,16,2,4096,0.10134400054812431,19734.764654867537,ok
pt_eager,2,4,16,2,4096,0.4413599967956543,9062.89656752006,ok
graph_safe_inductor_cg,2,4,16,2,4096,0.022592000663280487,177053.8191644678,ok
luminal_compiled,2,4,16,2,4096,0.09467199817299843,42251.141596173125,ok
pt_eager,2,8,16,2,4096,0.4634399861097336,17262.21353309327,ok
graph_safe_inductor_cg,2,8,16,2,4096,0.02404800057411194,332667.9893966789,ok
luminal_compiled,2,8,16,2,4096,0.09742399677634239,82115.29258408182,ok
pt_eager,2,16,16,2,4096,0.4532639980316162,35299.516549920125,ok
graph_safe_inductor_cg,2,16,16,2,4096,0.023744000121951103,673854.4439783823,ok
luminal_compiled,2,16,16,2,4096,0.09547200053930283,167588.40193584614,ok
pt_eager,2,32,16,2,4096,0.5251840054988861,60931.025440507,ok
graph_safe_inductor_cg,2,32,16,2,4096,0.02582399919629097,1239157.411552122,ok
luminal_compiled,2,32,16,2,4096,0.09985600039362907,320461.46324564423,ok
pt_eager,2,64,16,2,4096,0.47198399901390076,135597.81715844796,ok
graph_safe_inductor_cg,2,64,16,2,4096,0.02619200013577938,2443494.184034204,ok
luminal_compiled,2,64,16,2,4096,0.09620799869298935,665225.3541229069,ok
pt_eager,2,128,16,2,4096,0.5091679990291595,251390.50420305296,ok
graph_safe_inductor_cg,2,128,16,2,4096,0.027855999767780304,4595060.348472987,ok
luminal_compiled,2,128,16,2,4096,0.10607999935746193,1206636.508062876,ok
pt_eager,2,256,16,2,4096,0.49588799476623535,516245.6092946552,ok
graph_safe_inductor_cg,2,256,16,2,4096,0.027456000447273254,9324009.172116116,ok
luminal_compiled,2,256,16,2,4096,0.09959999844431877,2570281.1646439577,ok
pt_eager,2,512,16,2,4096,0.499007984995842,1026035.6855898134,ok
graph_safe_inductor_cg,2,512,16,2,4096,0.036959998309612274,13852814.486380616,ok
luminal_compiled,2,512,16,2,4096,0.10465599969029427,4892218.32971973,ok
pt_eager,2,1024,16,2,4096,0.5781759917736053,1771087.0298484561,ok
graph_safe_inductor_cg,2,1024,16,2,4096,0.029888000339269638,34261241.58110951,ok
luminal_compiled,2,1024,16,2,4096,0.098191998898983,10428548.267496424,ok
pt_eager,2,2048,16,2,4096,0.5551839768886566,3688867.2678871835,ok
graph_safe_inductor_cg,2,2048,16,2,4096,0.03388800099492073,60434370.27480501,ok
luminal_compiled,2,2048,16,2,4096,0.10235200077295303,20009379.245483134,ok
pt_eager,4,2,16,2,4096,0.5139839947223663,3891.171749580102,ok
graph_safe_inductor_cg,4,2,16,2,4096,0.02729600016027689,73270.8084794982,ok
luminal_compiled,4,2,16,2,4096,0.09959999844431877,20080.32159878092,ok
pt_eager,4,4,16,2,4096,0.512287974357605,7808.10833011626,ok
graph_safe_inductor_cg,4,4,16,2,4096,0.027664000168442726,144592.24897500317,ok
luminal_compiled,4,4,16,2,4096,0.10916800051927567,36640.77367885587,ok
pt_eager,4,8,16,2,4096,0.5412319898605347,14781.092303988627,ok
graph_safe_inductor_cg,4,8,16,2,4096,0.02937600016593933,272331.15314574994,ok
luminal_compiled,4,8,16,2,4096,0.10155199840664864,78777.37637387782,ok
pt_eager,4,16,16,2,4096,0.48787200450897217,32795.48703784202,ok
graph_safe_inductor_cg,4,16,16,2,4096,0.028704000636935234,557413.5885229809,ok
luminal_compiled,4,16,16,2,4096,0.09879999980330467,161943.32016046048,ok
pt_eager,4,32,16,2,4096,0.5792959928512573,55239.46375409585,ok
graph_safe_inductor_cg,4,32,16,2,4096,0.036240000277757645,883002.2007378418,ok
luminal_compiled,4,32,16,2,4096,0.10119999945163727,316205.5353102305,ok
pt_eager,4,64,16,2,4096,0.5686399936676025,112549.24154598081,ok
graph_safe_inductor_cg,4,64,16,2,4096,0.03144000098109245,2035623.3461471149,ok
luminal_compiled,4,64,16,2,4096,0.1101440005004406,581057.5220549029,ok
pt_eager,4,128,16,2,4096,0.5731199979782104,223338.9175941239,ok
graph_safe_inductor_cg,4,128,16,2,4096,0.03721600025892258,3439380.887507165,ok
luminal_compiled,4,128,16,2,4096,0.10672000050544739,1199400.294169474,ok
pt_eager,4,256,16,2,4096,0.5351200103759766,478397.3595383469,ok
graph_safe_inductor_cg,4,256,16,2,4096,0.0326399989426136,7843137.508983668,ok
luminal_compiled,4,256,16,2,4096,0.10263999924063683,2494154.344251451,ok
pt_eager,4,512,16,2,4096,0.5590240061283112,915881.9556712956,ok
graph_safe_inductor_cg,4,512,16,2,4096,0.0337119996547699,15187470.492500357,ok
luminal_compiled,4,512,16,2,4096,0.10249600186944008,4995316.799304895,ok
pt_eager,4,1024,16,2,4096,0.6048000156879425,1693121.6492037117,ok
graph_safe_inductor_cg,4,1024,16,2,4096,0.03484800085425377,29384755.937154543,ok
luminal_compiled,4,1024,16,2,4096,0.10395199805498123,9850700.507539993,ok
pt_eager,4,2048,16,2,4096,0.5883680284023285,3480814.5601677205,ok
graph_safe_inductor_cg,4,2048,16,2,4096,0.04467200115323067,45845271.02278446,ok
luminal_compiled,4,2048,16,2,4096,0.11158400028944016,18353885.814163756,ok
pt_eager,8,2,16,2,4096,0.6239520013332367,3205.3747655692696,ok
graph_safe_inductor_cg,8,2,16,2,4096,0.04572800174355507,43736.877268683216,ok
luminal_compiled,8,2,16,2,4096,0.10910400003194809,18331.13359193389,ok
pt_eager,8,4,16,2,4096,0.6168799996376038,6484.243292617471,ok
graph_safe_inductor_cg,8,4,16,2,4096,0.045903999358415604,87138.37695857928,ok
luminal_compiled,8,4,16,2,4096,0.10825600102543831,36949.4527980954,ok
pt_eager,8,8,16,2,4096,0.6259680092334747,12780.205828403836,ok
graph_safe_inductor_cg,8,8,16,2,4096,0.04761600121855736,168010.74838855147,ok
luminal_compiled,8,8,16,2,4096,0.11219200119376183,71306.33124355768,ok
pt_eager,8,16,16,2,4096,0.6317119896411896,25327.997983840625,ok
graph_safe_inductor_cg,8,16,16,2,4096,0.047807998955249786,334672.02873261116,ok
luminal_compiled,8,16,16,2,4096,0.10617600008845329,150693.18854233247,ok
pt_eager,8,32,16,2,4096,0.6701280176639557,47752.0699873898,ok
graph_safe_inductor_cg,8,32,16,2,4096,0.0514880008995533,621504.0289178838,ok
luminal_compiled,8,32,16,2,4096,0.11286400258541107,283527.0703409942,ok
pt_eager,8,64,16,2,4096,0.7092800140380859,90232.34651098376,ok
graph_safe_inductor_cg,8,64,16,2,4096,0.052239999175071716,1225114.873863551,ok
luminal_compiled,8,64,16,2,4096,0.11633599922060966,550130.659716395,ok
pt_eager,8,128,16,2,4096,0.6788640022277832,188550.28338511224,ok
graph_safe_inductor_cg,8,128,16,2,4096,0.053888000547885895,2375296.8879640796,ok
luminal_compiled,8,128,16,2,4096,0.11127999797463417,1150251.6384766388,ok
pt_eager,8,256,16,2,4096,0.6391039788722992,400560.7983410035,ok
graph_safe_inductor_cg,8,256,16,2,4096,0.05567999929189682,4597701.207895956,ok
luminal_compiled,8,256,16,2,4096,0.10991999879479408,2328966.5466419603,ok
pt_eager,8,512,16,2,4096,0.6350559890270233,806228.1260971039,ok
graph_safe_inductor_cg,8,512,16,2,4096,0.0562559999525547,9101251.429746367,ok
luminal_compiled,8,512,16,2,4096,0.12015999853610992,4260985.404773753,ok
pt_eager,8,1024,16,2,4096,0.6764000058174133,1513897.089285977,ok
graph_safe_inductor_cg,8,1024,16,2,4096,0.05718399956822395,17907107.018254407,ok
luminal_compiled,8,1024,16,2,4096,0.11070400103926659,9249891.515996683,ok
pt_eager,8,2048,16,2,4096,0.6043839752674103,3388574.2901999685,ok
graph_safe_inductor_cg,8,2048,16,2,4096,0.06619199737906456,30940296.12479633,ok
luminal_compiled,8,2048,16,2,4096,0.12067200243473053,16971625.221084144,ok
pt_eager,16,2,16,2,4096,0.7858880162239075,2544.891840455523,ok
graph_safe_inductor_cg,16,2,16,2,4096,0.07545600086450577,26505.512843058623,ok
luminal_compiled,16,2,16,2,4096,0.13145600259304047,15214.215863474643,ok
pt_eager,16,4,16,2,4096,0.7631199955940247,5241.639615125452,ok
graph_safe_inductor_cg,16,4,16,2,4096,0.07680000364780426,52083.330859507856,ok
luminal_compiled,16,4,16,2,4096,0.12828800082206726,31179.84514816717,ok
pt_eager,16,8,16,2,4096,0.7696959972381592,10393.713919139223,ok
graph_safe_inductor_cg,16,8,16,2,4096,0.08003199845552444,99960.0179226535,ok
luminal_compiled,16,8,16,2,4096,0.1287200003862381,62150.40379113693,ok
pt_eager,16,16,16,2,4096,0.7825759947299957,20445.29874126834,ok
graph_safe_inductor_cg,16,16,16,2,4096,0.07948800176382065,201288.24029996534,ok
luminal_compiled,16,16,16,2,4096,0.13308800011873245,120221.20691366495,ok
pt_eager,16,32,16,2,4096,0.8636959791183472,37050.074069657356,ok
graph_safe_inductor_cg,16,32,16,2,4096,0.08278399705886841,386548.1389747891,ok
luminal_compiled,16,32,16,2,4096,0.13150399923324585,243338.6070886124,ok
pt_eager,16,64,16,2,4096,0.878896027803421,72818.62470120825,ok
graph_safe_inductor_cg,16,64,16,2,4096,0.09142400324344635,700034.9769149688,ok
luminal_compiled,16,64,16,2,4096,0.13363199681043625,478927.2144962949,ok
pt_eager,16,128,16,2,4096,0.8490720093250275,150752.82024872545,ok
graph_safe_inductor_cg,16,128,16,2,4096,0.0907679982483387,1410188.6399410903,ok
luminal_compiled,16,128,16,2,4096,0.1300320029258728,984373.0552467828,ok
pt_eager,16,256,16,2,4096,0.812527984380722,315066.07147212717,ok
graph_safe_inductor_cg,16,256,16,2,4096,0.09612800180912018,2663115.795419685,ok
luminal_compiled,16,256,16,2,4096,0.13608000427484512,1881246.2665929128,ok
pt_eager,16,512,16,2,4096,0.8759520053863525,584506.9100266221,ok
graph_safe_inductor_cg,16,512,16,2,4096,0.09095999971032143,5628847.863132768,ok
luminal_compiled,16,512,16,2,4096,0.13556800037622452,3776702.4561777995,ok
pt_eager,16,1024,16,2,4096,0.8617600202560425,1188265.8465587131,ok
graph_safe_inductor_cg,16,1024,16,2,4096,0.09561599791049957,10709504.919443557,ok
luminal_compiled,16,1024,16,2,4096,0.14022399485111237,7302601.819947201,ok
pt_eager,16,2048,16,2,4096,0.8651839792728424,2367126.587019415,ok
graph_safe_inductor_cg,16,2048,16,2,4096,0.12144000083208084,16864295.00961416,ok
luminal_compiled,16,2048,16,2,4096,0.13519999384880066,15147929.68326875,ok
pt_eager,32,2,16,2,4096,1.1254720091819763,1777.0321995423542,ok
graph_safe_inductor_cg,32,2,16,2,4096,0.13230399787425995,15116.70117407015,ok
luminal_compiled,32,2,16,2,4096,0.1716800034046173,11649.58038407291,ok
pt_eager,32,4,16,2,4096,1.1279360055923462,3546.300481736428,ok
graph_safe_inductor_cg,32,4,16,2,4096,0.14241600036621094,28086.731755661804,ok
luminal_compiled,32,4,16,2,4096,0.17348799854516983,23056.34991205774,ok
pt_eager,32,8,16,2,4096,1.0751680135726929,7440.69754588092,ok
graph_safe_inductor_cg,32,8,16,2,4096,0.1366880014538765,58527.4487512314,ok
luminal_compiled,32,8,16,2,4096,0.16710400581359863,47874.37596752677,ok
pt_eager,32,16,16,2,4096,1.1331039667129517,14120.504799232836,ok
graph_safe_inductor_cg,32,16,16,2,4096,0.13777600228786469,116130.52878809854,ok
luminal_compiled,32,16,16,2,4096,0.17158400267362595,93248.78631275425,ok
pt_eager,32,32,16,2,4096,1.1613919734954834,27553.143753601504,ok
graph_safe_inductor_cg,32,32,16,2,4096,0.14433600008487701,221704.91063339947,ok
luminal_compiled,32,32,16,2,4096,0.17432000488042831,183570.44001891708,ok
pt_eager,32,64,16,2,4096,1.1802719831466675,54224.789636514644,ok
graph_safe_inductor_cg,32,64,16,2,4096,0.16435199975967407,389408.10025789076,ok
luminal_compiled,32,64,16,2,4096,0.17791999876499176,359712.2327127224,ok
pt_eager,32,128,16,2,4096,1.1843199729919434,108078.9000599509,ok
graph_safe_inductor_cg,32,128,16,2,4096,0.15452799946069717,828328.8494429494,ok
luminal_compiled,32,128,16,2,4096,0.1780960038304329,718713.4873720702,ok
pt_eager,32,256,16,2,4096,1.168511986732483,219082.0487138128,ok
graph_safe_inductor_cg,32,256,16,2,4096,0.15727999806404114,1627670.4167796476,ok
luminal_compiled,32,256,16,2,4096,0.17375999689102173,1473296.5272815775,ok
pt_eager,32,512,16,2,4096,1.0871520042419434,470955.30156062293,ok
graph_safe_inductor_cg,32,512,16,2,4096,0.16273599863052368,3146200.0068125455,ok
luminal_compiled,32,512,16,2,4096,0.1789119988679886,2861742.10360135,ok
pt_eager,32,1024,16,2,4096,1.1761599779129028,870629.8626289675,ok
graph_safe_inductor_cg,32,1024,16,2,4096,0.16993600130081177,6025797.901336805,ok
luminal_compiled,32,1024,16,2,4096,0.17558399587869644,5831966.603080603,ok
pt_eager,32,2048,16,2,4096,1.2334399819374084,1660396.962958127,ok
graph_safe_inductor_cg,32,2048,16,2,4096,0.24084799736738205,8503288.473999824,ok
luminal_compiled,32,2048,16,2,4096,0.17404799908399582,11766868.971654378,ok
1 variant num_cat batch m_spa bag rows ms samples_per_sec status
2 pt_eager 2 2 16 2 4096 0.5056319832801819 3955.4459886524933 ok
3 graph_safe_inductor_cg 2 2 16 2 4096 0.022352000698447227 89477.44888621707 ok
4 luminal_compiled 2 2 16 2 4096 0.10134400054812431 19734.764654867537 ok
5 pt_eager 2 4 16 2 4096 0.4413599967956543 9062.89656752006 ok
6 graph_safe_inductor_cg 2 4 16 2 4096 0.022592000663280487 177053.8191644678 ok
7 luminal_compiled 2 4 16 2 4096 0.09467199817299843 42251.141596173125 ok
8 pt_eager 2 8 16 2 4096 0.4634399861097336 17262.21353309327 ok
9 graph_safe_inductor_cg 2 8 16 2 4096 0.02404800057411194 332667.9893966789 ok
10 luminal_compiled 2 8 16 2 4096 0.09742399677634239 82115.29258408182 ok
11 pt_eager 2 16 16 2 4096 0.4532639980316162 35299.516549920125 ok
12 graph_safe_inductor_cg 2 16 16 2 4096 0.023744000121951103 673854.4439783823 ok
13 luminal_compiled 2 16 16 2 4096 0.09547200053930283 167588.40193584614 ok
14 pt_eager 2 32 16 2 4096 0.5251840054988861 60931.025440507 ok
15 graph_safe_inductor_cg 2 32 16 2 4096 0.02582399919629097 1239157.411552122 ok
16 luminal_compiled 2 32 16 2 4096 0.09985600039362907 320461.46324564423 ok
17 pt_eager 2 64 16 2 4096 0.47198399901390076 135597.81715844796 ok
18 graph_safe_inductor_cg 2 64 16 2 4096 0.02619200013577938 2443494.184034204 ok
19 luminal_compiled 2 64 16 2 4096 0.09620799869298935 665225.3541229069 ok
20 pt_eager 2 128 16 2 4096 0.5091679990291595 251390.50420305296 ok
21 graph_safe_inductor_cg 2 128 16 2 4096 0.027855999767780304 4595060.348472987 ok
22 luminal_compiled 2 128 16 2 4096 0.10607999935746193 1206636.508062876 ok
23 pt_eager 2 256 16 2 4096 0.49588799476623535 516245.6092946552 ok
24 graph_safe_inductor_cg 2 256 16 2 4096 0.027456000447273254 9324009.172116116 ok
25 luminal_compiled 2 256 16 2 4096 0.09959999844431877 2570281.1646439577 ok
26 pt_eager 2 512 16 2 4096 0.499007984995842 1026035.6855898134 ok
27 graph_safe_inductor_cg 2 512 16 2 4096 0.036959998309612274 13852814.486380616 ok
28 luminal_compiled 2 512 16 2 4096 0.10465599969029427 4892218.32971973 ok
29 pt_eager 2 1024 16 2 4096 0.5781759917736053 1771087.0298484561 ok
30 graph_safe_inductor_cg 2 1024 16 2 4096 0.029888000339269638 34261241.58110951 ok
31 luminal_compiled 2 1024 16 2 4096 0.098191998898983 10428548.267496424 ok
32 pt_eager 2 2048 16 2 4096 0.5551839768886566 3688867.2678871835 ok
33 graph_safe_inductor_cg 2 2048 16 2 4096 0.03388800099492073 60434370.27480501 ok
34 luminal_compiled 2 2048 16 2 4096 0.10235200077295303 20009379.245483134 ok
35 pt_eager 4 2 16 2 4096 0.5139839947223663 3891.171749580102 ok
36 graph_safe_inductor_cg 4 2 16 2 4096 0.02729600016027689 73270.8084794982 ok
37 luminal_compiled 4 2 16 2 4096 0.09959999844431877 20080.32159878092 ok
38 pt_eager 4 4 16 2 4096 0.512287974357605 7808.10833011626 ok
39 graph_safe_inductor_cg 4 4 16 2 4096 0.027664000168442726 144592.24897500317 ok
40 luminal_compiled 4 4 16 2 4096 0.10916800051927567 36640.77367885587 ok
41 pt_eager 4 8 16 2 4096 0.5412319898605347 14781.092303988627 ok
42 graph_safe_inductor_cg 4 8 16 2 4096 0.02937600016593933 272331.15314574994 ok
43 luminal_compiled 4 8 16 2 4096 0.10155199840664864 78777.37637387782 ok
44 pt_eager 4 16 16 2 4096 0.48787200450897217 32795.48703784202 ok
45 graph_safe_inductor_cg 4 16 16 2 4096 0.028704000636935234 557413.5885229809 ok
46 luminal_compiled 4 16 16 2 4096 0.09879999980330467 161943.32016046048 ok
47 pt_eager 4 32 16 2 4096 0.5792959928512573 55239.46375409585 ok
48 graph_safe_inductor_cg 4 32 16 2 4096 0.036240000277757645 883002.2007378418 ok
49 luminal_compiled 4 32 16 2 4096 0.10119999945163727 316205.5353102305 ok
50 pt_eager 4 64 16 2 4096 0.5686399936676025 112549.24154598081 ok
51 graph_safe_inductor_cg 4 64 16 2 4096 0.03144000098109245 2035623.3461471149 ok
52 luminal_compiled 4 64 16 2 4096 0.1101440005004406 581057.5220549029 ok
53 pt_eager 4 128 16 2 4096 0.5731199979782104 223338.9175941239 ok
54 graph_safe_inductor_cg 4 128 16 2 4096 0.03721600025892258 3439380.887507165 ok
55 luminal_compiled 4 128 16 2 4096 0.10672000050544739 1199400.294169474 ok
56 pt_eager 4 256 16 2 4096 0.5351200103759766 478397.3595383469 ok
57 graph_safe_inductor_cg 4 256 16 2 4096 0.0326399989426136 7843137.508983668 ok
58 luminal_compiled 4 256 16 2 4096 0.10263999924063683 2494154.344251451 ok
59 pt_eager 4 512 16 2 4096 0.5590240061283112 915881.9556712956 ok
60 graph_safe_inductor_cg 4 512 16 2 4096 0.0337119996547699 15187470.492500357 ok
61 luminal_compiled 4 512 16 2 4096 0.10249600186944008 4995316.799304895 ok
62 pt_eager 4 1024 16 2 4096 0.6048000156879425 1693121.6492037117 ok
63 graph_safe_inductor_cg 4 1024 16 2 4096 0.03484800085425377 29384755.937154543 ok
64 luminal_compiled 4 1024 16 2 4096 0.10395199805498123 9850700.507539993 ok
65 pt_eager 4 2048 16 2 4096 0.5883680284023285 3480814.5601677205 ok
66 graph_safe_inductor_cg 4 2048 16 2 4096 0.04467200115323067 45845271.02278446 ok
67 luminal_compiled 4 2048 16 2 4096 0.11158400028944016 18353885.814163756 ok
68 pt_eager 8 2 16 2 4096 0.6239520013332367 3205.3747655692696 ok
69 graph_safe_inductor_cg 8 2 16 2 4096 0.04572800174355507 43736.877268683216 ok
70 luminal_compiled 8 2 16 2 4096 0.10910400003194809 18331.13359193389 ok
71 pt_eager 8 4 16 2 4096 0.6168799996376038 6484.243292617471 ok
72 graph_safe_inductor_cg 8 4 16 2 4096 0.045903999358415604 87138.37695857928 ok
73 luminal_compiled 8 4 16 2 4096 0.10825600102543831 36949.4527980954 ok
74 pt_eager 8 8 16 2 4096 0.6259680092334747 12780.205828403836 ok
75 graph_safe_inductor_cg 8 8 16 2 4096 0.04761600121855736 168010.74838855147 ok
76 luminal_compiled 8 8 16 2 4096 0.11219200119376183 71306.33124355768 ok
77 pt_eager 8 16 16 2 4096 0.6317119896411896 25327.997983840625 ok
78 graph_safe_inductor_cg 8 16 16 2 4096 0.047807998955249786 334672.02873261116 ok
79 luminal_compiled 8 16 16 2 4096 0.10617600008845329 150693.18854233247 ok
80 pt_eager 8 32 16 2 4096 0.6701280176639557 47752.0699873898 ok
81 graph_safe_inductor_cg 8 32 16 2 4096 0.0514880008995533 621504.0289178838 ok
82 luminal_compiled 8 32 16 2 4096 0.11286400258541107 283527.0703409942 ok
83 pt_eager 8 64 16 2 4096 0.7092800140380859 90232.34651098376 ok
84 graph_safe_inductor_cg 8 64 16 2 4096 0.052239999175071716 1225114.873863551 ok
85 luminal_compiled 8 64 16 2 4096 0.11633599922060966 550130.659716395 ok
86 pt_eager 8 128 16 2 4096 0.6788640022277832 188550.28338511224 ok
87 graph_safe_inductor_cg 8 128 16 2 4096 0.053888000547885895 2375296.8879640796 ok
88 luminal_compiled 8 128 16 2 4096 0.11127999797463417 1150251.6384766388 ok
89 pt_eager 8 256 16 2 4096 0.6391039788722992 400560.7983410035 ok
90 graph_safe_inductor_cg 8 256 16 2 4096 0.05567999929189682 4597701.207895956 ok
91 luminal_compiled 8 256 16 2 4096 0.10991999879479408 2328966.5466419603 ok
92 pt_eager 8 512 16 2 4096 0.6350559890270233 806228.1260971039 ok
93 graph_safe_inductor_cg 8 512 16 2 4096 0.0562559999525547 9101251.429746367 ok
94 luminal_compiled 8 512 16 2 4096 0.12015999853610992 4260985.404773753 ok
95 pt_eager 8 1024 16 2 4096 0.6764000058174133 1513897.089285977 ok
96 graph_safe_inductor_cg 8 1024 16 2 4096 0.05718399956822395 17907107.018254407 ok
97 luminal_compiled 8 1024 16 2 4096 0.11070400103926659 9249891.515996683 ok
98 pt_eager 8 2048 16 2 4096 0.6043839752674103 3388574.2901999685 ok
99 graph_safe_inductor_cg 8 2048 16 2 4096 0.06619199737906456 30940296.12479633 ok
100 luminal_compiled 8 2048 16 2 4096 0.12067200243473053 16971625.221084144 ok
101 pt_eager 16 2 16 2 4096 0.7858880162239075 2544.891840455523 ok
102 graph_safe_inductor_cg 16 2 16 2 4096 0.07545600086450577 26505.512843058623 ok
103 luminal_compiled 16 2 16 2 4096 0.13145600259304047 15214.215863474643 ok
104 pt_eager 16 4 16 2 4096 0.7631199955940247 5241.639615125452 ok
105 graph_safe_inductor_cg 16 4 16 2 4096 0.07680000364780426 52083.330859507856 ok
106 luminal_compiled 16 4 16 2 4096 0.12828800082206726 31179.84514816717 ok
107 pt_eager 16 8 16 2 4096 0.7696959972381592 10393.713919139223 ok
108 graph_safe_inductor_cg 16 8 16 2 4096 0.08003199845552444 99960.0179226535 ok
109 luminal_compiled 16 8 16 2 4096 0.1287200003862381 62150.40379113693 ok
110 pt_eager 16 16 16 2 4096 0.7825759947299957 20445.29874126834 ok
111 graph_safe_inductor_cg 16 16 16 2 4096 0.07948800176382065 201288.24029996534 ok
112 luminal_compiled 16 16 16 2 4096 0.13308800011873245 120221.20691366495 ok
113 pt_eager 16 32 16 2 4096 0.8636959791183472 37050.074069657356 ok
114 graph_safe_inductor_cg 16 32 16 2 4096 0.08278399705886841 386548.1389747891 ok
115 luminal_compiled 16 32 16 2 4096 0.13150399923324585 243338.6070886124 ok
116 pt_eager 16 64 16 2 4096 0.878896027803421 72818.62470120825 ok
117 graph_safe_inductor_cg 16 64 16 2 4096 0.09142400324344635 700034.9769149688 ok
118 luminal_compiled 16 64 16 2 4096 0.13363199681043625 478927.2144962949 ok
119 pt_eager 16 128 16 2 4096 0.8490720093250275 150752.82024872545 ok
120 graph_safe_inductor_cg 16 128 16 2 4096 0.0907679982483387 1410188.6399410903 ok
121 luminal_compiled 16 128 16 2 4096 0.1300320029258728 984373.0552467828 ok
122 pt_eager 16 256 16 2 4096 0.812527984380722 315066.07147212717 ok
123 graph_safe_inductor_cg 16 256 16 2 4096 0.09612800180912018 2663115.795419685 ok
124 luminal_compiled 16 256 16 2 4096 0.13608000427484512 1881246.2665929128 ok
125 pt_eager 16 512 16 2 4096 0.8759520053863525 584506.9100266221 ok
126 graph_safe_inductor_cg 16 512 16 2 4096 0.09095999971032143 5628847.863132768 ok
127 luminal_compiled 16 512 16 2 4096 0.13556800037622452 3776702.4561777995 ok
128 pt_eager 16 1024 16 2 4096 0.8617600202560425 1188265.8465587131 ok
129 graph_safe_inductor_cg 16 1024 16 2 4096 0.09561599791049957 10709504.919443557 ok
130 luminal_compiled 16 1024 16 2 4096 0.14022399485111237 7302601.819947201 ok
131 pt_eager 16 2048 16 2 4096 0.8651839792728424 2367126.587019415 ok
132 graph_safe_inductor_cg 16 2048 16 2 4096 0.12144000083208084 16864295.00961416 ok
133 luminal_compiled 16 2048 16 2 4096 0.13519999384880066 15147929.68326875 ok
134 pt_eager 32 2 16 2 4096 1.1254720091819763 1777.0321995423542 ok
135 graph_safe_inductor_cg 32 2 16 2 4096 0.13230399787425995 15116.70117407015 ok
136 luminal_compiled 32 2 16 2 4096 0.1716800034046173 11649.58038407291 ok
137 pt_eager 32 4 16 2 4096 1.1279360055923462 3546.300481736428 ok
138 graph_safe_inductor_cg 32 4 16 2 4096 0.14241600036621094 28086.731755661804 ok
139 luminal_compiled 32 4 16 2 4096 0.17348799854516983 23056.34991205774 ok
140 pt_eager 32 8 16 2 4096 1.0751680135726929 7440.69754588092 ok
141 graph_safe_inductor_cg 32 8 16 2 4096 0.1366880014538765 58527.4487512314 ok
142 luminal_compiled 32 8 16 2 4096 0.16710400581359863 47874.37596752677 ok
143 pt_eager 32 16 16 2 4096 1.1331039667129517 14120.504799232836 ok
144 graph_safe_inductor_cg 32 16 16 2 4096 0.13777600228786469 116130.52878809854 ok
145 luminal_compiled 32 16 16 2 4096 0.17158400267362595 93248.78631275425 ok
146 pt_eager 32 32 16 2 4096 1.1613919734954834 27553.143753601504 ok
147 graph_safe_inductor_cg 32 32 16 2 4096 0.14433600008487701 221704.91063339947 ok
148 luminal_compiled 32 32 16 2 4096 0.17432000488042831 183570.44001891708 ok
149 pt_eager 32 64 16 2 4096 1.1802719831466675 54224.789636514644 ok
150 graph_safe_inductor_cg 32 64 16 2 4096 0.16435199975967407 389408.10025789076 ok
151 luminal_compiled 32 64 16 2 4096 0.17791999876499176 359712.2327127224 ok
152 pt_eager 32 128 16 2 4096 1.1843199729919434 108078.9000599509 ok
153 graph_safe_inductor_cg 32 128 16 2 4096 0.15452799946069717 828328.8494429494 ok
154 luminal_compiled 32 128 16 2 4096 0.1780960038304329 718713.4873720702 ok
155 pt_eager 32 256 16 2 4096 1.168511986732483 219082.0487138128 ok
156 graph_safe_inductor_cg 32 256 16 2 4096 0.15727999806404114 1627670.4167796476 ok
157 luminal_compiled 32 256 16 2 4096 0.17375999689102173 1473296.5272815775 ok
158 pt_eager 32 512 16 2 4096 1.0871520042419434 470955.30156062293 ok
159 graph_safe_inductor_cg 32 512 16 2 4096 0.16273599863052368 3146200.0068125455 ok
160 luminal_compiled 32 512 16 2 4096 0.1789119988679886 2861742.10360135 ok
161 pt_eager 32 1024 16 2 4096 1.1761599779129028 870629.8626289675 ok
162 graph_safe_inductor_cg 32 1024 16 2 4096 0.16993600130081177 6025797.901336805 ok
163 luminal_compiled 32 1024 16 2 4096 0.17558399587869644 5831966.603080603 ok
164 pt_eager 32 2048 16 2 4096 1.2334399819374084 1660396.962958127 ok
165 graph_safe_inductor_cg 32 2048 16 2 4096 0.24084799736738205 8503288.473999824 ok
166 luminal_compiled 32 2048 16 2 4096 0.17404799908399582 11766868.971654378 ok

View File

@@ -0,0 +1,255 @@
//! Numerical correctness check: load weights and inputs dumped by
//! `correctness_dump.py`, run the luminal DLRM forward, and compare
//! element-wise against PyTorch's expected output.
//!
//! Usage: `cargo run --release --bin check -- [weights/]`.
use luminal::prelude::*;
use luminal_cuda_lite::kernel::{
dlrm_pairwise_dot_lower_tri_stacked, linear_bias, linear_bias_relu, linear_bias_relu_split_a,
linear_bias_sigmoid, stacked_embedding_bag_sum_kernel,
};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use rand::{SeedableRng, rngs::StdRng};
use std::fs;
use std::path::{Path, PathBuf};
const BATCH: usize = 2048;
const M_DEN: usize = 3;
const M_SPA: usize = 16;
const L: usize = 2;
const LN_BOT: &[usize] = &[M_DEN, 64, M_SPA];
const LN_TOP_TAIL: &[usize] = &[64, 32, 1];
#[derive(Clone, Copy)]
enum Act {
None,
Relu,
Sigmoid,
}
struct LinearWB {
w: GraphTensor,
b: GraphTensor,
act: Act,
}
impl LinearWB {
fn new(cx: &mut Graph, in_dim: usize, out_dim: usize, name: &str, act: Act) -> Self {
Self {
w: cx
.named_tensor(format!("{name}_w").as_str(), (out_dim, in_dim))
.persist(),
b: cx
.named_tensor(format!("{name}_b").as_str(), out_dim)
.persist(),
act,
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
match self.act {
Act::None => linear_bias(x, self.w, self.b),
Act::Relu => linear_bias_relu(x, self.w, self.b),
Act::Sigmoid => linear_bias_sigmoid(x, self.w, self.b),
}
}
}
fn read_f32(path: &Path) -> Vec<f32> {
let bytes = fs::read(path)
.unwrap_or_else(|e| panic!("can't read {}: {e}", path.display()));
assert_eq!(bytes.len() % 4, 0, "{} not a multiple of 4 bytes", path.display());
bytemuck::cast_slice::<u8, f32>(&bytes).to_vec()
}
fn read_i32(path: &Path) -> Vec<i32> {
let bytes = fs::read(path)
.unwrap_or_else(|e| panic!("can't read {}: {e}", path.display()));
assert_eq!(bytes.len() % 4, 0, "{} not a multiple of 4 bytes", path.display());
bytemuck::cast_slice::<u8, i32>(&bytes).to_vec()
}
fn main() {
let weights_dir: PathBuf = std::env::args()
.nth(1)
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("examples/dlrm/weights"));
let manifest_path = weights_dir.join("manifest.json");
let manifest_bytes = fs::read(&manifest_path)
.unwrap_or_else(|e| panic!("can't read {}: {e}", manifest_path.display()));
let manifest_text = std::str::from_utf8(&manifest_bytes).unwrap();
// Extract num_cat and rows from the manifest with a tiny parse — we only
// need two integers, so avoid pulling in serde_json.
let extract = |key: &str| -> usize {
let needle = format!("\"{key}\":");
let i = manifest_text.find(&needle).expect("key not found");
let rest = &manifest_text[i + needle.len()..];
let s: String = rest
.chars()
.skip_while(|c| !c.is_ascii_digit())
.take_while(|c| c.is_ascii_digit())
.collect();
s.parse().unwrap()
};
let num_cat = extract("num_cat");
let rows = extract("rows");
let num_fea = num_cat + 1;
let pair_count = num_fea * (num_fea - 1) / 2;
let top_in = pair_count + M_SPA;
let ln_top: Vec<usize> = std::iter::once(top_in)
.chain(LN_TOP_TAIL.iter().copied())
.collect();
println!(
"check: num_cat={num_cat} rows={rows} F={num_fea} pairs={pair_count} top_in={top_in}"
);
println!(" ln_top={ln_top:?}");
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
let dense = cx.named_tensor("dense", (BATCH, M_DEN)).persist();
let stacked_w = cx
.named_tensor("emb_stacked", (num_cat * rows, M_SPA))
.persist();
let mut sparse_inputs = Vec::with_capacity(num_cat);
for i in 0..num_cat {
sparse_inputs.push(
cx.named_tensor(format!("idx_{i}").as_str(), (BATCH, L))
.as_dtype(DType::Int)
.persist(),
);
}
// Upstream DLRM applies ReLU to every bot layer (sigmoid_bot=-1 default).
let mut bot_layers = Vec::new();
for (i, win) in LN_BOT.windows(2).enumerate() {
bot_layers.push(LinearWB::new(&mut cx, win[0], win[1], &format!("bot_{i}"), Act::Relu));
}
let mut h = dense;
for l in bot_layers.iter() {
h = l.forward(h);
}
let dense_out = h;
let row_offsets: Vec<usize> = (0..=num_cat).map(|k| k * rows).collect();
let emb_stack =
stacked_embedding_bag_sum_kernel(stacked_w, sparse_inputs.clone(), &row_offsets);
let interactions = dlrm_pairwise_dot_lower_tri_stacked(dense_out, emb_stack);
let mut top_layers = Vec::new();
let mut prev = top_in;
let last_top = LN_TOP_TAIL.len() - 1;
for (i, &h) in LN_TOP_TAIL.iter().enumerate() {
let act = if i < last_top { Act::Relu } else { Act::Sigmoid };
top_layers.push(LinearWB::new(&mut cx, prev, h, &format!("top_{i}"), act));
prev = h;
}
// top_0 reads `dense_out` and `interactions` directly via the
// split-A kernel — no materialized concat.
let mut t = linear_bias_relu_split_a(
dense_out,
interactions,
top_layers[0].w,
top_layers[0].b,
);
for l in top_layers.iter().skip(1) {
t = l.forward(t);
}
let out_t = t.output();
cx.build_search_space::<CudaRuntime>();
let mut runtime = CudaRuntime::initialize(stream);
// Load weights/biases from disk.
for (i, _win) in LN_BOT.windows(2).enumerate() {
let l = &bot_layers[i];
let w = read_f32(&weights_dir.join(format!("bot_{i}_w.bin")));
let b = read_f32(&weights_dir.join(format!("bot_{i}_b.bin")));
runtime.set_data(l.w, w);
runtime.set_data(l.b, b);
}
for i in 0..top_layers.len() {
let l = &top_layers[i];
let w = read_f32(&weights_dir.join(format!("top_{i}_w.bin")));
let b = read_f32(&weights_dir.join(format!("top_{i}_b.bin")));
runtime.set_data(l.w, w);
runtime.set_data(l.b, b);
}
// Read per-table weight files (as PyTorch dumped them) and concat them
// into the single stacked weight that the fused kernel expects.
let mut stacked = Vec::with_capacity(num_cat * rows * M_SPA);
for i in 0..num_cat {
let t = read_f32(&weights_dir.join(format!("emb_{i}.bin")));
assert_eq!(t.len(), rows * M_SPA, "emb_{i}.bin shape mismatch");
stacked.extend(t);
}
runtime.set_data(stacked_w, stacked);
let dense_data = read_f32(&weights_dir.join("dense.bin"));
runtime.set_data(dense, dense_data);
for i in 0..num_cat {
let idx = read_i32(&weights_dir.join(format!("idx_{i}.bin")));
runtime.set_data(sparse_inputs[i], idx);
}
let mut rng = StdRng::seed_from_u64(0);
runtime = cx.search_options(runtime, SearchOptions::new(50).trials(1).keep_best(2), &mut rng);
runtime.execute(&cx.dyn_map);
let lum_out = runtime.get_f32(out_t);
let expected = read_f32(&weights_dir.join("expected.bin"));
assert_eq!(
lum_out.len(),
expected.len(),
"output length mismatch: luminal={} expected={}",
lum_out.len(),
expected.len()
);
let mut max_abs = 0.0f32;
let mut sum_abs = 0.0f64;
let mut max_rel = 0.0f32;
let mut diff_count = 0usize;
for (i, (&a, &b)) in lum_out.iter().zip(expected.iter()).enumerate() {
let d = (a - b).abs();
sum_abs += d as f64;
if d > max_abs {
max_abs = d;
}
let r = if b.abs() > 1e-8 { d / b.abs() } else { d };
if r > max_rel {
max_rel = r;
}
if d > 1e-4 {
diff_count += 1;
if diff_count <= 5 {
println!(
" diff @ {i}: luminal={a:.6} expected={b:.6} abs={d:.3e}"
);
}
}
}
let mean_abs = sum_abs / lum_out.len() as f64;
println!();
println!(
" luminal head: {:?}",
&lum_out[..8.min(lum_out.len())]
);
println!(
" expected head: {:?}",
&expected[..8.min(expected.len())]
);
println!(
" max abs diff = {max_abs:.3e} mean abs diff = {mean_abs:.3e} max rel diff = {max_rel:.3e}"
);
println!(" elements with abs diff > 1e-4: {diff_count}/{}", lum_out.len());
let tol: f32 = 1e-3;
if max_abs < tol {
println!("PASS (max abs diff < {tol})");
} else {
println!("FAIL (max abs diff >= {tol})");
std::process::exit(1);
}
}

415
examples/dlrm/src/main.rs Normal file
View File

@@ -0,0 +1,415 @@
//! DLRM forward-pass benchmark on luminal's CUDA backend.
//!
//! Mirrors `sweep_categories.py` from https://github.com/jss8649/tmp-dlrm-bench:
//! batch=2048, m_spa=16, indices_per_bag=2, rows_per_table=4096,
//! ln_bot=[3, 64, 16], top MLP scales with F = num_cat + 1:
//! num_int = F*(F-1)/2 + m_spa, ln_top = [num_int, 64, 32, 1].
//! arch_interaction_op="dot", arch_interaction_itself=False, sigmoid_top=2.
//!
//! CLI: `dlrm [--num-cat N] [--rows R]`. Defaults: num-cat=3, rows=4096.
//!
//! Harness: 5 rounds × 20 timed iters, 10 warmup before round 0,
//! median-of-round-medians (same as bench_luminal_exact.py).
use luminal::prelude::*;
use luminal_cuda_lite::kernel::{
dlrm_pairwise_dot_lower_tri_stacked, linear_bias, linear_bias_relu, linear_bias_relu_split_a,
linear_bias_sigmoid, stacked_embedding_bag_sum_kernel,
};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::time::Instant;
// ---- Fixed config (matches sweep_categories.py defaults; per-call
// overridable from the CLI for sweep purposes) ----
const M_DEN: usize = 3;
const LN_TOP_TAIL: &[usize] = &[64, 32, 1]; // → ln_top = [num_int, 64, 32, 1]
const SEARCH_GRAPHS: usize = 200;
const SEARCH_TRIALS: usize = 1;
const SEARCH_KEEP_BEST: usize = 4;
const SEARCH_SEED: u64 = 0;
const PRE_ROUND_WARMUP: usize = 10;
const TIMED_ITERS_PER_ROUND: usize = 20;
const ROUNDS: usize = 5;
#[derive(Clone, Copy)]
enum Act {
None,
Relu,
Sigmoid,
}
struct LinearWB {
w: GraphTensor, // (out_dim, in_dim) — PyTorch/cuBLASLt convention
b: GraphTensor, // (out_dim,)
act: Act,
}
impl LinearWB {
fn new(cx: &mut Graph, in_dim: usize, out_dim: usize, name: &str, act: Act) -> Self {
Self {
w: cx
.named_tensor(format!("{name}_w").as_str(), (out_dim, in_dim))
.persist(),
b: cx
.named_tensor(format!("{name}_b").as_str(), out_dim)
.persist(),
act,
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
match self.act {
Act::None => linear_bias(x, self.w, self.b),
Act::Relu => linear_bias_relu(x, self.w, self.b),
Act::Sigmoid => linear_bias_sigmoid(x, self.w, self.b),
}
}
}
fn build_indices(table_idx: usize, batch: usize, bag: usize, rows: usize) -> Vec<i32> {
// Match sweep_categories.py's `_make_inputs`:
// positions = arange(batch * L)
// lS_i[k] = (positions * (2k+3) + (k+1)) % ROWS_PER_TABLE
let n = rows as i64;
let s = (2 * table_idx + 3) as i64;
let o = (table_idx + 1) as i64;
(0..(batch * bag) as i64)
.map(|p| ((p * s + o).rem_euclid(n)) as i32)
.collect()
}
fn rand_normal(rng: &mut StdRng, n: usize, std: f32) -> Vec<f32> {
let mut out = Vec::with_capacity(n);
while out.len() < n {
let u1: f32 = rng.random::<f32>().max(1e-9);
let u2: f32 = rng.random::<f32>();
let r = (-2.0 * u1.ln()).sqrt() * std;
let z0 = r * (2.0 * std::f32::consts::PI * u2).cos();
let z1 = r * (2.0 * std::f32::consts::PI * u2).sin();
out.push(z0);
if out.len() < n {
out.push(z1);
}
}
out
}
fn rand_uniform(rng: &mut StdRng, n: usize, hi: f32) -> Vec<f32> {
(0..n).map(|_| (rng.random::<f32>() * 2.0 - 1.0) * hi).collect()
}
fn build_dense_x(batch: usize, m_den: usize) -> Vec<f32> {
let total = batch * m_den;
(0..total)
.map(|i| -1.0 + 2.0 * (i as f32) / ((total - 1) as f32))
.collect()
}
struct Args {
num_cat: usize,
rows: usize,
batch: usize,
m_spa: usize,
bag: usize,
print_outputs: bool,
}
fn parse_args() -> Args {
let mut a = Args {
num_cat: 3,
rows: 4096,
batch: 2048,
m_spa: 16,
bag: 2,
print_outputs: false,
};
let mut args = std::env::args().skip(1);
while let Some(arg) = args.next() {
match arg.as_str() {
"--num-cat" => {
a.num_cat = args.next().expect("missing value for --num-cat").parse().unwrap();
}
"--rows" => {
a.rows = args.next().expect("missing value for --rows").parse().unwrap();
}
"--batch" => {
a.batch = args.next().expect("missing value for --batch").parse().unwrap();
}
"--m-spa" => {
a.m_spa = args.next().expect("missing value for --m-spa").parse().unwrap();
}
"--bag" => {
a.bag = args.next().expect("missing value for --bag").parse().unwrap();
}
"--print-outputs" => {
a.print_outputs = true;
}
"-h" | "--help" => {
eprintln!(
"usage: dlrm [--num-cat N] [--rows R] [--batch B] [--m-spa D] [--bag L] [--print-outputs]"
);
std::process::exit(0);
}
other => {
eprintln!("unknown arg: {other}");
std::process::exit(2);
}
}
}
a
}
fn main() {
let args = parse_args();
let Args {
num_cat,
rows,
batch,
m_spa,
bag,
print_outputs,
} = args;
let num_fea = num_cat + 1;
let pair_count = num_fea * (num_fea - 1) / 2;
let top_in = pair_count + m_spa;
// ln_bot last entry must equal m_spa (bot output feeds into interaction
// alongside emb rows that are m_spa wide).
let ln_bot: Vec<usize> = vec![M_DEN, 64, m_spa];
let ln_top: Vec<usize> = std::iter::once(top_in)
.chain(LN_TOP_TAIL.iter().copied())
.collect();
println!(
"==== DLRM luminal config ==== num_cat={num_cat} F={num_fea} \
rows/table={rows} batch={batch} m_spa={m_spa} L={bag}"
);
println!(" ln_bot={ln_bot:?} ln_top={ln_top:?} pairs={pair_count}");
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
let dense = cx
.named_tensor("dense", (batch, M_DEN))
.persist();
let stacked_w = cx
.named_tensor("emb_stacked", (num_cat * rows, m_spa))
.persist();
let mut sparse_inputs = Vec::with_capacity(num_cat);
for i in 0..num_cat {
sparse_inputs.push(
cx.named_tensor(format!("idx_{i}").as_str(), (batch, bag))
.as_dtype(DType::Int)
.persist(),
);
}
// Upstream `dlrm_s_pytorch.DLRM_Net.create_mlp` with `sigmoid_bot=-1`
// (the default) applies ReLU to every bot layer — including the final
// one. The output of the bot MLP feeds the interaction op AFTER ReLU.
let mut bot_layers = Vec::new();
for (i, win) in ln_bot.windows(2).enumerate() {
bot_layers.push(LinearWB::new(&mut cx, win[0], win[1], &format!("bot_{i}"), Act::Relu));
}
let mut h = dense;
for l in bot_layers.iter() {
h = l.forward(h);
}
let dense_out = h;
// One fused kernel for all num_cat tables. row_offsets[k] = k * rows
// (uniform rows-per-table in this benchmark, matching upstream).
let row_offsets: Vec<usize> = (0..=num_cat).map(|k| k * rows).collect();
let emb_stack =
stacked_embedding_bag_sum_kernel(stacked_w, sparse_inputs.clone(), &row_offsets);
// emb_stack: (BATCH, num_cat, M_SPA)
// Feature interaction over [dense_out, emb_stack[:, 0, :], …,
// emb_stack[:, num_cat-1, :]] in one fused kernel — no per-table slice.
let interactions = dlrm_pairwise_dot_lower_tri_stacked(dense_out, emb_stack);
// Skip the materialized `cat(dense_out, interactions)`: the first top
// layer reads both halves directly via the split-A matmul kernel. The
// remaining top layers see the dense fused output and stay vanilla.
let mut top_layers = Vec::new();
let mut prev = top_in;
let last_top = LN_TOP_TAIL.len() - 1;
for (i, &h) in LN_TOP_TAIL.iter().enumerate() {
let act = if i < last_top {
Act::Relu
} else {
Act::Sigmoid
};
top_layers.push(LinearWB::new(&mut cx, prev, h, &format!("top_{i}"), act));
prev = h;
}
// top_0: split-A matmul (no concat materialization).
let mut t = linear_bias_relu_split_a(
dense_out,
interactions,
top_layers[0].w,
top_layers[0].b,
);
// top_1..: standard fused linear_bias_(relu|sigmoid) kernels.
for l in top_layers.iter().skip(1) {
t = l.forward(t);
}
let out = t.output();
println!("Building E-graph...");
let t = Instant::now();
cx.build_search_space::<CudaRuntime>();
println!(" build: {:.2}s", t.elapsed().as_secs_f64());
let mut runtime = CudaRuntime::initialize(stream);
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
for (i, win) in ln_bot.windows(2).enumerate() {
let (a, b) = (win[0], win[1]);
let l = &bot_layers[i];
runtime.set_data(
l.w,
rand_normal(&mut rng, a * b, (2.0 / (a + b) as f32).sqrt()),
);
runtime.set_data(l.b, rand_normal(&mut rng, b, (1.0 / b as f32).sqrt()));
}
let mut top_shapes: Vec<(usize, usize)> = vec![];
let mut prev = top_in;
for &h in LN_TOP_TAIL.iter() {
top_shapes.push((prev, h));
prev = h;
}
for (i, &(a, b)) in top_shapes.iter().enumerate() {
let l = &top_layers[i];
runtime.set_data(
l.w,
rand_normal(&mut rng, a * b, (2.0 / (a + b) as f32).sqrt()),
);
runtime.set_data(l.b, rand_normal(&mut rng, b, (1.0 / b as f32).sqrt()));
}
// One stacked weight: concatenated per-table uniform inits.
let mut stacked_data: Vec<f32> = Vec::with_capacity(num_cat * rows * m_spa);
for _ in 0..num_cat {
stacked_data
.extend(rand_uniform(&mut rng, rows * m_spa, 1.0 / (rows as f32).sqrt()));
}
runtime.set_data(stacked_w, stacked_data);
runtime.set_data(dense, build_dense_x(batch, M_DEN));
for i in 0..num_cat {
runtime.set_data(sparse_inputs[i], build_indices(i, batch, bag, rows));
}
println!("Searching/compiling...");
let t = Instant::now();
runtime = cx.search_options(
runtime,
SearchOptions::new(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST),
&mut rng,
);
println!(" search/compile: {:.2}s", t.elapsed().as_secs_f64());
{
let host_ops = runtime.host_ops();
let total = host_ops.len();
let cublaslt = host_ops
.iter()
.filter(|op| format!("{op:?}").contains("CuBlasLt"))
.count();
let cudagraph = host_ops
.iter()
.filter(|op| format!("{op:?}").contains("CudaGraph"))
.count();
println!("Host ops: total={total} cuBLASLt={cublaslt} CudaGraph={cudagraph}");
}
for _ in 0..PRE_ROUND_WARMUP {
runtime.execute(&cx.dyn_map);
}
let _ = runtime.get_f32(out);
let mut round_medians = Vec::with_capacity(ROUNDS);
for round in 0..ROUNDS {
let mut times_us = Vec::with_capacity(TIMED_ITERS_PER_ROUND);
for _ in 0..TIMED_ITERS_PER_ROUND {
let t = Instant::now();
runtime.execute(&cx.dyn_map);
// execute() no longer syncs at end (so it stays capturable by
// torch.cuda.CUDAGraph) — sync explicitly for accurate GPU time.
runtime.synchronize_stream();
times_us.push(t.elapsed().as_micros() as f64);
}
let _ = runtime.get_f32(out);
times_us.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = times_us[times_us.len() / 2] / 1000.0;
println!(
" round {round}: median {median:.3} ms min {:.3} max {:.3}",
times_us[0] / 1000.0,
times_us[times_us.len() - 1] / 1000.0
);
round_medians.push(median);
}
round_medians.sort_by(|a, b| a.partial_cmp(b).unwrap());
let med_of_med = round_medians[round_medians.len() / 2];
let tput = batch as f64 / (med_of_med / 1000.0);
println!();
println!(
"==== cfg(num_cat={num_cat},batch={batch},m_spa={m_spa},bag={bag},rows={rows}): \
luminal median-of-round-medians = {med_of_med:.4} ms ({tput:.0} samples/s)"
);
if print_outputs {
let sample = runtime.get_f32(out);
let head: Vec<f32> = sample.iter().take(8).copied().collect();
println!(" output[0..8] = {head:?} (len={})", sample.len());
}
// Per-kernel breakdown. Available when `LUMINAL_KERNEL_TIMING=1` was
// set before the first execute (it gates the event-record-node
// insertion at graph-build time). `get_f32` above already
// synchronized the stream, so the events have valid data.
if std::env::var_os("LUMINAL_KERNEL_TIMING").is_some() {
let timings = runtime.read_per_kernel_timings_ms();
if timings.is_empty() {
println!(" (no per-kernel timings available — events not recorded)");
} else {
// Aggregate per kernel name in case the same op appears more
// than once (e.g. two Matmul2D_BiasRelu calls in the bot MLP).
let mut sum_per_name: std::collections::BTreeMap<&'static str, (f32, usize)> =
std::collections::BTreeMap::new();
let mut total = 0.0_f32;
for (name, ms) in &timings {
let e = sum_per_name.entry(*name).or_insert((0.0, 0));
e.0 += *ms;
e.1 += 1;
total += *ms;
}
println!();
println!(" per-kernel GPU time (single replay, ms):");
println!(
" {:>40} {:>3} {:>10} {:>10} {:>5}",
"kernel", "n", "total_ms", "each_ms", "pct"
);
// Sort by total descending so the bottleneck is on top.
let mut sorted: Vec<_> = sum_per_name.iter().collect();
sorted.sort_by(|a, b| b.1.0.partial_cmp(&a.1.0).unwrap());
for (name, (sum, n)) in sorted {
let pct = if total > 0.0 { 100.0 * sum / total } else { 0.0 };
println!(
" {:>40} {:>3} {:>10.4} {:>10.4} {:>4.1}%",
name,
n,
sum,
sum / (*n as f32),
pct
);
}
println!(" {:>40} {:>10.4}", "TOTAL", total);
}
}
}

101
examples/dlrm/summarize.py Normal file
View File

@@ -0,0 +1,101 @@
"""Summarize a results.csv from sweep_all.py: pretty per-cell tables and
ratios of luminal_compiled vs the PT baselines.
"""
from __future__ import annotations
import argparse
import csv
import math
import sys
from pathlib import Path
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("csv", default="results.csv", nargs="?")
args = ap.parse_args()
p = Path(args.csv)
rows = list(csv.DictReader(p.open()))
# Index: (num_cat, batch, variant) -> ms
table: dict[tuple[int, int, str], float | None] = {}
nums_cat: set[int] = set()
batches: set[int] = set()
variants: set[str] = set()
for r in rows:
if r["status"] != "ok":
continue
nc = int(r["num_cat"])
bs = int(r["batch"])
v = r["variant"]
ms = float(r["ms"]) if r["ms"] else None
if ms is None:
continue
table[(nc, bs, v)] = ms
nums_cat.add(nc)
batches.add(bs)
variants.add(v)
nums_cat_s = sorted(nums_cat)
batches_s = sorted(batches)
variants_s = sorted(variants)
def fmt_ms(ms: float | None) -> str:
if ms is None:
return " - "
if ms < 1.0:
return f"{ms*1000:>5.0f} us"
return f"{ms:>6.3f} ms"
print(f"# DLRMv1 sweep — {len(rows)} rows, {len(variants_s)} variants")
print()
print(f"Configurations: batch ∈ {batches_s}, num_cat ∈ {nums_cat_s}")
print()
# Per-variant table.
for v in variants_s:
print(f"## {v}")
print()
hdr = "| nc \\ batch | " + " | ".join(f"{b:>6}" for b in batches_s) + " |"
print(hdr)
print("|" + "|".join(["---"] * (len(batches_s) + 1)) + "|")
for nc in nums_cat_s:
cells = [fmt_ms(table.get((nc, bs, v))) for bs in batches_s]
print(f"| {nc:>3} | " + " | ".join(cells) + " |")
print()
# Speedup ratios: luminal_compiled vs each PT variant
if "luminal_compiled" in variants_s:
for ref in ["pt_eager", "graph_safe_inductor_cg", "graph_safe_cg", "v3_inductor_cg"]:
if ref not in variants_s:
continue
print(f"## speedup luminal_compiled vs {ref} (>1 = luminal faster)")
print()
print("| nc \\ batch | " + " | ".join(f"{b:>6}" for b in batches_s) + " |")
print("|" + "|".join(["---"] * (len(batches_s) + 1)) + "|")
wins = 0
cells_total = 0
for nc in nums_cat_s:
rcells = []
for bs in batches_s:
lum = table.get((nc, bs, "luminal_compiled"))
pt = table.get((nc, bs, ref))
if lum is None or pt is None or lum <= 0:
rcells.append(" - ")
continue
cells_total += 1
r = pt / lum
if r >= 1.0:
wins += 1
rcells.append(f"{r:>5.2f}x")
print(f"| {nc:>3} | " + " | ".join(rcells) + " |")
print()
if cells_total:
print(f"luminal faster in {wins}/{cells_total} cells ({100*wins/cells_total:.0f}%)")
print()
return 0
if __name__ == "__main__":
sys.exit(main())

115
examples/dlrm/sweep_all.py Normal file
View File

@@ -0,0 +1,115 @@
"""Drive the full DLRM sweep across (batch, num_cat) and produce a CSV.
For each cell, we shell out to `sweep_pytorch.py` (one process per
variant, so torch.compile cache state is fresh) and to `sweep_luminal.py`
(also separate, so the luminal compiled graph builds from clean state).
Each cell × variant → one row in `results.csv`. Run from this dir:
python sweep_all.py [--variant ...] [--quick]
"""
from __future__ import annotations
import argparse
import csv
import json
import subprocess
import sys
import time
from pathlib import Path
HERE = Path(__file__).parent
PY = sys.executable
def _powers(lo: int, hi: int) -> list[int]:
out = []
v = lo
while v <= hi:
out.append(v)
v *= 2
return out
def _run_json(cmd: list[str]) -> dict | None:
try:
r = subprocess.run(cmd, cwd=HERE, capture_output=True, text=True, timeout=600)
except subprocess.TimeoutExpired:
print(f" TIMEOUT: {' '.join(cmd)}", file=sys.stderr)
return None
if r.returncode != 0:
print(f" ERR : {' '.join(cmd)}\n stderr tail: {r.stderr[-400:]}", file=sys.stderr)
return None
for line in r.stdout.splitlines()[::-1]:
s = line.strip()
if s.startswith("{") and s.endswith("}"):
try:
return json.loads(s)
except json.JSONDecodeError:
continue
return None
PT_VARIANTS = ["pt_eager", "graph_safe_cg", "graph_safe_inductor_cg", "v3_inductor_cg"]
LUMINAL_VARIANTS = ["luminal_compiled"]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--batch", type=int, nargs="+", default=None)
ap.add_argument("--num-cat", type=int, nargs="+", default=None)
ap.add_argument("--variants", nargs="+", default=PT_VARIANTS + LUMINAL_VARIANTS)
ap.add_argument("--quick", action="store_true",
help="Smaller sweep (nc∈{2,8,32}, batch∈{32,512,2048}) for fast iteration.")
ap.add_argument("--out", default="results.csv")
args = ap.parse_args()
if args.quick:
nums_cat = [2, 8, 32]
batches = [32, 512, 2048]
else:
nums_cat = args.num_cat or _powers(2, 32)
batches = args.batch or _powers(2, 2048)
rows: list[dict] = []
t0 = time.time()
total_cells = len(nums_cat) * len(batches) * len(args.variants)
seen = 0
for nc in nums_cat:
for bs in batches:
for variant in args.variants:
seen += 1
cmd_base = ["--num-cat", str(nc), "--batch", str(bs), "--json"]
if variant == "luminal_compiled":
cmd = [PY, "sweep_luminal.py"] + cmd_base
else:
cmd = [PY, "sweep_pytorch.py", "--variant", variant] + cmd_base
r = _run_json(cmd)
if r is None:
print(f"[{seen}/{total_cells}] FAIL nc={nc:>2} batch={bs:>4} variant={variant}", flush=True)
rows.append({"variant": variant, "num_cat": nc, "batch": bs, "ms": None,
"samples_per_sec": None, "status": "failed"})
else:
r["status"] = "ok"
rows.append(r)
print(f"[{seen}/{total_cells}] {r['ms']:8.4f} ms nc={nc:>2} batch={bs:>4} variant={variant}", flush=True)
out_path = HERE / args.out
with open(out_path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["variant", "num_cat", "batch", "m_spa", "bag", "rows", "ms",
"samples_per_sec", "status"])
w.writeheader()
for row in rows:
row = dict(row)
row.setdefault("m_spa", 16)
row.setdefault("bag", 2)
row.setdefault("rows", 4096)
row.setdefault("status", "ok")
w.writerow(row)
print(f"\nWrote {out_path} ({len(rows)} rows, {time.time()-t0:.1f}s)")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,98 @@
"""Time DLRMv1 under `torch.compile(model, backend=luminal_backend)`.
Sibling to sweep_pytorch.py — same shape, same harness, same JSON output.
Strategy: import DLRMv1/make_inputs from sweep_pytorch (so the model
definition is the single source of truth across all variants), warm up,
then time 5 rounds × 20 iters × 10 warmup as elsewhere.
Outputs one JSON line per variant (currently `luminal_compiled`).
"""
from __future__ import annotations
import argparse
import json
import sys
from contextlib import contextmanager
import numpy as np
import torch
# Make the sibling import resolve regardless of CWD.
sys.path.insert(0, "/home/ubuntu/luminal/examples/dlrm")
from sweep_pytorch import DLRMv1, make_inputs, time_rounds # noqa: E402
import luminal # noqa: E402
from luminal import luminal_backend # noqa: E402
@contextmanager
def _relaxed_dynamo_limits():
pr = torch._dynamo.config.recompile_limit
pc = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 64
torch._dynamo.config.cache_size_limit = 64
try:
yield
finally:
torch._dynamo.config.recompile_limit = pr
torch._dynamo.config.cache_size_limit = pc
def _run_luminal_compiled(num_cat: int, batch: int, rows: int, m_spa: int, bag: int, device):
"""Time `torch.compile(model, backend=luminal_backend)`.
Do NOT wrap the compiled call in `torch.cuda.CUDAGraph` — luminal's
`cuda_lite` runtime already captures and replays a CUDA graph
internally (one host op per `execute()` call), so an external wrap
would just be hiding Python-wrapper / FFI overhead rather than
measuring luminal's actual perf.
"""
plain = DLRMv1(num_cat, rows, m_spa).eval().to(device)
inputs = make_inputs(num_cat, batch, bag, rows, device)
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(plain, backend=luminal_backend, fullgraph=False, dynamic=False)
# Warm up: triggers the export + translate + search.
with torch.no_grad():
for _ in range(3):
_ = compiled(*inputs)
torch.cuda.synchronize()
return time_rounds(lambda: compiled(*inputs))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--num-cat", type=int, default=3)
ap.add_argument("--rows", type=int, default=4096)
ap.add_argument("--batch", type=int, default=2048)
ap.add_argument("--m-spa", type=int, default=16)
ap.add_argument("--bag", type=int, default=2)
ap.add_argument("--json", action="store_true")
args = ap.parse_args()
device = torch.device("cuda")
ms = _run_luminal_compiled(args.num_cat, args.batch, args.rows, args.m_spa, args.bag, device)
rec = {
"variant": "luminal_compiled",
"num_cat": args.num_cat,
"batch": args.batch,
"m_spa": args.m_spa,
"bag": args.bag,
"rows": args.rows,
"ms": ms,
"samples_per_sec": args.batch / (ms / 1000.0),
}
if args.json:
print(json.dumps(rec))
else:
print(
f"==== luminal_compiled cfg: num_cat={args.num_cat} batch={args.batch} ===="
)
print(f" luminal_compiled {ms:8.4f} ms "
f"({rec['samples_per_sec']:>12,.0f} samples/s)")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,332 @@
"""Parameterized PyTorch sweep for DLRMv1 across (batch, num_cat).
Self-contained: defines DLRMv1 inline (vanilla `nn.EmbeddingBag` per table,
strict lower-tri pairwise-dot interaction, top MLP), so it does NOT need
facebookresearch/dlrm cloned locally.
Variants measured (toggle with `--variant`):
* `pt_eager` — eager forward of plain DLRMv1, no CG.
* `graph_safe_cg` — `GraphSafeDLRM` wrap of plain DLRMv1, captured
via `torch.cuda.CUDAGraph`, replayed.
* `graph_safe_inductor_cg` — `GraphSafeDLRM` wrap, then `torch.compile(
backend="inductor", mode="max-autotune-
no-cudagraphs", fullgraph=False)`, then
manual `torch.cuda.CUDAGraph` capture/replay.
This is the user-named primary baseline.
* `v3_inductor_cg` — `FusedDLRMv3` (stacked-emb-table rewrite,
index_select+sum), `torch.compile(inductor,
max-autotune-no-cudagraphs)`, manual CG.
* `all` — all of the above.
Harness: 5 rounds × 20 timed iters with 10 warmup before round 0, median
of CUDA-event-timed per-iter ms, then median across rounds.
"""
from __future__ import annotations
import argparse
import copy
import json
import os
import sys
from contextlib import contextmanager
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
# ---- Fixed dials matching the upstream sweep_categories.py defaults ------
M_DEN = 3
M_SPA = 16
INDICES_PER_BAG = 2
LN_BOT_TAIL = [64] # → ln_bot = [M_DEN, 64, M_SPA]
LN_TOP_TAIL = [64, 32, 1] # → ln_top = [num_int, 64, 32, 1]
SEED = 0
@contextmanager
def _relaxed_dynamo_limits():
pr = torch._dynamo.config.recompile_limit
pc = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 64
torch._dynamo.config.cache_size_limit = 64
try:
yield
finally:
torch._dynamo.config.recompile_limit = pr
torch._dynamo.config.cache_size_limit = pc
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
layers: List[nn.Module] = []
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
layers.append(nn.Linear(a, b, bias=True))
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
return nn.Sequential(*layers)
# ----------------------------------------------------------------------
# Vanilla DLRMv1 — what a user writes with nn.EmbeddingBag.
# This is the model the luminal torch.compile backend must ingest.
# ----------------------------------------------------------------------
class DLRMv1(nn.Module):
def __init__(self, num_cat: int, rows: int, m_spa: int):
super().__init__()
torch.manual_seed(SEED)
np.random.seed(SEED)
self.num_cat = num_cat
self.rows = rows
self.m_spa = m_spa
self.num_fea = num_cat + 1 # +1 for dense
self.pair_count = self.num_fea * (self.num_fea - 1) // 2
self.top_in = self.pair_count + m_spa
ln_bot = [M_DEN] + LN_BOT_TAIL + [m_spa]
ln_top = [self.top_in] + LN_TOP_TAIL
self.bot = _build_mlp(ln_bot, sigmoid_layer=-1)
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
self.emb = nn.ModuleList(
[nn.EmbeddingBag(rows, m_spa, mode="sum", sparse=False) for _ in range(num_cat)]
)
li, lj = [], []
for i in range(self.num_fea):
for j in range(i):
li.append(i)
lj.append(j)
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
def forward(self, dense_x, lS_o, lS_i):
x = self.bot(dense_x) # (B, M_SPA)
ly = [self.emb[k](lS_i[k], lS_o[k]) for k in range(self.num_cat)] # each (B, M_SPA)
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1) # (B, F, D)
Z = torch.bmm(T, T.transpose(1, 2)) # (B, F, F)
Zflat = Z[:, self.li, self.lj] # (B, PAIRS)
R = torch.cat([x, Zflat], dim=1)
return self.top(R)
# ----------------------------------------------------------------------
# Graph-safe wrapper: identical math, but the strict-lower-tri (li, lj)
# indices are pre-baked into buffers (so the forward does no Python-side
# allocs, making it `torch.cuda.CUDAGraph`-capturable).
# ----------------------------------------------------------------------
class GraphSafeDLRM(nn.Module):
def __init__(self, inner: DLRMv1):
super().__init__()
self.inner = inner
# Reuse the inner's li/lj buffers — they're already pre-computed.
def forward(self, dense_x, lS_o, lS_i):
inner = self.inner
x = inner.bot(dense_x)
ly = [inner.emb[k](lS_i[k], lS_o[k]) for k in range(inner.num_cat)]
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)
Z = torch.bmm(T, T.transpose(1, 2))
Zflat = Z[:, inner.li, inner.lj]
R = torch.cat([x, Zflat], dim=1)
return inner.top(R)
# ----------------------------------------------------------------------
# Fused v3: stacked embedding table + index_select + view + sum. Same
# math as `mode="sum"` EmbeddingBag at fixed bag size; Inductor can
# fuse the gather+sum into one Triton kernel.
# ----------------------------------------------------------------------
class FusedDLRMv3(nn.Module):
def __init__(self, plain: DLRMv1, bag: int):
super().__init__()
self.num_emb = plain.num_cat
self.m_spa = plain.m_spa
self.L = bag
self.num_fea = plain.num_fea
total_rows = plain.rows * plain.num_cat
big = torch.empty(total_rows, plain.m_spa, dtype=torch.float32)
starts = np.arange(plain.num_cat, dtype=np.int64) * plain.rows
for k in range(plain.num_cat):
big[starts[k] : starts[k] + plain.rows].copy_(plain.emb[k].weight.detach())
self.emb_weight = nn.Parameter(big)
self.register_buffer("row_offsets", torch.from_numpy(starts))
self.bot = copy.deepcopy(plain.bot)
self.top = copy.deepcopy(plain.top)
self.register_buffer("li", plain.li.clone(), persistent=False)
self.register_buffer("lj", plain.lj.clone(), persistent=False)
def pack_indices(self, lS_i):
return (torch.stack(lS_i, dim=0) + self.row_offsets.view(self.num_emb, 1)).reshape(-1)
def forward(self, dense_x, flat_indices):
bs = dense_x.shape[0]
g = (
torch.index_select(self.emb_weight, 0, flat_indices)
.view(self.num_emb * bs, self.L, self.m_spa)
.sum(dim=1)
)
ly = g.view(self.num_emb, bs, self.m_spa).transpose(0, 1)
x = self.bot(dense_x)
T = torch.cat([x.unsqueeze(1), ly], dim=1)
Z = torch.bmm(T, T.transpose(1, 2))
Zflat = Z[:, self.li, self.lj]
R = torch.cat([x, Zflat], dim=1)
return self.top(R)
# ----------------------------------------------------------------------
# Inputs match sweep_categories.py's `_make_inputs` for parity with the
# hand-written luminal example.
# ----------------------------------------------------------------------
def make_inputs(num_cat: int, batch: int, bag: int, rows: int, device):
dense = torch.linspace(
-1.0, 1.0, steps=batch * M_DEN, dtype=torch.float32, device=device
).reshape(batch, M_DEN)
total = batch * bag
positions = torch.arange(total, dtype=torch.int64, device=device)
offsets = torch.arange(0, total, bag, dtype=torch.int64, device=device)
lS_o = [offsets.clone() for _ in range(num_cat)]
lS_i = [((positions * (2 * k + 3) + (k + 1)) % rows).to(torch.int64) for k in range(num_cat)]
return dense, lS_o, lS_i
# ----------------------------------------------------------------------
# Timing: 5 rounds × 20 iters × 10 warmup, median-of-round-medians,
# CUDA event timing.
# ----------------------------------------------------------------------
def time_rounds(call, rounds: int = 5, iters: int = 20, warmup: int = 10) -> float:
medians = []
for r in range(rounds):
if r == 0:
for _ in range(warmup):
call()
torch.cuda.synchronize()
s = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
e = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
for i in range(iters):
s[i].record()
call()
e[i].record()
torch.cuda.synchronize()
elapsed = np.array([a.elapsed_time(b) for a, b in zip(s, e)])
medians.append(float(np.median(elapsed)))
return float(np.median(medians))
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def _capture_cudagraph(callable_fn, warmup: int = 3) -> torch.cuda.CUDAGraph:
"""Warm up on a side stream, then capture a CUDA graph of one call."""
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s), torch.no_grad():
for _ in range(warmup):
callable_fn()
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.no_grad(), torch.cuda.graph(g):
_ = callable_fn()
return g
def _run_variant(variant: str, num_cat: int, batch: int, rows: int, m_spa: int, bag: int, device):
plain = DLRMv1(num_cat, rows, m_spa).eval().to(device)
inputs = make_inputs(num_cat, batch, bag, rows, device)
results: dict = {}
if variant in ("pt_eager", "all"):
with torch.no_grad():
results["pt_eager"] = time_rounds(lambda: plain(*inputs))
if variant in ("graph_safe_cg", "all"):
safe = GraphSafeDLRM(copy.deepcopy(plain)).to(device).eval()
g = _capture_cudagraph(lambda: safe(*inputs))
results["graph_safe_cg"] = time_rounds(lambda: g.replay())
if variant in ("graph_safe_inductor_cg", "all"):
safe = GraphSafeDLRM(copy.deepcopy(plain)).to(device).eval()
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(
safe,
backend="inductor",
mode="max-autotune-no-cudagraphs",
fullgraph=False,
dynamic=False,
)
g = _capture_cudagraph(lambda: compiled(*inputs), warmup=5)
results["graph_safe_inductor_cg"] = time_rounds(lambda: g.replay())
if variant in ("v3_inductor_cg", "all"):
v3 = FusedDLRMv3(plain, bag).to(device).eval()
flat = v3.pack_indices(list(inputs[2])).contiguous()
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(
v3.forward,
backend="inductor",
mode="max-autotune-no-cudagraphs",
fullgraph=False,
dynamic=False,
)
g = _capture_cudagraph(lambda: compiled(inputs[0], flat), warmup=5)
results["v3_inductor_cg"] = time_rounds(lambda: g.replay())
return results
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--num-cat", type=int, default=3)
ap.add_argument("--rows", type=int, default=4096)
ap.add_argument("--batch", type=int, default=2048)
ap.add_argument("--m-spa", type=int, default=16)
ap.add_argument("--bag", type=int, default=2)
ap.add_argument(
"--variant",
choices=["pt_eager", "graph_safe_cg", "graph_safe_inductor_cg", "v3_inductor_cg", "all"],
default="all",
)
ap.add_argument("--json", action="store_true", help="emit one JSON line per variant")
args = ap.parse_args()
device = torch.device("cuda")
if not args.json:
print(
f"==== PyTorch sweep cfg: num_cat={args.num_cat} batch={args.batch} "
f"m_spa={args.m_spa} bag={args.bag} rows={args.rows} ===="
)
results = _run_variant(
args.variant, args.num_cat, args.batch, args.rows, args.m_spa, args.bag, device
)
if args.json:
for variant, ms in results.items():
print(
json.dumps(
{
"variant": variant,
"num_cat": args.num_cat,
"batch": args.batch,
"m_spa": args.m_spa,
"bag": args.bag,
"rows": args.rows,
"ms": ms,
"samples_per_sec": args.batch / (ms / 1000.0),
}
)
)
else:
print()
for label, ms in results.items():
tput = args.batch / (ms / 1000.0)
print(f" {label:<32} {ms:8.4f} ms ({tput:>12,.0f} samples/s)")
return 0
if __name__ == "__main__":
sys.exit(main())