mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
8 Commits
flashinfer
...
dlrm-pt2-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfcc41040e | ||
|
|
9d4a3bc555 | ||
|
|
6f8de66e3d | ||
|
|
cb9facfb11 | ||
|
|
2845e605c1 | ||
|
|
ccdb6f1540 | ||
|
|
3f57d94ecb | ||
|
|
3b36880c22 |
450
crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs
Normal file
450
crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
//! DLRM-shape megakernel — one CUDA kernel does the full forward pass
|
||||
//! (bot MLP → N embedding gathers → dot-product interaction → top MLP)
|
||||
//! per (thread × batch row). All intermediate activations live in
|
||||
//! registers; weights are read straight from global memory and rely on
|
||||
//! the L1 cache (the full weight footprint is a few KB).
|
||||
//!
|
||||
//! Parameterized by the DLRM family shape: dense input width, bot MLP
|
||||
//! widths, number of sparse tables + their vocabs, embedding dim,
|
||||
//! top MLP widths. CUDA source is generated per-shape via `format!`
|
||||
//! and compiled through luminal's nvrtc wrapper with source-string
|
||||
//! caching (same path as [`crate::kernel::matmul2d::Matmul2DKernel`]).
|
||||
//!
|
||||
//! Used by `luminal_python`'s PT2 translator when it detects a DLRM-shape
|
||||
//! input graph — see `crates/luminal_python/rust/src/translator/dlrm_pattern.rs`.
|
||||
//! The standalone `examples/dlrm/src/megakernel.rs` is the proof-of-concept
|
||||
//! this module generalizes from.
|
||||
//!
|
||||
//! ## Input layout
|
||||
//!
|
||||
//! The kernel's input list (passed to `cx.custom_op`) is, in order:
|
||||
//! 1. dense_x F32 (B, n_dense_in)
|
||||
//! 2..2+n_sparse int32 indices per sparse table, each (B,)
|
||||
//! — luminal collapses all integer types to 32-bit Int,
|
||||
//! so the runtime delivers a 4-byte-per-element buffer
|
||||
//! regardless of the original PyTorch dtype.
|
||||
//! 2+n_sparse.. F32 embedding weights, one per table, each (V_k, m_spa)
|
||||
//! then bot Linear weight+bias pairs, in topological order
|
||||
//! then top Linear weight+bias pairs, in topological order
|
||||
//!
|
||||
//! The matcher in luminal_python lines up these inputs from the parsed
|
||||
//! PT2 graph; mismatches there will surface as wrong-output bugs in
|
||||
//! `tests/test_dlrm.py`, not as a crash.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Static shape description for the DLRM family. Every dim is a `usize`
|
||||
/// resolved at translate time — the kernel bakes them all into the CUDA
|
||||
/// source as compile-time constants.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DlrmMegaKernel {
|
||||
/// Per-call batch size.
|
||||
pub batch: usize,
|
||||
/// Number of dense features (first element of `ln_bot`).
|
||||
pub n_dense_in: usize,
|
||||
/// Bot MLP layer widths. `ln_bot[0] == n_dense_in`; `ln_bot.last() == m_spa`.
|
||||
/// Must have at least 2 entries (one Linear layer).
|
||||
pub ln_bot: Vec<usize>,
|
||||
/// Number of sparse embedding tables.
|
||||
pub n_sparse: usize,
|
||||
/// Vocab size for each table (length == `n_sparse`).
|
||||
pub vocab_sizes: Vec<usize>,
|
||||
/// Sparse embedding dim (equal across tables, == bot MLP output width).
|
||||
pub m_spa: usize,
|
||||
/// Top MLP layer widths. `ln_top[0] == m_spa + n_pairs`; `ln_top.last() == 1`.
|
||||
pub ln_top: Vec<usize>,
|
||||
}
|
||||
|
||||
impl DlrmMegaKernel {
|
||||
/// `n_feat = 1 + n_sparse` — number of feature vectors fed into the
|
||||
/// dot interaction (1 dense + sparse tables).
|
||||
fn n_feat(&self) -> usize {
|
||||
1 + self.n_sparse
|
||||
}
|
||||
|
||||
/// `n_pairs = n_feat * (n_feat - 1) / 2` — number of strictly-lower-tri
|
||||
/// pairs produced by the dot interaction.
|
||||
fn n_pairs(&self) -> usize {
|
||||
let n = self.n_feat();
|
||||
n * (n - 1) / 2
|
||||
}
|
||||
|
||||
/// Validation: cheap up-front check that the shape is internally
|
||||
/// consistent. The matcher should have caught all of these but a
|
||||
/// debug-assert keeps the kernel compile path well-defined.
|
||||
fn validate(&self) {
|
||||
assert!(self.ln_bot.len() >= 2, "ln_bot must have ≥2 entries");
|
||||
assert!(self.ln_top.len() >= 2, "ln_top must have ≥2 entries");
|
||||
assert_eq!(self.ln_bot[0], self.n_dense_in, "ln_bot[0] must == n_dense_in");
|
||||
assert_eq!(*self.ln_bot.last().unwrap(), self.m_spa, "ln_bot.last() must == m_spa");
|
||||
assert_eq!(self.vocab_sizes.len(), self.n_sparse);
|
||||
assert_eq!(
|
||||
self.ln_top[0],
|
||||
self.m_spa + self.n_pairs(),
|
||||
"ln_top[0] must == m_spa + n_pairs"
|
||||
);
|
||||
assert_eq!(*self.ln_top.last().unwrap(), 1, "ln_top.last() must == 1 (binary classifier)");
|
||||
assert!(self.batch > 0);
|
||||
}
|
||||
|
||||
/// Generate the CUDA source for this kernel shape.
|
||||
fn cuda_source(&self) -> String {
|
||||
let n_feat = self.n_feat();
|
||||
let n_pairs = self.n_pairs();
|
||||
|
||||
// ---- Kernel signature ------------------------------------------
|
||||
// luminal's CustomOp dispatcher calls the kernel as
|
||||
// kernel(output_ptr, input_ptrs...)
|
||||
// — see `host/cublaslt`'s C/D ordering and matmul2d's
|
||||
// `matmul_2d_kernel(float* C, const float* A, ...)`. Match that
|
||||
// by putting `out` first, then the inputs in the same order as
|
||||
// emit_megakernel builds the inputs vec.
|
||||
let mut sig = String::from(
|
||||
" float* __restrict__ out,\n const float* __restrict__ dense_x,\n",
|
||||
);
|
||||
for k in 0..self.n_sparse {
|
||||
// 32-bit signed — see module docstring re: luminal's Int collapse.
|
||||
sig.push_str(&format!(" const int* __restrict__ idx_{k},\n"));
|
||||
}
|
||||
for k in 0..self.n_sparse {
|
||||
sig.push_str(&format!(" const float* __restrict__ emb_{k}_w,\n"));
|
||||
}
|
||||
// Bot MLP: one Linear per (ln_bot[i] → ln_bot[i+1]). Stored
|
||||
// PyTorch-style as (out, in), bias (out,).
|
||||
for i in 0..self.ln_bot.len() - 1 {
|
||||
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_w,\n"));
|
||||
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_b,\n"));
|
||||
}
|
||||
for i in 0..self.ln_top.len() - 1 {
|
||||
let trail = if i == self.ln_top.len() - 2 { "" } else { "," };
|
||||
sig.push_str(&format!(" const float* __restrict__ top_l{i}_w,\n"));
|
||||
sig.push_str(&format!(" const float* __restrict__ top_l{i}_b{trail}\n"));
|
||||
}
|
||||
|
||||
// ---- Body --------------------------------------------------------
|
||||
let mut body = String::new();
|
||||
|
||||
// 1. Load dense row into registers.
|
||||
body.push_str(&format!(
|
||||
" // Bot MLP layer 0 input: dense row\n \
|
||||
float layer_in[{}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {n_dense_in}; ++i) layer_in[i] = dense_x[bi * {n_dense_in} + i];\n\n",
|
||||
self.ln_bot[0],
|
||||
n_dense_in = self.n_dense_in,
|
||||
));
|
||||
|
||||
// 2. Bot MLP — sequence of Linear+ReLU. Output of last layer
|
||||
// becomes `x[m_spa]` for the interaction.
|
||||
for i in 0..self.ln_bot.len() - 1 {
|
||||
let in_w = self.ln_bot[i];
|
||||
let out_w = self.ln_bot[i + 1];
|
||||
body.push_str(&format!(
|
||||
" // Bot Linear {i}: ({in_w} → {out_w}) + ReLU\n \
|
||||
float bot_l{i}_out[{out_w}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {out_w}; ++j) {{\n \
|
||||
float a = bot_l{i}_b[j];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i_ = 0; i_ < {in_w}; ++i_) a += layer_in[i_] * bot_l{i}_w[j*{in_w} + i_];\n \
|
||||
bot_l{i}_out[j] = fmaxf(a, 0.0f);\n \
|
||||
}}\n \
|
||||
// shuffle output into `layer_in` for the next iteration / interaction\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {out_w}; ++i) layer_in[i] = bot_l{i}_out[i];\n\n",
|
||||
));
|
||||
}
|
||||
// After the loop, `layer_in[..m_spa]` holds dense_out ("x").
|
||||
body.push_str(&format!(
|
||||
" float x[{m_spa}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {m_spa}; ++i) x[i] = layer_in[i];\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
|
||||
// 3. Sparse embedding gathers (one row per table, bag size 1).
|
||||
for k in 0..self.n_sparse {
|
||||
body.push_str(&format!(
|
||||
" // Embedding lookup {k}\n \
|
||||
float ly_{k}[{m_spa}];\n \
|
||||
{{\n \
|
||||
int i_{k} = idx_{k}[bi];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {m_spa}; ++j) ly_{k}[j] = emb_{k}_w[i_{k}*{m_spa} + j];\n \
|
||||
}}\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
}
|
||||
|
||||
// 4. Dot interaction: compute n_pairs strictly-lower-tri dot products
|
||||
// over the n_feat = 1 + n_sparse vectors (x, ly_0, ly_1, ...).
|
||||
// Order matches MiniDLRM._interact: for i in 0..n_feat for j in 0..i.
|
||||
// Vec[0] = x, Vec[k+1] = ly_k.
|
||||
body.push_str(&format!(" float zflat[{n_pairs}];\n"));
|
||||
let vec_name = |idx: usize| -> String {
|
||||
if idx == 0 {
|
||||
"x".to_string()
|
||||
} else {
|
||||
format!("ly_{}", idx - 1)
|
||||
}
|
||||
};
|
||||
let mut pair_idx = 0usize;
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
let a = vec_name(i);
|
||||
let b = vec_name(j);
|
||||
let mut terms = Vec::with_capacity(self.m_spa);
|
||||
for d in 0..self.m_spa {
|
||||
terms.push(format!("{a}[{d}]*{b}[{d}]"));
|
||||
}
|
||||
body.push_str(&format!(
|
||||
" zflat[{pair_idx}] = {};\n",
|
||||
terms.join(" + ")
|
||||
));
|
||||
pair_idx += 1;
|
||||
}
|
||||
}
|
||||
body.push('\n');
|
||||
|
||||
// 5. R = cat([x, zflat]) → top MLP input.
|
||||
let r_len = self.m_spa + n_pairs;
|
||||
body.push_str(&format!(
|
||||
" float r[{r_len}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {m_spa}; ++i) r[i] = x[i];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {n_pairs}; ++i) r[{m_spa} + i] = zflat[i];\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
|
||||
// 6. Top MLP: Linear+ReLU chain, ending with Linear+Sigmoid.
|
||||
// We treat `r` as the first layer input and reuse a single
|
||||
// register array `top_in[]` for subsequent layers.
|
||||
let max_top = *self.ln_top.iter().max().unwrap();
|
||||
body.push_str(&format!(
|
||||
" float top_in[{max_top}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {r_len}; ++i) top_in[i] = r[i];\n\n",
|
||||
));
|
||||
let n_top_layers = self.ln_top.len() - 1;
|
||||
for i in 0..n_top_layers {
|
||||
let in_w = self.ln_top[i];
|
||||
let out_w = self.ln_top[i + 1];
|
||||
let is_last = i == n_top_layers - 1;
|
||||
body.push_str(&format!(
|
||||
" // Top Linear {i}: ({in_w} → {out_w})\n \
|
||||
float top_l{i}_out[{out_w}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {out_w}; ++j) {{\n \
|
||||
float a = top_l{i}_b[j];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i_ = 0; i_ < {in_w}; ++i_) a += top_in[i_] * top_l{i}_w[j*{in_w} + i_];\n \
|
||||
top_l{i}_out[j] = {activation};\n \
|
||||
}}\n",
|
||||
activation = if is_last {
|
||||
"1.0f / (1.0f + __expf(-a))"
|
||||
} else {
|
||||
"fmaxf(a, 0.0f)"
|
||||
},
|
||||
));
|
||||
if !is_last {
|
||||
body.push_str(&format!(
|
||||
" #pragma unroll\n \
|
||||
for (int i = 0; i < {out_w}; ++i) top_in[i] = top_l{i}_out[i];\n\n",
|
||||
));
|
||||
} else {
|
||||
// Final layer: write to global output. ln_top.last() == 1
|
||||
// so this is just a single value.
|
||||
body.push_str(&format!(
|
||||
" out[bi] = top_l{i}_out[0];\n",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble the full source.
|
||||
format!(
|
||||
"extern \"C\" __global__ void dlrm_mega(\n{sig}) {{\n \
|
||||
int bi = blockIdx.x * blockDim.x + threadIdx.x;\n \
|
||||
if (bi >= {batch}) return;\n\n\
|
||||
{body}\
|
||||
}}\n",
|
||||
batch = self.batch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for DlrmMegaKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
self.validate();
|
||||
let kernel = self.cuda_source();
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
if std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").is_ok() {
|
||||
let path = "/tmp/dlrm_megakernel_generated.cu";
|
||||
let _ = std::fs::write(path, &kernel);
|
||||
eprintln!("[DlrmMegaKernel] wrote generated source to {path}");
|
||||
}
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("nvrtc compile failed for DLRM megakernel");
|
||||
let module = stream.context().load_module(ptx).expect("load_module");
|
||||
let func = module
|
||||
.load_function("dlrm_mega")
|
||||
.expect("load_function dlrm_mega");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
const BLOCK: usize = 128;
|
||||
let grid_x = self.batch.div_ceil(BLOCK);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(BLOCK),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per batch row:
|
||||
// dense: n_dense_in × f32
|
||||
// indices: n_sparse × i64
|
||||
// embs: n_sparse × m_spa × f32 (single row each)
|
||||
// bot Ws: sum(in*out) for each layer, × f32 (shared across batch — costed once)
|
||||
// bot bs: sum(out) × f32
|
||||
// top Ws/bs same shape
|
||||
let bot_w: usize = (0..self.ln_bot.len() - 1)
|
||||
.map(|i| self.ln_bot[i] * self.ln_bot[i + 1])
|
||||
.sum();
|
||||
let bot_b: usize = self.ln_bot.iter().skip(1).sum();
|
||||
let top_w: usize = (0..self.ln_top.len() - 1)
|
||||
.map(|i| self.ln_top[i] * self.ln_top[i + 1])
|
||||
.sum();
|
||||
let top_b: usize = self.ln_top.iter().skip(1).sum();
|
||||
let per_row =
|
||||
self.n_dense_in * 4 + self.n_sparse * 8 + self.n_sparse * self.m_spa * 4;
|
||||
let weights = (bot_w + bot_b + top_w + top_b) * 4;
|
||||
Expression::from(self.batch * per_row + weights)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// batch × 1 × f32
|
||||
Expression::from(self.batch * 4)
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Per row:
|
||||
// bot Linears: 2*in*out + out (FMAs + bias)
|
||||
// embedding gathers: 0 FMAs (loads)
|
||||
// dot interaction: n_pairs × m_spa MACs
|
||||
// top Linears: 2*in*out + out + (relu/sigmoid cost ~5)
|
||||
let bot: usize = (0..self.ln_bot.len() - 1)
|
||||
.map(|i| 2 * self.ln_bot[i] * self.ln_bot[i + 1] + self.ln_bot[i + 1])
|
||||
.sum();
|
||||
let dot = self.n_pairs() * self.m_spa * 2;
|
||||
let top: usize = (0..self.ln_top.len() - 1)
|
||||
.map(|i| 2 * self.ln_top[i] * self.ln_top[i + 1] + self.ln_top[i + 1])
|
||||
.sum();
|
||||
Expression::from(self.batch * (bot + dot + top + 5))
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"DlrmMega"
|
||||
}
|
||||
}
|
||||
|
||||
/// `CustomOp` wrapper for [`DlrmMegaKernel`]. Same pattern as
|
||||
/// [`crate::kernel::matmul2d::Matmul2DCustom`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DlrmMegaCustom(pub DlrmMegaKernel);
|
||||
|
||||
impl CustomOp for DlrmMegaCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mini_dlrm() -> DlrmMegaKernel {
|
||||
DlrmMegaKernel {
|
||||
batch: 2048,
|
||||
n_dense_in: 13,
|
||||
ln_bot: vec![13, 8, 4],
|
||||
n_sparse: 3,
|
||||
vocab_sizes: vec![10, 20, 30],
|
||||
m_spa: 4,
|
||||
ln_top: vec![10, 8, 1],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shape_invariants() {
|
||||
let k = mini_dlrm();
|
||||
assert_eq!(k.n_feat(), 4);
|
||||
assert_eq!(k.n_pairs(), 6);
|
||||
assert_eq!(k.ln_top[0], k.m_spa + k.n_pairs());
|
||||
k.validate();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_source_compiles_in_format() {
|
||||
let src = mini_dlrm().cuda_source();
|
||||
// Sanity checks on the generated source — no nvrtc invocation here,
|
||||
// just verify the structural pieces exist.
|
||||
assert!(src.contains("extern \"C\" __global__ void dlrm_mega"));
|
||||
assert!(src.contains("if (bi >= 2048)"));
|
||||
// 3 embedding lookups
|
||||
assert!(src.contains("ly_0[") && src.contains("ly_1[") && src.contains("ly_2["));
|
||||
// 6 dot products
|
||||
assert!(src.contains("zflat[5]"));
|
||||
// Sigmoid epilogue
|
||||
assert!(src.contains("1.0f / (1.0f + __expf(-a))"));
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod dlrm_megakernel;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
@@ -19,6 +20,7 @@ pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use dlrm_megakernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
|
||||
@@ -106,6 +106,12 @@ pub(crate) struct CompiledBucket {
|
||||
pub(crate) bucket_indices: FxHashMap<char, usize>,
|
||||
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
|
||||
pub(crate) hlir_synced: bool,
|
||||
/// Cached topological order of exec_graph nodes. Lazily populated on
|
||||
/// first execute() and invalidated only when the exec_graph itself
|
||||
/// changes (compilation, bucket rebuild). Avoids the per-call
|
||||
/// `petgraph::algo::toposort` Vec allocation + traversal — small but
|
||||
/// real in hot inference loops.
|
||||
pub(crate) exec_topo_order: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
impl CompiledBucket {
|
||||
@@ -130,6 +136,7 @@ impl CompiledBucket {
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
bucket_indices: FxHashMap::default(),
|
||||
hlir_synced: false,
|
||||
exec_topo_order: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,6 +334,24 @@ impl CudaRuntime {
|
||||
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
|
||||
let id = id.to_id();
|
||||
// Fast path: if the same pointer is already registered, this is a no-op.
|
||||
// PyTorch's caching allocator routinely hands back the same device
|
||||
// pointer for the same logical tensor on each forward; bench loops in
|
||||
// particular hammer this. Skipping the cudarc upgrade_device_ptr +
|
||||
// ManuallyDrop reallocation + the changed_hlir insert + the per-bucket
|
||||
// ptr re-cache that fires on the next execute saves ~2µs per input.
|
||||
if let Some(CudaInput::Ptr(prev)) = self.hlir_buffers.get(&id) {
|
||||
if *prev == device_ptr {
|
||||
// Refresh the external_buffers view in case n_bytes shrank to
|
||||
// exactly cover the live region; cheap and keeps the slice
|
||||
// length correct without rebuilding the registration.
|
||||
if let Some(ext) = self.external_buffers.get(&id) {
|
||||
if ext.len() == n_bytes {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create CudaSlice view via cudarc's upgrade_device_ptr.
|
||||
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
|
||||
let slice = unsafe {
|
||||
@@ -1465,9 +1490,19 @@ impl Runtime for CudaRuntime {
|
||||
self.apply_output_ptr_registrations();
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
// Populate the topo-order cache lazily — only on first execute for
|
||||
// this bucket. Walking exec_graph + allocating a Vec every iter
|
||||
// measurably shows up at small batches where the kernel work itself
|
||||
// is sub-microsecond and the per-call overhead dominates.
|
||||
{
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
if bucket.exec_topo_order.is_empty() && bucket.exec_graph.node_count() > 0 {
|
||||
bucket.exec_topo_order = toposort(&bucket.exec_graph, None).unwrap();
|
||||
}
|
||||
}
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
|
||||
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
|
||||
for &exec_node in &bucket.exec_topo_order {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
trace!("Executing: {:?}", exec_op);
|
||||
|
||||
@@ -1539,21 +1574,26 @@ impl Runtime for CudaRuntime {
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
|
||||
// Populate last_kernel_stats from HostOps that report stats
|
||||
self.last_kernel_stats.clear();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
for exec_node in bucket.exec_graph.node_indices() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
if let Some(name) = exec_op.internal.stats_name() {
|
||||
self.last_kernel_stats.push(KernelStats {
|
||||
name,
|
||||
execution_time_us: 0.0,
|
||||
bytes_loaded: 0,
|
||||
bytes_stored: 0,
|
||||
flops: 0,
|
||||
bandwidth_gbps: 0.0,
|
||||
tflops: 0.0,
|
||||
});
|
||||
// last_kernel_stats is only read by print_execution_stats() — a
|
||||
// diagnostic API. Populating the Vec on every execute() (looping all
|
||||
// exec nodes and calling stats_name() on each) is wasteful in
|
||||
// production inference loops. Gate it on the profiling flag.
|
||||
if self.profiling {
|
||||
self.last_kernel_stats.clear();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
for exec_node in bucket.exec_graph.node_indices() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
if let Some(name) = exec_op.internal.stats_name() {
|
||||
self.last_kernel_stats.push(KernelStats {
|
||||
name,
|
||||
execution_time_us: 0.0,
|
||||
bytes_loaded: 0,
|
||||
bytes_stored: 0,
|
||||
flops: 0,
|
||||
bandwidth_gbps: 0.0,
|
||||
tflops: 0.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1575,11 +1615,22 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// Free owned input buffers after a step so they're not held until the
|
||||
// next set_data overwrites them. External-pointer inputs (registered
|
||||
// via set_device_ptr) are caller-owned and the runtime doesn't free
|
||||
// their memory either way — consuming them only invalidates the
|
||||
// registration and forces the caller to re-register on the next
|
||||
// execute. That's pure waste in tight inference loops (e.g.
|
||||
// luminal_python's torch.compile backend, which re-invokes execute()
|
||||
// for every forward), so leave external-pointer entries in place.
|
||||
let to_consume: Vec<NodeIndex> = self
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.copied()
|
||||
.iter()
|
||||
.filter(|(hlir_node, input)| {
|
||||
!inputs_with_outputs.contains(hlir_node)
|
||||
&& !matches!(input, CudaInput::Ptr(_))
|
||||
})
|
||||
.map(|(n, _)| *n)
|
||||
.collect();
|
||||
|
||||
for hlir_node in to_consume {
|
||||
|
||||
@@ -865,3 +865,29 @@ Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anyth
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-21 — DLRM compile: silent mis-stride on `index.Tensor` with a None-prefix
|
||||
|
||||
1. **Symptom**: compiling facebookresearch/dlrm through `luminal_backend` failed in the top-MLP with `assertion left == right failed: Dims must match to add tensors. left: [2, 8] right: [6, 8]`. The error surfaced ~5 ops downstream of the actual bug, with no mention of `index` anywhere in the trace.
|
||||
|
||||
2. **Root cause**: `translate_index_tensor` in `crates/luminal_python/rust/src/translator/movement.rs` had two code paths for advanced indexing. The first ran when an `OptionalTensors` arg held exactly one non-None entry on a specific dim (`first_non_none_dim > 0 && index_names.len() == 1`); it correctly used `first_non_none_dim` to gather on the right axis. The second — the general multi-index fall-through — silently ignored `first_non_none_dim` and computed strides/flat-source-shape as if indices always started at dim 0. DLRM's dot interaction does `Z[:, li, lj]` (Z is `[B, ni, nj]`, two 1-D index tensors after a `:`), which hits the multi-index path with `first_non_none_dim = 1`. The translator built strides over `src_shape[..n_indexed] = [B, ni]` and a flat-source of shape `[B*ni, nj]`, instead of striding over `[ni, nj]` with prefix-dim `[B]`. The downstream gather produced a tensor with the wrong leading dim (6 — the index length — instead of B), and the mismatch only blew up later when broadcast-add into the top-MLP hidden state.
|
||||
|
||||
3. **Why it was hard to find**: the trace ends in `process_pt2` with a luminal core assertion about broadcasting in a `+` op. Nothing in the message names the *upstream* op that produced the wrong shape. Worse, the bug only manifests when ALL of {two-plus index tensors, at least one leading `None`, downstream broadcast-sensitive consumer} are present — the common case (`a[idx]`, `a[idx, jdx]` with no prefix) just works. So the bug had survived through every prior model translator test.
|
||||
|
||||
4. **The fix**: split the prefix-aware case into its own helper `translate_index_tensor_with_prefix`. It explicitly partitions `src.shape` into `prefix_dims / indexed_dims / suffix_dims`, builds the flat sub-index over `indexed_dims`, promotes/expands it into the full output shape, and adds a broadcast prefix-offset constructed from `arange`s over each prefix dim. Result is fully-flat `source.gather(absolute_idx)`. The suffix-non-empty case is left guarded with a `bail!` (it's separable but DLRM doesn't need it).
|
||||
|
||||
5. **Principle**: a shape-keyed assumption baked into one branch of a multi-branch translator is a silent footgun — when the fall-through path is reached with a value the assumption rules out, you get *wrong shapes silently*, and the failure surfaces wherever the wrong shape first encounters a consumer that cares. Guard early: if an invariant the code relies on isn't met (here, "indices apply to the leading dims of source"), check it explicitly and `bail!` with the offending shape rather than computing forward. Even better, refactor so the unsupported case routes to a dedicated path the moment the assumption diverges — small risk of double-implementation, large reduction in "compile silently produces wrong output."
|
||||
|
||||
## 2026-05-21 — DLRM compile: `EmbeddingBag` translator gap
|
||||
|
||||
1. **Symptom**: same `luminal_backend` compile, first error: `RuntimeError: Failed to translate node N: torch.ops.aten._embedding_bag_forward_only.default: Unsupported ATen op`. This is the central op of DLRM — every sparse feature lookup decomposes to it via `nn.EmbeddingBag`.
|
||||
|
||||
2. **What's needed**: `_embedding_bag_forward_only(weight, indices, offsets, ..., mode, ...)` produces `output[b] = reduce_op(weight[indices[offsets[b]:offsets[b+1]]])` for each bag `b`. The general case is a *runtime segment reduction* — the bag boundaries depend on `offsets`, which is a runtime tensor — and luminal has no native segment-reduce primitive.
|
||||
|
||||
3. **The fix (in this session)**: add `translate_embedding_bag` covering the uniform-bag-size case, which is what DLRM actually uses. Read `indices.shape[0] = N` and `offsets.shape[0] = B` off the static shape info, compute bag size `K = N / B`, bail if they don't divide. Then gather `[N, D]` (same construction as `translate_embedding`), reshape to `[B, K, D]`, reduce along axis 1 according to `mode` (sum/mean/max). For `K=1` (the eval-time-1-lookup-per-sample DLRM path) skip the reshape+reduce — it's just an `embedding` lookup. `per_sample_weights` and non-uniform bags are guarded with `bail!`.
|
||||
|
||||
4. **Why this works for DLRM but isn't general**: a true segment reduction needs either (a) static knowledge of every segment boundary (what we get when bags are uniform), or (b) a scatter-add primitive that handles per-segment accumulation at runtime. (a) covers DLRM's training/eval data generator and the common recsys case where each sample has K-hot lookups for fixed K. (b) is required for any model that genuinely has variable-length bags per sample (e.g. variable-length feature crossings) and is a follow-up.
|
||||
|
||||
5. **Principle**: when a PyTorch op has no straight-line luminal lowering, look at the *shapes the model actually feeds in* before declaring it unsupportable. A "segment reduction" over offsets is a hard problem in general; "segment reduction where every bag has K elements with K statically known from indices.shape[0]/offsets.shape[0]" is a 5-line gather+reshape+reduce. The PT2 graph carries the shape info for free — use it.
|
||||
|
||||
@@ -127,6 +127,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
//
|
||||
// PyTorch's nn.Linear with bias generates `addmm(bias, input, weight.t())`
|
||||
// with the default `beta=alpha=1.0`. Emitting the multiplies in that
|
||||
// case wastes 2 HLIR nodes per Linear that egglog has to fold later;
|
||||
// for a 4-Linear DLRM that's 8 nodes off the search-space count.
|
||||
// Skip them when the scale is 1.
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
@@ -135,7 +141,9 @@ impl<'a> Translator<'a> {
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
let scaled_input = if beta == 1.0 { input } else { input * beta };
|
||||
let scaled_mm = if alpha == 1.0 { mm } else { mm * alpha };
|
||||
scaled_input + scaled_mm
|
||||
}
|
||||
|
||||
// Convolution
|
||||
@@ -154,6 +162,10 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
"torch.ops.aten._embedding_bag.default"
|
||||
| "torch.ops.aten._embedding_bag_forward_only.default" => {
|
||||
self.translate_embedding_bag(node)?
|
||||
}
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
|
||||
434
crates/luminal_python/rust/src/translator/dlrm_pattern.rs
Normal file
434
crates/luminal_python/rust/src/translator/dlrm_pattern.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
//! DLRM-family pattern matcher for the PT2 translator.
|
||||
//!
|
||||
//! Recognizes the `MiniDLRM` topology in a parsed PT2 graph (bot MLP →
|
||||
//! N sparse `_embedding_bag_forward_only` lookups (bag-size 1) →
|
||||
//! dot-product interaction via `bmm` + lower-triangular `index.Tensor` →
|
||||
//! top MLP ending in `sigmoid`) and, when matched, emits a single
|
||||
//! [`luminal_cuda_lite::kernel::DlrmMegaCustom`] op that replaces the
|
||||
//! entire per-node translation. The runtime then sees ONE host op
|
||||
//! instead of the 8 cuBLAS+CudaGraphOp ops the normal path produces.
|
||||
//!
|
||||
//! The matcher is intentionally conservative — any mismatch returns
|
||||
//! `None` and the translator falls back to its standard node-by-node
|
||||
//! walk, so wrong-graphs never produce wrong-output, only "the fast
|
||||
//! path didn't trigger." Diagnostic prints are gated on
|
||||
//! `LUMINAL_DLRM_MEGAKERNEL_DEBUG=1` for development.
|
||||
//!
|
||||
//! See `examples/dlrm/src/megakernel.rs` for the standalone proof of
|
||||
//! concept and `crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs`
|
||||
//! for the parameterized kernel itself.
|
||||
//!
|
||||
//! Companion plan: see `/home/ubuntu/.claude/plans/can-you-plan-out-mossy-wave.md`.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
|
||||
use crate::pt2_parser::ParsedPT2;
|
||||
use crate::pt2_schema::Node;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Resolved DLRM shape + the PT2 graph names of every tensor the
|
||||
/// megakernel needs as input. All weight/input lookups go through
|
||||
/// `Translator::get_tensor(name)` which is keyed by PT2 graph_name.
|
||||
#[derive(Debug)]
|
||||
pub(super) struct DlrmShape {
|
||||
pub batch: usize,
|
||||
pub n_dense_in: usize,
|
||||
pub ln_bot: Vec<usize>,
|
||||
pub n_sparse: usize,
|
||||
pub vocab_sizes: Vec<usize>,
|
||||
pub m_spa: usize,
|
||||
pub ln_top: Vec<usize>,
|
||||
|
||||
pub dense_input_name: String,
|
||||
pub index_input_names: Vec<String>, // length n_sparse
|
||||
pub emb_weight_names: Vec<String>, // length n_sparse
|
||||
pub bot_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
|
||||
pub top_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
|
||||
pub output_name: String,
|
||||
}
|
||||
|
||||
fn debug_enabled() -> bool {
|
||||
std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").map(|v| v == "1").unwrap_or(false)
|
||||
}
|
||||
|
||||
macro_rules! dbgln {
|
||||
($($arg:tt)*) => {
|
||||
if debug_enabled() {
|
||||
eprintln!("[dlrm_pattern] {}", format!($($arg)*));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Try to interpret the parsed PT2 program as a DLRM-shape forward.
|
||||
/// Returns `None` if any structural check fails — translator falls back
|
||||
/// to the standard dispatch.
|
||||
pub(super) fn match_dlrm(parsed: &ParsedPT2) -> Option<DlrmShape> {
|
||||
let nodes = &parsed.program.graph_module.graph.nodes;
|
||||
|
||||
// ---- 1. Index the key op types ----------------------------------
|
||||
let emb_node_idxs: Vec<usize> = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.target == "torch.ops.aten._embedding_bag_forward_only.default"
|
||||
|| n.target == "torch.ops.aten._embedding_bag.default")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
if emb_node_idxs.is_empty() {
|
||||
dbgln!("no embedding_bag nodes — not DLRM");
|
||||
return None;
|
||||
}
|
||||
let n_sparse = emb_node_idxs.len();
|
||||
let first_emb = emb_node_idxs[0];
|
||||
let last_emb = *emb_node_idxs.last().unwrap();
|
||||
|
||||
let addmm_idxs: Vec<usize> = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.target == "torch.ops.aten.addmm.default")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
let bot_addmms: Vec<usize> =
|
||||
addmm_idxs.iter().filter(|&&i| i < first_emb).copied().collect();
|
||||
let top_addmms: Vec<usize> =
|
||||
addmm_idxs.iter().filter(|&&i| i > last_emb).copied().collect();
|
||||
if bot_addmms.is_empty() || top_addmms.is_empty() {
|
||||
dbgln!(
|
||||
"addmm split: bot={}, top={} (expected ≥1 each)",
|
||||
bot_addmms.len(),
|
||||
top_addmms.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let sigmoid_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, n)| n.target == "torch.ops.aten.sigmoid.default")
|
||||
.map(|(i, _)| i)?;
|
||||
if sigmoid_idx < *top_addmms.last().unwrap() {
|
||||
dbgln!("sigmoid before last top addmm — not DLRM ordering");
|
||||
return None;
|
||||
}
|
||||
|
||||
let bmm_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, n)| n.target == "torch.ops.aten.bmm.default")
|
||||
.map(|(i, _)| i)?;
|
||||
if bmm_idx < last_emb || bmm_idx > top_addmms[0] {
|
||||
dbgln!("bmm position wrong (idx {bmm_idx}, last_emb {last_emb}, first_top_addmm {})", top_addmms[0]);
|
||||
return None;
|
||||
}
|
||||
|
||||
// index.Tensor must exist between bmm and the first top addmm — that's
|
||||
// the (li, lj) gather of the lower-triangular pairs.
|
||||
let _index_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(i, n)| n.target == "torch.ops.aten.index.Tensor" && *i > bmm_idx)
|
||||
.map(|(i, _)| i)?;
|
||||
|
||||
// ---- 2. Extract embedding info (vocab, m_spa, indices, weights) -
|
||||
let mut vocab_sizes = Vec::with_capacity(n_sparse);
|
||||
let mut emb_weight_names = Vec::with_capacity(n_sparse);
|
||||
let mut index_input_names = Vec::with_capacity(n_sparse);
|
||||
let mut batch_opt: Option<usize> = None;
|
||||
let mut m_spa_opt: Option<usize> = None;
|
||||
|
||||
for &i in &emb_node_idxs {
|
||||
let n = &nodes[i];
|
||||
// Validate the bag invariants the megakernel relies on.
|
||||
// arg ordering: (weight, indices, offsets, scale_grad_by_freq, mode,
|
||||
// sparse, per_sample_weights, include_last_offset, padding_idx)
|
||||
let weight_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
|
||||
let indices_name = n.inputs.get(1)?.arg.as_tensor_name()?.to_string();
|
||||
let offsets_name = n.inputs.get(2)?.arg.as_tensor_name()?.to_string();
|
||||
|
||||
// mode must be 0 (sum) — anything else falls back.
|
||||
let mode = n.inputs.get(4).and_then(|a| a.arg.as_int()).unwrap_or(0);
|
||||
if mode != 0 {
|
||||
dbgln!("embedding_bag mode={mode} != 0 (sum)");
|
||||
return None;
|
||||
}
|
||||
// per_sample_weights must be None (no tensor arg in slot 6).
|
||||
if let Some(arg) = n.inputs.get(6)
|
||||
&& arg.arg.as_tensor_name().is_some()
|
||||
{
|
||||
dbgln!("embedding_bag has per_sample_weights — not supported");
|
||||
return None;
|
||||
}
|
||||
// include_last_offset must be false.
|
||||
if matches!(
|
||||
n.inputs.get(7).and_then(|a| a.arg.as_bool()),
|
||||
Some(true)
|
||||
) {
|
||||
dbgln!("embedding_bag include_last_offset=true — not supported");
|
||||
return None;
|
||||
}
|
||||
|
||||
let weight_meta = parsed.tensor_meta(&weight_name)?;
|
||||
if weight_meta.sizes.len() != 2 {
|
||||
dbgln!("embedding weight has non-2D shape");
|
||||
return None;
|
||||
}
|
||||
let v = weight_meta.sizes[0].hint()? as usize;
|
||||
let m = weight_meta.sizes[1].hint()? as usize;
|
||||
if let Some(prev) = m_spa_opt
|
||||
&& prev != m
|
||||
{
|
||||
dbgln!("inconsistent m_spa across embeddings ({prev} vs {m})");
|
||||
return None;
|
||||
}
|
||||
m_spa_opt = Some(m);
|
||||
|
||||
// Bag-size-1: indices.len == offsets.len == batch.
|
||||
let idx_meta = parsed.tensor_meta(&indices_name)?;
|
||||
let off_meta = parsed.tensor_meta(&offsets_name)?;
|
||||
if idx_meta.sizes.len() != 1 || off_meta.sizes.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let idx_len = idx_meta.sizes[0].hint()? as usize;
|
||||
let off_len = off_meta.sizes[0].hint()? as usize;
|
||||
if idx_len != off_len {
|
||||
dbgln!(
|
||||
"non-uniform bag (indices={idx_len}, offsets={off_len}) — fallback"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
if let Some(prev) = batch_opt
|
||||
&& prev != idx_len
|
||||
{
|
||||
dbgln!("inconsistent batch across embeddings ({prev} vs {idx_len})");
|
||||
return None;
|
||||
}
|
||||
batch_opt = Some(idx_len);
|
||||
|
||||
vocab_sizes.push(v);
|
||||
emb_weight_names.push(weight_name);
|
||||
index_input_names.push(indices_name);
|
||||
}
|
||||
let m_spa = m_spa_opt?;
|
||||
let batch = batch_opt?;
|
||||
|
||||
// ---- 3. Reconstruct bot/top MLP widths --------------------------
|
||||
//
|
||||
// addmm(bias, input, weight^T) → output (B, out)
|
||||
// inputs[0] = bias (out,) — gives us the layer's out_features
|
||||
// inputs[1] = input (B, in_w) — first addmm in each chain tells us in_w
|
||||
// inputs[2] = weight^T — usually produced by a `permute.default`
|
||||
// whose input is the (out, in) weight param.
|
||||
|
||||
let extract_chain_shape = |chain: &[usize]| -> Option<Vec<usize>> {
|
||||
let mut ln = Vec::with_capacity(chain.len() + 1);
|
||||
for (i, &node_idx) in chain.iter().enumerate() {
|
||||
let n = &nodes[node_idx];
|
||||
let bias_name = n.inputs.first()?.arg.as_tensor_name()?;
|
||||
let bias_meta = parsed.tensor_meta(bias_name)?;
|
||||
if bias_meta.sizes.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let out = bias_meta.sizes[0].hint()? as usize;
|
||||
if i == 0 {
|
||||
let input_name = n.inputs.get(1)?.arg.as_tensor_name()?;
|
||||
let in_meta = parsed.tensor_meta(input_name)?;
|
||||
if in_meta.sizes.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let in_w = in_meta.sizes[1].hint()? as usize;
|
||||
ln.push(in_w);
|
||||
}
|
||||
ln.push(out);
|
||||
}
|
||||
Some(ln)
|
||||
};
|
||||
let ln_bot = extract_chain_shape(&bot_addmms)?;
|
||||
let ln_top = extract_chain_shape(&top_addmms)?;
|
||||
|
||||
// ---- 4. Shape consistency checks --------------------------------
|
||||
if *ln_bot.last()? != m_spa {
|
||||
dbgln!("ln_bot.last() = {} != m_spa {m_spa}", ln_bot.last()?);
|
||||
return None;
|
||||
}
|
||||
let n_feat = 1 + n_sparse;
|
||||
let n_pairs = n_feat * (n_feat - 1) / 2;
|
||||
if ln_top[0] != m_spa + n_pairs {
|
||||
dbgln!(
|
||||
"ln_top[0] = {} != m_spa+n_pairs = {}",
|
||||
ln_top[0],
|
||||
m_spa + n_pairs
|
||||
);
|
||||
return None;
|
||||
}
|
||||
if *ln_top.last()? != 1 {
|
||||
dbgln!("ln_top.last() = {} != 1", ln_top.last()?);
|
||||
return None;
|
||||
}
|
||||
if vocab_sizes.len() != n_sparse {
|
||||
return None;
|
||||
}
|
||||
if ln_bot.len() < 2 || ln_top.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// ---- 5. Pull weight + bias parameter names ----------------------
|
||||
let extract_weights = |chain: &[usize]| -> Option<Vec<(String, String)>> {
|
||||
let mut out = Vec::with_capacity(chain.len());
|
||||
for &node_idx in chain {
|
||||
let n = &nodes[node_idx];
|
||||
let bias_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
|
||||
let mat2_name = n.inputs.get(2)?.arg.as_tensor_name()?;
|
||||
let weight_name = resolve_weight_param(nodes, mat2_name)?;
|
||||
out.push((weight_name, bias_name));
|
||||
}
|
||||
Some(out)
|
||||
};
|
||||
let bot_weight_names = extract_weights(&bot_addmms)?;
|
||||
let top_weight_names = extract_weights(&top_addmms)?;
|
||||
|
||||
// ---- 6. dense_input + output names ------------------------------
|
||||
let dense_input_name = nodes[bot_addmms[0]]
|
||||
.inputs
|
||||
.get(1)?
|
||||
.arg
|
||||
.as_tensor_name()?
|
||||
.to_string();
|
||||
// Validate it's actually a user input (not an intermediate).
|
||||
let user_input_names: std::collections::HashSet<&str> = parsed
|
||||
.classify_inputs()
|
||||
.iter()
|
||||
.filter_map(|i| match i {
|
||||
crate::pt2_parser::InputKind::UserInput { graph_name } => Some(graph_name.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.map(|s| s.to_string())
|
||||
.collect::<std::collections::HashSet<String>>()
|
||||
.iter()
|
||||
.map(|s| -> &str { unsafe { std::mem::transmute::<&str, &str>(s.as_str()) } })
|
||||
.collect();
|
||||
let _ = user_input_names; // suppress dead_code if not used; cleaner check below
|
||||
// (Simpler: just check the name is in classified user inputs by string.)
|
||||
let inputs = parsed.classify_inputs();
|
||||
let is_user = inputs.iter().any(|i| {
|
||||
matches!(
|
||||
i,
|
||||
crate::pt2_parser::InputKind::UserInput { graph_name } if graph_name == &dense_input_name
|
||||
)
|
||||
});
|
||||
if !is_user {
|
||||
dbgln!("dense_input candidate {dense_input_name} is not a user input");
|
||||
return None;
|
||||
}
|
||||
|
||||
let output_name = nodes[sigmoid_idx]
|
||||
.outputs
|
||||
.first()?
|
||||
.as_tensor
|
||||
.as_ref()?
|
||||
.name
|
||||
.clone();
|
||||
|
||||
let shape = DlrmShape {
|
||||
batch,
|
||||
n_dense_in: ln_bot[0],
|
||||
ln_bot,
|
||||
n_sparse,
|
||||
vocab_sizes,
|
||||
m_spa,
|
||||
ln_top,
|
||||
dense_input_name,
|
||||
index_input_names,
|
||||
emb_weight_names,
|
||||
bot_weight_names,
|
||||
top_weight_names,
|
||||
output_name,
|
||||
};
|
||||
dbgln!(
|
||||
"matched DLRM: batch={} ln_bot={:?} n_sparse={} vocabs={:?} m_spa={} ln_top={:?}",
|
||||
shape.batch,
|
||||
shape.ln_bot,
|
||||
shape.n_sparse,
|
||||
shape.vocab_sizes,
|
||||
shape.m_spa,
|
||||
shape.ln_top
|
||||
);
|
||||
Some(shape)
|
||||
}
|
||||
|
||||
/// Walk back from an addmm's `mat2` argument to the underlying weight
|
||||
/// parameter. PyTorch's `nn.Linear` decomposes to
|
||||
/// `permute(weight) → addmm(bias, x, permuted)`, so we expect mat2 to be
|
||||
/// the output of a `permute.default` node whose input is the weight.
|
||||
/// If mat2 is itself a graph input (no producing node), it IS the weight.
|
||||
fn resolve_weight_param(nodes: &[Node], name: &str) -> Option<String> {
|
||||
for n in nodes {
|
||||
let Some(first_out) = n.outputs.first().and_then(|o| o.as_tensor.as_ref()) else {
|
||||
continue;
|
||||
};
|
||||
if first_out.name == name {
|
||||
// mat2 was produced by an op. Only `permute.default` is expected;
|
||||
// anything else is unfamiliar and we should fall back.
|
||||
if n.target == "torch.ops.aten.permute.default" {
|
||||
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
|
||||
} else if n.target == "torch.ops.aten.t.default" {
|
||||
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
|
||||
} else {
|
||||
dbgln!(
|
||||
"addmm mat2 produced by unexpected op '{}' — fallback",
|
||||
n.target
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
// No producing node — mat2 is a graph input (param) directly.
|
||||
Some(name.to_string())
|
||||
}
|
||||
|
||||
/// Build the megakernel CustomOp inputs vec in the canonical order
|
||||
/// expected by [`DlrmMegaKernel`] and insert it into the translator's
|
||||
/// luminal graph. Registers the result under `shape.output_name` so the
|
||||
/// downstream output-emission loop finds it.
|
||||
pub(super) fn emit_megakernel(t: &mut Translator<'_>, shape: DlrmShape) -> Result<()> {
|
||||
// Resolve every input tensor by PT2 graph_name through Translator.tensors.
|
||||
let mut inputs: Vec<GraphTensor> = Vec::new();
|
||||
inputs.push(
|
||||
t.get_tensor(&shape.dense_input_name)
|
||||
.with_context(|| format!("dense input {} not in tensors", shape.dense_input_name))?,
|
||||
);
|
||||
for n in &shape.index_input_names {
|
||||
inputs.push(t.get_tensor(n).with_context(|| format!("index input {n} not in tensors"))?);
|
||||
}
|
||||
for n in &shape.emb_weight_names {
|
||||
inputs.push(t.get_tensor(n).with_context(|| format!("emb weight {n} not in tensors"))?);
|
||||
}
|
||||
for (w, b) in &shape.bot_weight_names {
|
||||
inputs.push(t.get_tensor(w).with_context(|| format!("bot weight {w} not in tensors"))?);
|
||||
inputs.push(t.get_tensor(b).with_context(|| format!("bot bias {b} not in tensors"))?);
|
||||
}
|
||||
for (w, b) in &shape.top_weight_names {
|
||||
inputs.push(t.get_tensor(w).with_context(|| format!("top weight {w} not in tensors"))?);
|
||||
inputs.push(t.get_tensor(b).with_context(|| format!("top bias {b} not in tensors"))?);
|
||||
}
|
||||
|
||||
let kernel = DlrmMegaKernel {
|
||||
batch: shape.batch,
|
||||
n_dense_in: shape.n_dense_in,
|
||||
ln_bot: shape.ln_bot.clone(),
|
||||
n_sparse: shape.n_sparse,
|
||||
vocab_sizes: shape.vocab_sizes.clone(),
|
||||
m_spa: shape.m_spa,
|
||||
ln_top: shape.ln_top.clone(),
|
||||
};
|
||||
let out = t.graph.custom_op(
|
||||
DlrmMegaCustom(kernel),
|
||||
inputs,
|
||||
(shape.batch, 1usize),
|
||||
DType::F32,
|
||||
);
|
||||
t.tensors.insert(shape.output_name.clone(), out);
|
||||
dbgln!("emitted DlrmMegaCustom; output={}", shape.output_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -6,6 +6,8 @@ mod attention;
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod dlrm_pattern;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
@@ -70,12 +72,31 @@ impl<'a> Translator<'a> {
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
|
||||
// Fast path: if the entire forward matches the DLRM family shape,
|
||||
// emit one DlrmMegaCustom op instead of walking nodes. On any
|
||||
// mismatch the matcher returns None and we fall through to the
|
||||
// standard dispatch — no semantic difference, just slower (~503µs
|
||||
// vs ~30µs at bs=2048 for MiniDLRM). CUDA-only: the megakernel
|
||||
// is a CUDA CustomOp.
|
||||
#[cfg(feature = "cuda")]
|
||||
if let Some(shape) = dlrm_pattern::match_dlrm(self.parsed) {
|
||||
dlrm_pattern::emit_megakernel(self, shape)?;
|
||||
return self.emit_outputs();
|
||||
}
|
||||
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
self.translate_node(node)
|
||||
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
|
||||
}
|
||||
self.emit_outputs()
|
||||
}
|
||||
|
||||
/// Walks the parsed graph's user outputs, applies the wrap/cast rules
|
||||
/// that downstream codegen relies on, then attaches an `Output` node
|
||||
/// per user-output. Shared by the normal dispatch path and the DLRM
|
||||
/// megakernel fast path.
|
||||
fn emit_outputs(&mut self) -> Result<()> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
@@ -84,12 +105,20 @@ impl<'a> Translator<'a> {
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
// The `+ 0.0` wrap pulls double duty: it materializes a fresh
|
||||
// buffer for outputs that alias an Input (passthrough
|
||||
// `return x`), AND it acts as an anchor that survives egglog
|
||||
// rewriting, so the downstream runtime can find the producer
|
||||
// node for outputs whose original op (e.g. Reduce with
|
||||
// keepdims, Conv) gets folded away during optimization.
|
||||
// Removing it broke 24 test_hlir_ops tests with "Cannot find
|
||||
// output tensor!" — keep it until that anchor invariant is
|
||||
// refactored elsewhere.
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -256,6 +256,97 @@ impl<'a> Translator<'a> {
|
||||
Ok(weight.gather(ids_expanded + arange_expanded))
|
||||
}
|
||||
|
||||
/// `aten._embedding_bag` / `aten._embedding_bag_forward_only`
|
||||
///
|
||||
/// Signature: (weight, indices, offsets, scale_grad_by_freq=False, mode=0,
|
||||
/// sparse=False, per_sample_weights=None, include_last_offset=False,
|
||||
/// padding_idx=-1) -> (output, offset2bag, bag_size, max_indices)
|
||||
///
|
||||
/// Strategy: for the bag-size-uniform case (N indices spread evenly across
|
||||
/// B bags, i.e. N % B == 0), reshape gather output [N, D] into [B, K, D]
|
||||
/// and reduce along K according to `mode`. We deliberately read uniformity
|
||||
/// off the *static shapes* of `indices` and `offsets` — non-uniform bags
|
||||
/// require a runtime segment-sum primitive we don't yet have.
|
||||
///
|
||||
/// DLRM hits the K=1 special case (offsets=[0,1,...,B-1], indices=[B]) per
|
||||
/// sparse table per sample — the same lookup pattern as `aten.embedding`.
|
||||
/// Only the first tuple element is materialized; the bookkeeping outputs
|
||||
/// (offset2bag, bag_size, max_indices) are inference-time dead ends.
|
||||
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
let offsets = self.get_input_tensor(node, 2)?;
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
let include_last_offset = self.get_bool_arg(node, 7).unwrap_or(false);
|
||||
|
||||
if let Some(arg) = node.inputs.get(6)
|
||||
&& arg.arg.as_tensor_name().is_some()
|
||||
{
|
||||
bail!("_embedding_bag: per_sample_weights not supported");
|
||||
}
|
||||
|
||||
if indices.shape.len() != 1 || offsets.shape.len() != 1 {
|
||||
bail!(
|
||||
"_embedding_bag: expected 1-D indices and offsets, got shapes {:?}, {:?}",
|
||||
indices.shape.dims,
|
||||
offsets.shape.dims
|
||||
);
|
||||
}
|
||||
let n = indices.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("_embedding_bag: indices length must be statically known")?;
|
||||
let b_raw = offsets.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("_embedding_bag: offsets length must be statically known")?;
|
||||
let b = if include_last_offset { b_raw - 1 } else { b_raw };
|
||||
if b == 0 {
|
||||
bail!("_embedding_bag: empty bag set");
|
||||
}
|
||||
if n % b != 0 {
|
||||
bail!(
|
||||
"_embedding_bag: non-uniform bag size not supported (indices={n}, bags={b})"
|
||||
);
|
||||
}
|
||||
let k = n / b;
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
|
||||
// Step 1: gather weight rows. Same construction as translate_embedding —
|
||||
// flatten the (idx, hidden) pair into a single offset into the weight
|
||||
// matrix and gather. Result: [N, D].
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let ids_expanded = (indices_int * hidden_dim).expand_dim(1, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim);
|
||||
let arange_expanded = arange.expand_dim(0, indices.shape.dims[0]);
|
||||
let gathered = weight.gather(ids_expanded + arange_expanded);
|
||||
|
||||
// Step 2: bag-size-1 → already [B, D]; skip reshape/reduce.
|
||||
if k == 1 {
|
||||
return Ok(gathered);
|
||||
}
|
||||
|
||||
// Step 3: reshape [B*K, D] → [B, K, D] (contiguous, identity stride view).
|
||||
let bag_shape = vec![
|
||||
Expression::from(b),
|
||||
Expression::from(k),
|
||||
hidden_dim,
|
||||
];
|
||||
let mut bagged = GraphTensor {
|
||||
id: gathered.id,
|
||||
graph_ref: gathered.graph_ref,
|
||||
shape: ShapeTracker::new(bag_shape),
|
||||
dtype: gathered.dtype,
|
||||
};
|
||||
|
||||
// Step 4: reduce along axis=1.
|
||||
bagged = match mode {
|
||||
0 => bagged.sum(1),
|
||||
1 => bagged.mean(1),
|
||||
2 => bagged.max(1),
|
||||
m => bail!("_embedding_bag: unsupported mode {m} (0=sum, 1=mean, 2=max)"),
|
||||
};
|
||||
Ok(bagged)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let source = self.get_input_tensor(node, 0)?;
|
||||
|
||||
@@ -318,6 +409,20 @@ impl<'a> Translator<'a> {
|
||||
|
||||
let index_names = &index_names;
|
||||
|
||||
// Prefix-of-Nones case: `source[:, ..., :, idx_0, idx_1, ..., idx_{m-1}]`
|
||||
// — indices apply to dims [first..first+m), not [0..m). The original
|
||||
// multi-index path below assumes first==0 and silently mis-strides
|
||||
// (and mis-flattens) when called with a prefix; route to the
|
||||
// prefix-aware path before falling through. Suffix-of-Nones after the
|
||||
// indices is not yet supported here.
|
||||
if first_non_none_dim > 0 {
|
||||
return self.translate_index_tensor_with_prefix(
|
||||
source,
|
||||
index_names,
|
||||
first_non_none_dim,
|
||||
);
|
||||
}
|
||||
|
||||
let src_shape = source.shape.dims;
|
||||
let n_indexed = index_names.len();
|
||||
|
||||
@@ -398,6 +503,132 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Advanced indexing with a `None` prefix: `source[:, ..., :, i0, i1, ...]`.
|
||||
///
|
||||
/// Output shape: `prefix_dims ++ idx_shape ++ suffix_dims` where
|
||||
/// `prefix_dims = src.shape[..first]`, `suffix_dims = src.shape[first+m..]`,
|
||||
/// and `idx_shape` is the broadcast shape of the m index tensors.
|
||||
///
|
||||
/// Currently supports the no-suffix case (indices land on the trailing
|
||||
/// dims). DLRM's dot interaction hits this: `Z[:, li, lj]` with
|
||||
/// `Z: [B, ni, nj]`, `li, lj: [L]`.
|
||||
fn translate_index_tensor_with_prefix(
|
||||
&mut self,
|
||||
source: GraphTensor,
|
||||
index_names: &[crate::pt2_schema::TensorName],
|
||||
first: usize,
|
||||
) -> Result<GraphTensor> {
|
||||
let src_shape = source.shape.dims;
|
||||
let n_indexed = index_names.len();
|
||||
let src_rank = src_shape.len();
|
||||
if first + n_indexed > src_rank {
|
||||
bail!(
|
||||
"index.Tensor (prefix): {n_indexed} indices starting at dim {first} \
|
||||
exceed source rank {src_rank}"
|
||||
);
|
||||
}
|
||||
let prefix_dims: Vec<Expression> = src_shape[..first].to_vec();
|
||||
let indexed_dims: Vec<Expression> = src_shape[first..first + n_indexed].to_vec();
|
||||
let suffix_dims: Vec<Expression> = src_shape[first + n_indexed..].to_vec();
|
||||
if !suffix_dims.is_empty() {
|
||||
bail!(
|
||||
"index.Tensor (prefix): trailing-dim suffix after indices not \
|
||||
supported (prefix={} indexed={} suffix={})",
|
||||
prefix_dims.len(),
|
||||
n_indexed,
|
||||
suffix_dims.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Per-axis strides within the indexed subspace (right-to-left product).
|
||||
let mut strides = vec![Expression::from(1usize); n_indexed];
|
||||
for i in (0..n_indexed - 1).rev() {
|
||||
strides[i] = strides[i + 1] * indexed_dims[i + 1];
|
||||
}
|
||||
let indexed_size = indexed_dims
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(Expression::from(1usize), |a, b| a * b);
|
||||
|
||||
// Collapse the m index tensors into a single flat index in the indexed
|
||||
// subspace. Negative entries get normalized per axis.
|
||||
let mut flat_idx: Option<GraphTensor> = None;
|
||||
for (i, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_t = self.get_tensor(&idx_name.name)?.cast(DType::Int);
|
||||
let axis_size = indexed_dims[i];
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_t.shape);
|
||||
let is_neg = idx_t.lt(zero).cast(DType::Int);
|
||||
let idx_norm = idx_t + is_neg * axis_size;
|
||||
let stride = strides[i];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
idx_norm
|
||||
} else {
|
||||
idx_norm * stride
|
||||
};
|
||||
flat_idx = Some(match flat_idx {
|
||||
Some(acc) => {
|
||||
let (a, w) = broadcast_binary(acc, weighted);
|
||||
a + w
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
let flat_idx = flat_idx.context("index.Tensor (prefix): no indices")?;
|
||||
let idx_shape: Vec<Expression> = flat_idx.shape.dims.to_vec();
|
||||
|
||||
// Build the absolute flat index over `source` viewed as 1D, shape
|
||||
// `prefix_dims ++ idx_shape`:
|
||||
// abs[p..., k...] = flat_prefix(p...) * indexed_size + flat_idx[k...]
|
||||
// Construct by promoting `flat_idx` to the full rank then adding a
|
||||
// broadcast prefix-offset tensor.
|
||||
let mut full_shape: Vec<Expression> = prefix_dims.clone();
|
||||
full_shape.extend_from_slice(&idx_shape);
|
||||
|
||||
// Promote flat_idx: insert prefix_dims leading axes, then expand.
|
||||
let mut idx_promoted = flat_idx;
|
||||
for _ in 0..prefix_dims.len() {
|
||||
idx_promoted = idx_promoted.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
idx_promoted.shape.expand(full_shape.clone());
|
||||
|
||||
// Prefix offset: for each prefix dim pi (right-to-left), accumulate
|
||||
// arange(prefix_dims[pi]) * (product_of_more_inner_prefix_dims * indexed_size).
|
||||
let mut prefix_offset: Option<GraphTensor> = None;
|
||||
let mut cum_stride = indexed_size;
|
||||
for (pi, pd) in prefix_dims.iter().enumerate().rev() {
|
||||
let ar = self.graph.arange(*pd) * cum_stride;
|
||||
// arange is shape [pd]; lift it into full_shape at position pi.
|
||||
let mut ar_promoted = ar;
|
||||
for _ in 0..pi {
|
||||
ar_promoted = ar_promoted.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let trailing = full_shape.len() - pi - 1;
|
||||
for _ in 0..trailing {
|
||||
let r = ar_promoted.shape.len();
|
||||
ar_promoted = ar_promoted.expand_dim(r, Expression::from(1usize));
|
||||
}
|
||||
ar_promoted.shape.expand(full_shape.clone());
|
||||
prefix_offset = Some(match prefix_offset {
|
||||
Some(acc) => acc + ar_promoted,
|
||||
None => ar_promoted,
|
||||
});
|
||||
cum_stride = cum_stride * *pd;
|
||||
}
|
||||
|
||||
let final_idx = match prefix_offset {
|
||||
Some(po) => idx_promoted + po,
|
||||
None => idx_promoted,
|
||||
};
|
||||
|
||||
// Flatten source to 1D and gather with the absolute index.
|
||||
let total: Expression = src_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(Expression::from(1usize), |a, b| a * b);
|
||||
let fully_flat = reshape_tensor(source, vec![total]);
|
||||
Ok(fully_flat.gather(final_idx))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_gather(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
|
||||
@@ -43,6 +43,19 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
# Pre-zip + caches for the hot path. The CudaRuntime now preserves
|
||||
# external-pointer registrations across execute() calls and treats
|
||||
# set_device_ptr as a no-op when the pointer is unchanged — caching
|
||||
# the (name, ptr) here avoids the pyo3 round-trip entirely in tight
|
||||
# loops where PyTorch's caching allocator keeps re-handing back the
|
||||
# same tensor (e.g. inference loops with reused activation buffers).
|
||||
self._input_specs = list(zip(self._input_names, self._input_dtypes))
|
||||
self._last_input_ptrs: dict[str, int] = {}
|
||||
# Output dtype/zero-copy decisions are properties of the compiled
|
||||
# graph and never change; computing them lazily and caching avoids
|
||||
# ~10µs of pyo3 calls per iter.
|
||||
self._output_torch_dtypes_cache = None
|
||||
self._output_zero_copy_cache = None
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -89,22 +102,41 @@ class CompiledModel:
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs already in the expected dtype + contiguous, we
|
||||
# skip the detach/contiguous/to chain (those allocate new Tensor
|
||||
# objects even when they're no-ops) and short-circuit set_input_device_ptr
|
||||
# when the pointer hasn't moved since the last call. The runtime
|
||||
# treats same-ptr re-registration as a no-op too, but skipping the
|
||||
# pyo3 round-trip here saves another ~5µs per input.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
graph = self._graph
|
||||
last_input_ptrs = self._last_input_ptrs
|
||||
if self._supports_device_ptrs:
|
||||
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
|
||||
if (
|
||||
tensor.is_cuda
|
||||
and tensor.dtype is expected_dtype
|
||||
and tensor.is_contiguous()
|
||||
):
|
||||
t = tensor
|
||||
else:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
ptr = t.data_ptr()
|
||||
if last_input_ptrs.get(name) != ptr:
|
||||
graph.set_input_device_ptr(name, ptr, t.numel() * t.element_size())
|
||||
last_input_ptrs[name] = ptr
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
else:
|
||||
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
graph.set_input_from_ptr(
|
||||
name,
|
||||
t.data_ptr(),
|
||||
t.numel() * t.element_size(),
|
||||
_torch_dtype_code(t.dtype),
|
||||
)
|
||||
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
|
||||
209
crates/luminal_python/tests/test_dlrm.py
Normal file
209
crates/luminal_python/tests/test_dlrm.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""End-to-end compile tests for a faithful DLRM-style recommender.
|
||||
|
||||
`MiniDLRM` below mirrors `DLRM_Net` from facebookresearch/dlrm:
|
||||
bottom-MLP on dense features, an `EmbeddingBag` per sparse table, dot-product
|
||||
interaction over the (1 + n_sparse) feature vectors, and a top-MLP. The
|
||||
forward signature `(dense_x, lS_o, lS_i)` matches DLRM exactly.
|
||||
|
||||
This is the smallest model that exercises the three translator paths added for
|
||||
DLRM:
|
||||
- `aten._embedding_bag_forward_only.default` (uniform-bag-size lowering)
|
||||
- `aten.index.Tensor` with a `None` prefix (`Z[:, li, lj]`)
|
||||
- the existing `aten.bmm` / `aten.cat` paths under the above feeders
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn as nn
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class MiniDLRM(nn.Module):
|
||||
"""Minimal faithful DLRM (dot interaction, mode='sum' EmbeddingBag)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa: int,
|
||||
ln_emb: list[int],
|
||||
ln_bot: list[int],
|
||||
ln_top: list[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert ln_bot[-1] == m_spa, "bottom MLP must end at m_spa"
|
||||
n_feat = 1 + len(ln_emb)
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
assert ln_top[0] == n_pairs + m_spa, (
|
||||
f"top MLP entry width must equal n_pairs ({n_pairs}) + m_spa ({m_spa}) "
|
||||
f"= {n_pairs + m_spa}, got {ln_top[0]}"
|
||||
)
|
||||
self.m_spa = m_spa
|
||||
self.emb_l = nn.ModuleList(
|
||||
[nn.EmbeddingBag(int(n), m_spa, mode="sum") for n in ln_emb]
|
||||
)
|
||||
self.bot_l = self._build_mlp(ln_bot, sigmoid_last=False)
|
||||
self.top_l = self._build_mlp(ln_top, sigmoid_last=True)
|
||||
|
||||
@staticmethod
|
||||
def _build_mlp(sizes: list[int], sigmoid_last: bool) -> nn.Sequential:
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(len(sizes) - 1):
|
||||
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=True))
|
||||
if i == len(sizes) - 2 and sigmoid_last:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _apply_emb(
|
||||
self, lS_o: list[torch.Tensor], lS_i: list[torch.Tensor]
|
||||
) -> list[torch.Tensor]:
|
||||
return [self.emb_l[k](lS_i[k], lS_o[k]) for k in range(len(self.emb_l))]
|
||||
|
||||
def _interact(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
|
||||
batch_size, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
_, ni, nj = Z.shape
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for _ in range(i)], device=x.device
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i)], device=x.device
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
dense_x: torch.Tensor,
|
||||
lS_o: list[torch.Tensor],
|
||||
lS_i: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
x = self.bot_l(dense_x)
|
||||
ly = self._apply_emb(lS_o, lS_i)
|
||||
z = self._interact(x, ly)
|
||||
return self.top_l(z)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_inputs(
|
||||
batch_size: int,
|
||||
dense_dim: int,
|
||||
ln_emb: list[int],
|
||||
bag_size: int,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
||||
dense_x = torch.rand(batch_size, dense_dim, device=device)
|
||||
if bag_size == 1:
|
||||
offsets = [
|
||||
torch.arange(batch_size, dtype=torch.long, device=device)
|
||||
for _ in ln_emb
|
||||
]
|
||||
indices = [
|
||||
torch.randint(0, int(n), (batch_size,), dtype=torch.long, device=device)
|
||||
for n in ln_emb
|
||||
]
|
||||
else:
|
||||
offsets = [
|
||||
torch.arange(
|
||||
0,
|
||||
batch_size * bag_size,
|
||||
bag_size,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
for _ in ln_emb
|
||||
]
|
||||
indices = [
|
||||
torch.randint(
|
||||
0, int(n), (batch_size * bag_size,), dtype=torch.long, device=device
|
||||
)
|
||||
for n in ln_emb
|
||||
]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def _build_model(
|
||||
m_spa: int,
|
||||
ln_emb: list[int],
|
||||
ln_bot: list[int],
|
||||
device: torch.device,
|
||||
) -> MiniDLRM:
|
||||
torch.manual_seed(0)
|
||||
n_feat = 1 + len(ln_emb)
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
ln_top = [n_pairs + m_spa, 8, 1]
|
||||
model = MiniDLRM(m_spa, ln_emb, ln_bot, ln_top).to(device).eval()
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dlrm_dot_bag1_smallbatch(device: torch.device) -> None:
|
||||
"""The canonical DLRM eval path: 1 lookup per sample per sparse table."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=2, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
|
||||
|
||||
def test_dlrm_dot_bag1_largerbatch(device: torch.device) -> None:
|
||||
"""Larger batch (64) — sanity-check that the bs-1 specialization isn't load-bearing."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=64, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-4)
|
||||
|
||||
|
||||
def test_dlrm_dot_multihot(device: torch.device) -> None:
|
||||
"""Uniform multi-hot bags (bag_size=3) — exercises the reshape+sum path."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=3, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
|
||||
|
||||
def test_dlrm_dot_larger_tables(device: torch.device) -> None:
|
||||
"""Verifies bigger embedding tables don't change the path."""
|
||||
m_spa = 4
|
||||
ln_emb = [50, 100, 200]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
17
examples/dlrm/Cargo.toml
Normal file
17
examples/dlrm/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "dlrm"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "dlrm"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
safetensors = "0.7.0"
|
||||
memmap2 = "0.9.9"
|
||||
bytemuck = "1.24.0"
|
||||
rand = "0.9.2"
|
||||
306
examples/dlrm/bench.py
Normal file
306
examples/dlrm/bench.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""DLRM inference latency benchmark across PyTorch backends + luminal.
|
||||
|
||||
Backends measured:
|
||||
1. PyTorch eager
|
||||
2. torch.compile (default backend = inductor, mode="reduce-overhead")
|
||||
3. AOTInductor (export → aoti_compile_and_package → load → run)
|
||||
4. CUDA graphs (capture-replay around the eager model)
|
||||
5. PyTorch + luminal_backend (torch.compile with our PT2 → luminal backend)
|
||||
|
||||
The rust luminal path is measured separately by the dlrm binary's --bench
|
||||
flag and the two results are combined in the rank table later.
|
||||
|
||||
Methodology:
|
||||
- Same MiniDLRM at the small config, batch_size=2 (matches export.py and
|
||||
the rust binary so the comparison is apples-to-apples).
|
||||
- 50 warmup iters per backend, 500 measured iters.
|
||||
- Per-iteration latency via paired cudaEvent_record + elapsed_time.
|
||||
- Report mean / p50 / p99 in microseconds; also dump every measurement
|
||||
to /tmp/dlrm_bench_<backend>.txt so other consumers can re-aggregate.
|
||||
|
||||
Run:
|
||||
/lambda/nfs/tucker-fs/second/luminal/crates/luminal_python/.venv/bin/python \
|
||||
examples/dlrm/bench.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import statistics
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
# MiniDLRM lives in tests.
|
||||
TESTS_DIR = (
|
||||
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
)
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
WARMUP = 50
|
||||
ITERS = 500
|
||||
|
||||
M_SPA = 4
|
||||
LN_EMB = [10, 20, 30]
|
||||
LN_BOT = [13, 8, M_SPA]
|
||||
LN_TOP = [10, 8, 1]
|
||||
# Real-workload DLRM batch — kernel work dominates per-launch overhead.
|
||||
BATCH = 2048
|
||||
|
||||
|
||||
def make_model() -> torch.nn.Module:
|
||||
torch.manual_seed(0)
|
||||
return MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
|
||||
|
||||
|
||||
def make_inputs() -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
||||
torch.manual_seed(42)
|
||||
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE) for n in LN_EMB
|
||||
]
|
||||
offsets = [torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def time_callable(fn: Callable[[], torch.Tensor], iters: int) -> list[float]:
|
||||
"""Time `fn` over `iters` iterations using CUDA events. Returns per-iter
|
||||
microseconds."""
|
||||
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
torch.cuda.synchronize()
|
||||
for i in range(iters):
|
||||
start_evts[i].record()
|
||||
_ = fn()
|
||||
end_evts[i].record()
|
||||
torch.cuda.synchronize()
|
||||
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
|
||||
|
||||
|
||||
def report(name: str, samples_us: list[float]) -> dict[str, float]:
|
||||
samples_us = sorted(samples_us)
|
||||
n = len(samples_us)
|
||||
mean = sum(samples_us) / n
|
||||
p50 = samples_us[n // 2]
|
||||
p99 = samples_us[int(n * 0.99)]
|
||||
print(f" {name:<32s} mean={mean:8.2f}µs p50={p50:8.2f}µs p99={p99:8.2f}µs")
|
||||
# Dump every sample for downstream aggregation.
|
||||
out_path = f"/tmp/dlrm_bench_{name.replace(' ', '_').replace('(', '').replace(')', '')}.txt"
|
||||
Path(out_path).write_text("\n".join(f"{s:.4f}" for s in samples_us))
|
||||
return {"name": name, "mean": mean, "p50": p50, "p99": p99, "n": n}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_eager() -> dict[str, float]:
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return model(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("eager", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_torch_compile() -> dict[str, float]:
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
compiled = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("torch.compile (inductor)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_aoti() -> dict[str, float]:
|
||||
"""AOTInductor: export → compile-and-package → load → run.
|
||||
|
||||
Note: torch.export currently treats list[Tensor] inputs as pytree-flattened,
|
||||
so the runtime callable takes positional tensors. We unpack manually.
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
dense_x, offsets, indices = make_inputs()
|
||||
|
||||
# Wrap to surface tensor inputs at the top-level positional signature.
|
||||
class FlatWrapper(torch.nn.Module):
|
||||
def __init__(self, m: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.m = m
|
||||
|
||||
def forward(
|
||||
self,
|
||||
dense_x: torch.Tensor,
|
||||
o0: torch.Tensor,
|
||||
o1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
i0: torch.Tensor,
|
||||
i1: torch.Tensor,
|
||||
i2: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.m(dense_x, [o0, o1, o2], [i0, i1, i2])
|
||||
|
||||
flat_model = FlatWrapper(model).to(DEVICE).eval()
|
||||
flat_inputs = (dense_x, *offsets, *indices)
|
||||
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(flat_model, flat_inputs)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
pkg_path = os.path.join(tmp, "dlrm.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pkg_path)
|
||||
loaded = torch._inductor.aoti_load_package(pkg_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return loaded(*flat_inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("AOTInductor", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_cuda_graphs() -> dict[str, float]:
|
||||
"""Capture the eager forward as a CUDA graph, then replay.
|
||||
|
||||
MiniDLRM builds the `li`/`lj` lower-triangular index tensors via
|
||||
`torch.tensor([...], device=...)` inside `_interact`, which triggers a
|
||||
fresh host→device copy each call — and CUDA-graph capture can't observe
|
||||
non-pinned host→device copies. Wrap the model to pre-bake those indices
|
||||
as cuda buffers on the wrapper, then patch the bound method.
|
||||
"""
|
||||
model = make_model()
|
||||
n_feat = 1 + len(LN_EMB)
|
||||
li_const = torch.tensor(
|
||||
[i for i in range(n_feat) for _ in range(i)], device=DEVICE
|
||||
)
|
||||
lj_const = torch.tensor(
|
||||
[j for i in range(n_feat) for j in range(i)], device=DEVICE
|
||||
)
|
||||
|
||||
def _interact_static(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
|
||||
bs, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
Zflat = Z[:, li_const, lj_const]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
# Bind the static version so `self` resolves correctly.
|
||||
import types
|
||||
|
||||
model._interact = types.MethodType(_interact_static, model)
|
||||
|
||||
dense_x, offsets, indices = make_inputs()
|
||||
static_dense = dense_x.clone()
|
||||
static_offsets = [o.clone() for o in offsets]
|
||||
static_indices = [i.clone() for i in indices]
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd() -> torch.Tensor:
|
||||
return model(static_dense, static_offsets, static_indices)
|
||||
|
||||
# CUDA-graph prep: a stream warmup, then capture.
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
_ = fwd()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
static_out = fwd()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
# Real workloads would copy fresh inputs into static_* here. For pure
|
||||
# replay-latency measurement the inputs are constant.
|
||||
g.replay()
|
||||
return static_out
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("CUDA graphs (eager capture)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_luminal_backend() -> dict[str, float]:
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("luminal_backend (PT2)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
torch.cuda.synchronize()
|
||||
print(f"Device: {torch.cuda.get_device_name(0)}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"Config: m_spa={M_SPA} ln_emb={LN_EMB} batch={BATCH} iters={ITERS}\n")
|
||||
|
||||
rows = []
|
||||
for fn in (bench_eager, bench_torch_compile, bench_aoti, bench_cuda_graphs, bench_luminal_backend):
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
rows.append(fn())
|
||||
except Exception as e:
|
||||
print(f" FAILED {fn.__name__}: {type(e).__name__}: {e}")
|
||||
print(f" (setup+bench took {time.perf_counter() - t0:.1f}s)\n")
|
||||
|
||||
# Pull in any externally-produced rust samples (rust luminal binary
|
||||
# writes both —bench and —mega samples to /tmp).
|
||||
for label, path_str in [
|
||||
("rust luminal", "/tmp/dlrm_bench_rust_luminal.txt"),
|
||||
("DLRM megakernel", "/tmp/dlrm_bench_megakernel.txt"),
|
||||
]:
|
||||
p = Path(path_str)
|
||||
if not p.exists():
|
||||
continue
|
||||
samples_us = sorted(float(s) for s in p.read_text().splitlines() if s)
|
||||
n = len(samples_us)
|
||||
rows.append({
|
||||
"name": label,
|
||||
"mean": sum(samples_us) / n,
|
||||
"p50": samples_us[n // 2],
|
||||
"p99": samples_us[int(n * 0.99)],
|
||||
"n": n,
|
||||
})
|
||||
print(f" {label:<32s} mean={rows[-1]['mean']:8.2f}µs "
|
||||
f"p50={rows[-1]['p50']:8.2f}µs p99={rows[-1]['p99']:8.2f}µs "
|
||||
f"(from {path_str})")
|
||||
|
||||
# Rank by mean latency.
|
||||
rows.sort(key=lambda r: r["mean"])
|
||||
print("=" * 60)
|
||||
print("Ranking (mean latency, lower is better):\n")
|
||||
fastest = rows[0]["mean"]
|
||||
print(f" {'#':<3}{'backend':<32s}{'mean µs':>10s}{'vs fastest':>14s}")
|
||||
for i, r in enumerate(rows):
|
||||
ratio = r["mean"] / fastest
|
||||
print(f" {i + 1:<3}{r['name']:<32s}{r['mean']:>10.2f}{ratio:>13.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
328
examples/dlrm/bench_sweep.py
Normal file
328
examples/dlrm/bench_sweep.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""DLRM latency sweep: batch_size × n_sparse_tables × backend.
|
||||
|
||||
Reuses the per-backend timing primitives from `bench.py` but parameterises
|
||||
the model config so we can see how each backend scales along both DLRM's
|
||||
key axes: batch size (parallelism / kernel utilisation) and number of
|
||||
sparse tables (kernel launch count, host-side dispatch cost).
|
||||
|
||||
For each (batch, n_sparse) cell, runs:
|
||||
- PyTorch eager
|
||||
- torch.compile (mode='reduce-overhead')
|
||||
- AOTInductor
|
||||
- CUDA graphs (eager capture)
|
||||
- luminal_backend (PT2)
|
||||
|
||||
The rust luminal path can't be invoked from python; we skip it here. The
|
||||
single-config bench.py remains the cross-check that includes rust.
|
||||
|
||||
Output is one table per backend with rows = batch, cols = n_sparse, plus a
|
||||
final per-cell "winner" matrix.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
TESTS_DIR = (
|
||||
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
)
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
WARMUP = 25
|
||||
ITERS = 200 # halved vs bench.py to keep sweep wall-clock reasonable
|
||||
M_SPA = 4
|
||||
|
||||
# Sweep grid — real-workload DLRM batches where matmul efficiency is what's
|
||||
# actually being compared (sub-100 batches were launch-overhead dominated and
|
||||
# said more about wrapper cost than backend quality).
|
||||
BATCH_SIZES = [256, 1024, 2048, 4096]
|
||||
N_SPARSE_LIST = [3, 8, 16]
|
||||
|
||||
|
||||
def make_model(n_sparse: int):
|
||||
torch.manual_seed(0)
|
||||
# Embedding table vocab sizes: alternate small/medium so the lookups
|
||||
# exercise different table widths without making setup time explode.
|
||||
base_vocabs = [10, 20, 30, 40, 60, 80, 100, 120, 160, 200, 240, 320, 400, 500, 640, 800]
|
||||
ln_emb = base_vocabs[:n_sparse]
|
||||
ln_bot = [13, 8, M_SPA]
|
||||
n_feat = 1 + n_sparse
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
ln_top = [n_pairs + M_SPA, 8, 1]
|
||||
return MiniDLRM(M_SPA, ln_emb, ln_bot, ln_top).to(DEVICE).eval(), ln_emb
|
||||
|
||||
|
||||
def make_inputs(batch: int, ln_emb: list[int]):
|
||||
torch.manual_seed(42)
|
||||
dense_x = torch.rand(batch, 13, device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (batch,), dtype=torch.long, device=DEVICE) for n in ln_emb
|
||||
]
|
||||
offsets = [torch.arange(batch, dtype=torch.long, device=DEVICE) for _ in ln_emb]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def time_callable(fn, iters: int) -> list[float]:
|
||||
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
torch.cuda.synchronize()
|
||||
for i in range(iters):
|
||||
start_evts[i].record()
|
||||
fn()
|
||||
end_evts[i].record()
|
||||
torch.cuda.synchronize()
|
||||
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
|
||||
|
||||
|
||||
def mean_us(samples: list[float]) -> float:
|
||||
return sum(samples) / len(samples)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_eager(model, inputs):
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return model(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_torch_compile(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_aoti(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
dense_x, offsets, indices = inputs
|
||||
n_sparse = len(offsets)
|
||||
|
||||
# Flat-signature wrapper so torch.export sees positional tensors.
|
||||
class FlatWrapper(torch.nn.Module):
|
||||
def __init__(self, m, n_sparse: int):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.n_sparse = n_sparse
|
||||
|
||||
def forward(self, *args):
|
||||
n = self.n_sparse
|
||||
dense_x = args[0]
|
||||
offsets = list(args[1 : 1 + n])
|
||||
indices = list(args[1 + n : 1 + 2 * n])
|
||||
return self.m(dense_x, offsets, indices)
|
||||
|
||||
flat_model = FlatWrapper(model, n_sparse).to(DEVICE).eval()
|
||||
flat_inputs = (dense_x, *offsets, *indices)
|
||||
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(flat_model, flat_inputs)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
pkg = os.path.join(tmp, "dlrm.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pkg)
|
||||
loaded = torch._inductor.aoti_load_package(pkg)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return loaded(*flat_inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_cuda_graphs(model, inputs):
|
||||
"""Capture eager forward as a CUDA graph and replay. Patches the
|
||||
interaction's li/lj construction to be static buffers so capture works
|
||||
(same trick the single-config bench uses)."""
|
||||
dense_x, offsets, indices = inputs
|
||||
n_sparse = len(offsets)
|
||||
n_feat = 1 + n_sparse
|
||||
li = torch.tensor([i for i in range(n_feat) for _ in range(i)], device=DEVICE)
|
||||
lj = torch.tensor([j for i in range(n_feat) for j in range(i)], device=DEVICE)
|
||||
|
||||
def _interact_static(self, x, ly):
|
||||
bs, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
Zflat = Z[:, li, lj]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
model._interact = types.MethodType(_interact_static, model)
|
||||
|
||||
static_dense = dense_x.clone()
|
||||
static_offsets = [o.clone() for o in offsets]
|
||||
static_indices = [i.clone() for i in indices]
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd():
|
||||
return model(static_dense, static_offsets, static_indices)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
fwd()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
_ = fwd()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
g.replay()
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_luminal_backend(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
BACKENDS = [
|
||||
("eager", bench_eager),
|
||||
("torch.compile", bench_torch_compile),
|
||||
("AOTInductor", bench_aoti),
|
||||
("CUDA graphs", bench_cuda_graphs),
|
||||
("luminal_backend", bench_luminal_backend),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fmt(v: float) -> str:
|
||||
if v != v: # NaN
|
||||
return " - "
|
||||
return f"{v:7.1f}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print(f"Device: {torch.cuda.get_device_name(0)}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(
|
||||
f"Sweep: batch ∈ {BATCH_SIZES}, n_sparse ∈ {N_SPARSE_LIST}, "
|
||||
f"backends ∈ {[b[0] for b in BACKENDS]}, iters={ITERS}\n"
|
||||
)
|
||||
|
||||
# results[backend_name][batch][n_sparse] = mean µs
|
||||
results: dict[str, dict[tuple[int, int], float]] = {
|
||||
name: {} for name, _ in BACKENDS
|
||||
}
|
||||
|
||||
total_cells = len(BATCH_SIZES) * len(N_SPARSE_LIST) * len(BACKENDS)
|
||||
cell = 0
|
||||
for n_sparse in N_SPARSE_LIST:
|
||||
for batch in BATCH_SIZES:
|
||||
model, ln_emb = make_model(n_sparse)
|
||||
inputs = make_inputs(batch, ln_emb)
|
||||
for name, fn in BACKENDS:
|
||||
cell += 1
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
mu = fn(model, inputs)
|
||||
except Exception as e:
|
||||
mu = float("nan")
|
||||
print(
|
||||
f" [{cell:>3}/{total_cells}] "
|
||||
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
|
||||
f"FAILED: {type(e).__name__}: {str(e).splitlines()[-1][:80]}"
|
||||
)
|
||||
continue
|
||||
results[name][(batch, n_sparse)] = mu
|
||||
print(
|
||||
f" [{cell:>3}/{total_cells}] "
|
||||
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
|
||||
f"mean={mu:>7.1f}µs (took {time.perf_counter() - t0:.1f}s)"
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch._dynamo.reset()
|
||||
|
||||
# ---- Print one table per backend -----------------------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("Latency in µs by backend (rows = batch, cols = n_sparse)")
|
||||
for name, _ in BACKENDS:
|
||||
print(f"\n {name}:")
|
||||
header = " " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST)
|
||||
print(header)
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
v = results[name].get((bs, ns), float("nan"))
|
||||
row += f" {fmt(v)} "
|
||||
print(row)
|
||||
|
||||
# ---- Print "fastest backend per cell" matrix -----------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("Winner per cell (lowest mean µs):")
|
||||
print("\n " + "".join(f" n_sp={ns:<14}" for ns in N_SPARSE_LIST))
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
options = [
|
||||
(name, results[name].get((bs, ns), float("inf"))) for name, _ in BACKENDS
|
||||
]
|
||||
options = [(n, v) for n, v in options if v == v and v != float("inf")]
|
||||
if not options:
|
||||
row += " - "
|
||||
continue
|
||||
winner = min(options, key=lambda x: x[1])
|
||||
row += f" {winner[0]:<13s} {winner[1]:>6.1f}"
|
||||
print(row)
|
||||
|
||||
# ---- luminal_backend vs eager: scaling story -----------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("luminal_backend / eager (lower than 1.0 = luminal wins this cell):")
|
||||
print("\n " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST))
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
le = results.get("luminal_backend", {}).get((bs, ns), float("nan"))
|
||||
eg = results.get("eager", {}).get((bs, ns), float("nan"))
|
||||
if eg != eg or le != le or eg == 0:
|
||||
row += " - "
|
||||
else:
|
||||
row += f" {le / eg:>5.2f}x"
|
||||
print(row)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
examples/dlrm/export.py
Normal file
89
examples/dlrm/export.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Three-way DLRM equivalence harness.
|
||||
|
||||
Builds the MiniDLRM at the fixed config used by examples/dlrm/src/main.rs,
|
||||
serializes weights + sample inputs + the PyTorch eager output to safetensors
|
||||
files that the rust binary loads. Also runs the PyTorch + luminal_backend
|
||||
path so the comparison happens in one place.
|
||||
|
||||
Saves:
|
||||
/tmp/dlrm_weights.safetensors — state_dict with PyTorch names
|
||||
/tmp/dlrm_inputs.safetensors — dense_x, indices_{0..2}, and `expected`
|
||||
(the PyTorch eager output, fp32)
|
||||
|
||||
Then run:
|
||||
cargo run --release --manifest-path examples/dlrm/Cargo.toml
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# Import MiniDLRM from the test file we already authored.
|
||||
TESTS_DIR = Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
# Backend (and venv) shared with the test runner.
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
M_SPA = 4
|
||||
LN_EMB = [10, 20, 30]
|
||||
LN_BOT = [13, 8, M_SPA]
|
||||
LN_TOP = [10, 8, 1]
|
||||
# Match the rust binary's BATCH_SIZE — real-workload DLRM batch where
|
||||
# compute-bound matmul efficiency is what's being measured.
|
||||
BATCH = 2048
|
||||
DEVICE = torch.device("cuda")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
|
||||
|
||||
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE)
|
||||
for n in LN_EMB
|
||||
]
|
||||
offsets = [
|
||||
torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = model(dense_x, offsets, indices)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
luminal_out = compiled(dense_x, offsets, indices)
|
||||
|
||||
max_diff_lum = (luminal_out - eager_out).abs().max().item()
|
||||
print(f"PyTorch eager output : {eager_out.flatten().tolist()}")
|
||||
print(f"PyTorch + luminal : {luminal_out.flatten().tolist()}")
|
||||
print(f" max |diff| eager vs luminal_backend : {max_diff_lum:.3e}")
|
||||
assert max_diff_lum < 1e-5, "PT eager and luminal_backend disagree"
|
||||
|
||||
# Save weights — state_dict names already match what rust uses.
|
||||
weights = {k: v.detach().cpu() for k, v in model.state_dict().items()}
|
||||
save_file(weights, "/tmp/dlrm_weights.safetensors")
|
||||
print(f" wrote /tmp/dlrm_weights.safetensors ({len(weights)} tensors)")
|
||||
|
||||
inputs = {
|
||||
"dense_x": dense_x.detach().cpu().contiguous(),
|
||||
"expected": eager_out.detach().cpu().contiguous(),
|
||||
}
|
||||
for k, ix in enumerate(indices):
|
||||
# Rust reads i32 indices.
|
||||
inputs[f"indices_{k}"] = ix.detach().cpu().to(torch.int32).contiguous()
|
||||
save_file(inputs, "/tmp/dlrm_inputs.safetensors")
|
||||
print(f" wrote /tmp/dlrm_inputs.safetensors ({len(inputs)} tensors)")
|
||||
|
||||
print(
|
||||
"\nNext: cargo run --release --manifest-path examples/dlrm/Cargo.toml --bin dlrm"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
637
examples/dlrm/src/main.rs
Normal file
637
examples/dlrm/src/main.rs
Normal file
@@ -0,0 +1,637 @@
|
||||
//! Pure-rust DLRM mirroring `MiniDLRM` from
|
||||
//! `crates/luminal_python/tests/test_dlrm.py`.
|
||||
//!
|
||||
//! Loads weights + sample inputs + expected output produced by
|
||||
//! `examples/dlrm/export.py`, runs the same compute graph through luminal's
|
||||
//! CUDA runtime, and prints max-abs diff vs the saved PyTorch eager output.
|
||||
//!
|
||||
//! Topology (fixed for now — same as MiniDLRM at the small-config we test):
|
||||
//! m_spa = 4
|
||||
//! ln_emb = [10, 20, 30] (3 sparse tables)
|
||||
//! ln_bot = [13, 8, 4] (Linear-ReLU-Linear-ReLU)
|
||||
//! ln_top = [10, 8, 1] (Linear-ReLU-Linear-Sigmoid)
|
||||
//! batch_size = 2, bag_size = 1
|
||||
//!
|
||||
//! Weight name convention matches the PyTorch state_dict (so
|
||||
//! `runtime.load_safetensors` matches by name with no remapping):
|
||||
//! emb_l.{k}.weight (V_k, m_spa)
|
||||
//! bot_l.{0,2}.{weight,bias} Linear in_features → out_features
|
||||
//! top_l.{0,2}.{weight,bias} same
|
||||
//! PyTorch stores Linear weight as (out, in); we permute when matmul'ing.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_nn::gather_rows;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
|
||||
const M_SPA: usize = 4;
|
||||
const LN_EMB: [usize; 3] = [10, 20, 30];
|
||||
const LN_BOT: [usize; 3] = [13, 8, M_SPA];
|
||||
const LN_TOP: [usize; 3] = [10, 8, 1];
|
||||
// Real-workload DLRM batch — large enough that kernel work dominates the
|
||||
// per-launch overhead and the compute-bound performance is what's measured.
|
||||
const BATCH_SIZE: usize = 2048;
|
||||
|
||||
/// Linear with bias whose weight matches PyTorch's `nn.Linear` storage:
|
||||
/// shape `(out, in)`. Forward computes `input @ weight.T + bias`.
|
||||
struct Linear {
|
||||
weight: GraphTensor, // (out_features, in_features)
|
||||
bias: GraphTensor, // (out_features,)
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(cx: &mut Graph, prefix: &str, in_features: usize, out_features: usize) -> Self {
|
||||
Self {
|
||||
weight: cx
|
||||
.named_tensor(format!("{prefix}.weight"), (out_features, in_features))
|
||||
.persist(),
|
||||
bias: cx
|
||||
.named_tensor(format!("{prefix}.bias"), out_features)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
let out_features = self.weight.shape.dims[0];
|
||||
let mm = input.matmul(self.weight.permute((1, 0)));
|
||||
// Broadcast bias (out,) → output shape (..., out).
|
||||
let bias_b = self.bias.expand_dim(0, mm.shape.dims[0]);
|
||||
// bias_b shape: (B, out) — matches `mm` for 2-D input.
|
||||
let _ = out_features;
|
||||
mm + bias_b
|
||||
}
|
||||
}
|
||||
|
||||
// We use luminal's primitive .relu() / .sigmoid() rather than hand-rolling
|
||||
// them out of maximum / exp / reciprocal so the HLIR generated here matches
|
||||
// what the PT2 translator emits for `aten.relu.default` / `aten.sigmoid.default`
|
||||
// op-for-op. See dispatch.rs: both route through these same primitives.
|
||||
|
||||
fn bot_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
|
||||
layers[1].forward(layers[0].forward(x).relu()).relu()
|
||||
}
|
||||
|
||||
fn top_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
|
||||
layers[1].forward(layers[0].forward(x).relu()).sigmoid()
|
||||
}
|
||||
|
||||
/// Dot interaction: cat(dense, sparse...) → reshape → bmm → flat-tri-upper indexing.
|
||||
/// Matches MiniDLRM._interact in the python test.
|
||||
fn interact_features(
|
||||
cx: &mut Graph,
|
||||
dense: GraphTensor,
|
||||
sparse: &[GraphTensor],
|
||||
) -> GraphTensor {
|
||||
let batch = dense.shape.dims[0];
|
||||
let d = dense.shape.dims[1];
|
||||
let n_feat = 1 + sparse.len();
|
||||
|
||||
// T = cat([dense, *sparse], dim=1).view(B, n_feat, d)
|
||||
let mut t = dense;
|
||||
for s in sparse {
|
||||
t = t.concat_along(*s, 1);
|
||||
}
|
||||
// Reshape (B, n_feat * d) → (B, n_feat, d). concat_along leaves a contiguous
|
||||
// tensor so a fresh ShapeTracker is safe.
|
||||
let bagged = GraphTensor {
|
||||
id: t.id,
|
||||
graph_ref: t.graph_ref,
|
||||
shape: ShapeTracker::new((batch, Expression::from(n_feat), d)),
|
||||
dtype: t.dtype,
|
||||
};
|
||||
|
||||
// Z = bmm(T, T.transpose(1, 2)) → (B, n_feat, n_feat)
|
||||
let z = bagged.matmul(bagged.permute((0, 2, 1)));
|
||||
|
||||
// Strictly-lower-triangular indices into (n_feat, n_feat). For n_feat=4
|
||||
// these are 6 (i,j) pairs: (1,0),(2,0),(2,1),(3,0),(3,1),(3,2).
|
||||
let mut li = Vec::new();
|
||||
let mut lj = Vec::new();
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
li.push(i as i32);
|
||||
lj.push(j as i32);
|
||||
}
|
||||
}
|
||||
let n_pairs = li.len();
|
||||
|
||||
// Build flat_idx_per_pair[k] = li[k] * n_feat + lj[k] (constant across batch).
|
||||
let mut flat_idx_per_pair = Vec::with_capacity(n_pairs);
|
||||
for k in 0..n_pairs {
|
||||
flat_idx_per_pair.push(li[k] * n_feat as i32 + lj[k]);
|
||||
}
|
||||
|
||||
// Absolute flat index into Z viewed as 1D for each (b, k):
|
||||
// abs[b, k] = b * (n_feat*n_feat) + flat_idx_per_pair[k]
|
||||
let row_stride = n_feat * n_feat; // entries per batch in Z
|
||||
let arange_b = cx.arange(batch); // (B,) ints, values 0..B
|
||||
let abs_idx = arange_b.expand_dim(1, Expression::from(n_pairs))
|
||||
* Expression::from(row_stride);
|
||||
// pair_idx_const: (n_pairs,) ints, captured as a graph input we set once.
|
||||
let pair_idx = cx
|
||||
.named_tensor("__dot_pair_idx", n_pairs)
|
||||
.as_dtype(DType::Int)
|
||||
.persist();
|
||||
let abs_idx = abs_idx + pair_idx.expand_dim(0, batch);
|
||||
|
||||
// Gather Z as 1D.
|
||||
let z_flat = GraphTensor {
|
||||
id: z.id,
|
||||
graph_ref: z.graph_ref,
|
||||
shape: ShapeTracker::new(batch * row_stride),
|
||||
dtype: z.dtype,
|
||||
};
|
||||
let zflat_indexed = z_flat.gather(abs_idx); // (B, n_pairs)
|
||||
|
||||
// R = cat(dense, zflat_indexed, dim=1) → (B, d + n_pairs)
|
||||
dense.concat_along(zflat_indexed, 1)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Parse args: optional --bench / --stats / --mega, then positional paths.
|
||||
let mut bench_mode = false;
|
||||
let mut stats_mode = false;
|
||||
let mut mega_mode = false;
|
||||
let mut positional: Vec<String> = Vec::new();
|
||||
for arg in std::env::args().skip(1) {
|
||||
if arg == "--bench" {
|
||||
bench_mode = true;
|
||||
} else if arg == "--stats" {
|
||||
stats_mode = true;
|
||||
} else if arg == "--mega" {
|
||||
mega_mode = true;
|
||||
} else {
|
||||
positional.push(arg);
|
||||
}
|
||||
}
|
||||
let weights_path = positional
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/dlrm_weights.safetensors".to_string());
|
||||
let inputs_path = positional
|
||||
.get(1)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/dlrm_inputs.safetensors".to_string());
|
||||
|
||||
if mega_mode {
|
||||
run_megakernel(&weights_path, &inputs_path, bench_mode);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(
|
||||
Path::new(&weights_path).exists(),
|
||||
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
assert!(
|
||||
Path::new(&inputs_path).exists(),
|
||||
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
|
||||
// ---- Build graph -----------------------------------------------------
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dense_in = cx
|
||||
.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
|
||||
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
|
||||
.as_dtype(DType::Int)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Embedding tables (bag_size=1 → just row gather).
|
||||
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
|
||||
.persist()
|
||||
})
|
||||
.collect();
|
||||
let sparse_feats: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| gather_rows(emb_weights[k], idx_tensors[k], M_SPA))
|
||||
.collect();
|
||||
|
||||
// Bottom MLP: Linear 13→8, ReLU, Linear 8→4, ReLU.
|
||||
let bot = [
|
||||
Linear::new(&mut cx, "bot_l.0", LN_BOT[0], LN_BOT[1]),
|
||||
Linear::new(&mut cx, "bot_l.2", LN_BOT[1], LN_BOT[2]),
|
||||
];
|
||||
let dense_out = bot_forward(&bot, dense_in);
|
||||
|
||||
// Dot interaction → (B, n_pairs + m_spa) = (B, 10) for our config.
|
||||
let interacted = interact_features(&mut cx, dense_out, &sparse_feats);
|
||||
|
||||
// Top MLP: Linear 10→8, ReLU, Linear 8→1, Sigmoid.
|
||||
let top = [
|
||||
Linear::new(&mut cx, "top_l.0", LN_TOP[0], LN_TOP[1]),
|
||||
Linear::new(&mut cx, "top_l.2", LN_TOP[1], LN_TOP[2]),
|
||||
];
|
||||
let out = top_forward(&top, interacted).output();
|
||||
|
||||
// ---- Compile + load weights ------------------------------------------
|
||||
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, &weights_path);
|
||||
|
||||
// Set the strictly-lower-triangular pair index constant.
|
||||
let n_feat = 1 + LN_EMB.len();
|
||||
let mut pair_idx_vals = Vec::new();
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
pair_idx_vals.push((i * n_feat + j) as i32);
|
||||
}
|
||||
}
|
||||
// Find the named input by walking the graph.
|
||||
let pair_idx_id = find_named_input(&cx, "__dot_pair_idx")
|
||||
.expect("pair_idx tensor not found in graph");
|
||||
runtime.set_data(pair_idx_id, pair_idx_vals);
|
||||
|
||||
// Load inputs + expected output from safetensors.
|
||||
let inputs_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.map(&std::fs::File::open(&inputs_path).unwrap())
|
||||
.unwrap()
|
||||
};
|
||||
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
|
||||
|
||||
let dense_x: Vec<f32> = bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
runtime.set_data(dense_in, dense_x);
|
||||
for (k, idx_t) in idx_tensors.iter().enumerate() {
|
||||
let ix: Vec<i32> = bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec();
|
||||
runtime.set_data(*idx_t, ix);
|
||||
}
|
||||
|
||||
// ---- Search (small budget — graph is tiny) ---------------------------
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(8).trials(1), &mut rng);
|
||||
|
||||
// ---- Execute and compare ---------------------------------------------
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let result = runtime.get_f32(out);
|
||||
|
||||
let expected_bytes = inputs_st.tensor("expected").unwrap().data();
|
||||
let expected: &[f32] = bytemuck::cast_slice(expected_bytes);
|
||||
|
||||
println!("rust output : {result:?}");
|
||||
println!("expected : {expected:?}");
|
||||
|
||||
let max_diff = result
|
||||
.iter()
|
||||
.zip(expected.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0_f32, f32::max);
|
||||
println!("max |diff| : {max_diff:.3e}");
|
||||
|
||||
assert!(
|
||||
max_diff < 1e-4,
|
||||
"rust output diverges from PyTorch eager (max_diff={max_diff})"
|
||||
);
|
||||
println!("OK — rust luminal matches PyTorch eager within 1e-4.");
|
||||
|
||||
if stats_mode {
|
||||
let host_ops = runtime.host_ops();
|
||||
println!("\n=== Active bucket host-op inventory ({} ops) ===", host_ops.len());
|
||||
let mut by_type: std::collections::BTreeMap<String, usize> =
|
||||
std::collections::BTreeMap::new();
|
||||
for op in &host_ops {
|
||||
let s = format!("{op:?}");
|
||||
let head = s.split_whitespace().next().unwrap_or(&s).to_string();
|
||||
*by_type.entry(head).or_insert(0) += 1;
|
||||
}
|
||||
for (k, v) in &by_type {
|
||||
println!(" {v:>3} {k}");
|
||||
}
|
||||
// Per-op detail: extract the cuBLASLt epilogue + shape signature so
|
||||
// we can see at a glance whether bias/relu fusion fired (the egglog
|
||||
// rewrites map matmul+add+maximum_f32(0) -> EPILOGUE_RELU_BIAS).
|
||||
println!("\n=== cuBLASLt op detail ===");
|
||||
for op in &host_ops {
|
||||
let s = format!("{op:?}");
|
||||
if !s.starts_with("CuBlasLt") {
|
||||
continue;
|
||||
}
|
||||
let epilogue = extract_field(&s, "epilogue:");
|
||||
let shape = (extract_field(&s, "m:"), extract_field(&s, "n:"), extract_field(&s, "k:"));
|
||||
println!(
|
||||
" m={:<8} n={:<8} k={:<8} epilogue={}",
|
||||
shape.0, shape.1, shape.2, epilogue
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if bench_mode {
|
||||
// Cache input vectors so the bench loop can re-set_data each iter (the
|
||||
// PyTorch backends do an equivalent staging step under the hood).
|
||||
let dense_vec: Vec<f32> =
|
||||
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
|
||||
.map(|k| {
|
||||
bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec()
|
||||
})
|
||||
.collect();
|
||||
bench_rust(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
out,
|
||||
dense_in,
|
||||
&idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Time `runtime.execute` directly. Inputs are already loaded once before
|
||||
/// `--bench` and not re-uploaded between calls, mirroring CUDA-graph replay
|
||||
/// semantics. Synchronizes the stream once at the end and divides total
|
||||
/// elapsed by `iters` for a steady-state mean; also prints per-iter samples
|
||||
/// to /tmp/dlrm_bench_rust_luminal.txt for the python aggregator.
|
||||
fn bench_rust(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
out: GraphTensor,
|
||||
dense_in: GraphTensor,
|
||||
idx_tensors: &[GraphTensor],
|
||||
dense_vec: Vec<f32>,
|
||||
idx_vecs: Vec<Vec<i32>>,
|
||||
) {
|
||||
bench_through_luminal(
|
||||
cx,
|
||||
runtime,
|
||||
out,
|
||||
dense_in,
|
||||
idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
"/tmp/dlrm_bench_rust_luminal.txt",
|
||||
"[bench] rust luminal",
|
||||
);
|
||||
}
|
||||
|
||||
/// Shared steady-state bench for any luminal graph + runtime. Re-sets
|
||||
/// inputs every iter, calls `execute`, then `get_f32` to force a stream
|
||||
/// sync. Dumps per-iter µs samples to `samples_path` for
|
||||
/// `examples/dlrm/bench.py` to merge into its ranking.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn bench_through_luminal(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
out: GraphTensor,
|
||||
dense_in: GraphTensor,
|
||||
idx_tensors: &[GraphTensor],
|
||||
dense_vec: Vec<f32>,
|
||||
idx_vecs: Vec<Vec<i32>>,
|
||||
samples_path: &str,
|
||||
label: &str,
|
||||
) {
|
||||
const WARMUP: usize = 50;
|
||||
const ITERS: usize = 500;
|
||||
use std::time::Instant;
|
||||
|
||||
let bench_once = |runtime: &mut CudaRuntime| {
|
||||
runtime.set_data(dense_in, dense_vec.clone());
|
||||
for (k, t) in idx_tensors.iter().enumerate() {
|
||||
runtime.set_data(*t, idx_vecs[k].clone());
|
||||
}
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let _ = runtime.get_f32(out);
|
||||
};
|
||||
|
||||
for _ in 0..WARMUP {
|
||||
bench_once(runtime);
|
||||
}
|
||||
|
||||
let mut samples = Vec::with_capacity(ITERS);
|
||||
for _ in 0..ITERS {
|
||||
let t0 = Instant::now();
|
||||
bench_once(runtime);
|
||||
samples.push(t0.elapsed().as_secs_f64() * 1e6);
|
||||
}
|
||||
samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mean = samples.iter().sum::<f64>() / ITERS as f64;
|
||||
let p50 = samples[ITERS / 2];
|
||||
let p99 = samples[(ITERS as f64 * 0.99) as usize];
|
||||
println!(
|
||||
"\n{label}: mean={mean:.2}µs p50={p50:.2}µs p99={p99:.2}µs (n={ITERS})"
|
||||
);
|
||||
|
||||
let body = samples
|
||||
.iter()
|
||||
.map(|s| format!("{s:.4}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
std::fs::write(samples_path, body).expect("write bench samples");
|
||||
println!(" per-iter samples -> {samples_path}");
|
||||
}
|
||||
|
||||
/// `--mega`: build a luminal Graph whose entire forward is a single
|
||||
/// [`DlrmMegaCustom`] op, then run it through the standard
|
||||
/// `CudaRuntime` flow (load_safetensors → search → execute → get_f32).
|
||||
/// Verifies bitwise vs the saved PyTorch eager output, optionally
|
||||
/// benches steady-state per-call latency through the same `bench_rust`
|
||||
/// path the non-mega rust binary uses.
|
||||
///
|
||||
/// The point: same kernel as the PT2-backend fast path (the parameterized
|
||||
/// `DlrmMegaKernel` in `luminal_cuda_lite::kernel::dlrm_megakernel`),
|
||||
/// just constructed by hand instead of via the translator's pattern
|
||||
/// matcher. Everything past the `cx.custom_op` call — buffer
|
||||
/// management, weight loading, input registration, kernel dispatch,
|
||||
/// output retrieval — is luminal's runtime.
|
||||
fn run_megakernel(weights_path: &str, inputs_path: &str, bench: bool) {
|
||||
assert!(
|
||||
Path::new(weights_path).exists(),
|
||||
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
assert!(
|
||||
Path::new(inputs_path).exists(),
|
||||
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// ---- User inputs -----------------------------------------------------
|
||||
let dense_in = cx.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
|
||||
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
|
||||
.as_dtype(DType::Int)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// ---- Weights — names must match safetensors keys so the runtime's
|
||||
// load_safetensors matches by Input label.
|
||||
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
|
||||
.persist()
|
||||
})
|
||||
.collect();
|
||||
// PyTorch's nn.Linear stores weight as (out_features, in_features).
|
||||
let bot_l0_w = cx
|
||||
.named_tensor("bot_l.0.weight", (LN_BOT[1], LN_BOT[0]))
|
||||
.persist();
|
||||
let bot_l0_b = cx.named_tensor("bot_l.0.bias", LN_BOT[1]).persist();
|
||||
let bot_l1_w = cx
|
||||
.named_tensor("bot_l.2.weight", (LN_BOT[2], LN_BOT[1]))
|
||||
.persist();
|
||||
let bot_l1_b = cx.named_tensor("bot_l.2.bias", LN_BOT[2]).persist();
|
||||
let top_l0_w = cx
|
||||
.named_tensor("top_l.0.weight", (LN_TOP[1], LN_TOP[0]))
|
||||
.persist();
|
||||
let top_l0_b = cx.named_tensor("top_l.0.bias", LN_TOP[1]).persist();
|
||||
let top_l1_w = cx
|
||||
.named_tensor("top_l.2.weight", (LN_TOP[2], LN_TOP[1]))
|
||||
.persist();
|
||||
let top_l1_b = cx.named_tensor("top_l.2.bias", LN_TOP[2]).persist();
|
||||
|
||||
// ---- One CustomOp does the whole forward ----------------------------
|
||||
// Input order MUST match what DlrmMegaKernel's CUDA source expects:
|
||||
// dense, indices..., emb_weights..., bot Linears (w then b each),
|
||||
// top Linears (w then b each). See `kernel::dlrm_megakernel`.
|
||||
let mut inputs: Vec<GraphTensor> = vec![dense_in];
|
||||
inputs.extend(idx_tensors.iter().copied());
|
||||
inputs.extend(emb_weights.iter().copied());
|
||||
inputs.extend([
|
||||
bot_l0_w, bot_l0_b, bot_l1_w, bot_l1_b, top_l0_w, top_l0_b, top_l1_w, top_l1_b,
|
||||
]);
|
||||
|
||||
let kernel = DlrmMegaKernel {
|
||||
batch: BATCH_SIZE,
|
||||
n_dense_in: LN_BOT[0],
|
||||
ln_bot: LN_BOT.to_vec(),
|
||||
n_sparse: LN_EMB.len(),
|
||||
vocab_sizes: LN_EMB.to_vec(),
|
||||
m_spa: M_SPA,
|
||||
ln_top: LN_TOP.to_vec(),
|
||||
};
|
||||
let out = cx
|
||||
.custom_op(
|
||||
DlrmMegaCustom(kernel),
|
||||
inputs,
|
||||
(BATCH_SIZE, 1usize),
|
||||
DType::F32,
|
||||
)
|
||||
.output();
|
||||
|
||||
// ---- Compile + load weights — same path as the non-mega flow -------
|
||||
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, weights_path);
|
||||
|
||||
// ---- Inputs ---------------------------------------------------------
|
||||
let inputs_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.map(&std::fs::File::open(inputs_path).unwrap())
|
||||
.unwrap()
|
||||
};
|
||||
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
|
||||
let dense_vec: Vec<f32> =
|
||||
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
runtime.set_data(dense_in, dense_vec.clone());
|
||||
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
|
||||
.map(|k| {
|
||||
bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec()
|
||||
})
|
||||
.collect();
|
||||
for (k, idx_t) in idx_tensors.iter().enumerate() {
|
||||
runtime.set_data(*idx_t, idx_vecs[k].clone());
|
||||
}
|
||||
|
||||
// ---- Search ---------------------------------------------------------
|
||||
// Single-CustomOp graph: nothing to search over. One trial.
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(1).trials(1), &mut rng);
|
||||
|
||||
// ---- Execute + verify -----------------------------------------------
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let result = runtime.get_f32(out);
|
||||
|
||||
let expected: &[f32] =
|
||||
bytemuck::cast_slice(inputs_st.tensor("expected").unwrap().data());
|
||||
let max_diff = result
|
||||
.iter()
|
||||
.zip(expected.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0_f32, f32::max);
|
||||
println!(
|
||||
"[mega] output[..4]={:?} expected[..4]={:?} max|diff|={:.3e}",
|
||||
&result[..result.len().min(4)],
|
||||
&expected[..result.len().min(4)],
|
||||
max_diff
|
||||
);
|
||||
assert!(
|
||||
max_diff < 1e-4,
|
||||
"megakernel output diverges from PyTorch eager (max_diff={max_diff})"
|
||||
);
|
||||
println!("[mega] OK — luminal megakernel matches PyTorch eager within 1e-4");
|
||||
|
||||
// Inventory the host ops — should be exactly 1 (the DlrmMegaCustom).
|
||||
let host_ops = runtime.host_ops();
|
||||
println!("[mega] active bucket host-op count: {}", host_ops.len());
|
||||
|
||||
if bench {
|
||||
// Reuse the shared bench loop. Writes per-iter µs samples to
|
||||
// /tmp/dlrm_bench_megakernel.txt so examples/dlrm/bench.py picks
|
||||
// them up under the "DLRM megakernel" row.
|
||||
bench_through_luminal(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
out,
|
||||
dense_in,
|
||||
&idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
"/tmp/dlrm_bench_megakernel.txt",
|
||||
"[mega]",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull a `Field: value,` from a Debug-formatted struct dump. Returns the
|
||||
/// substring between `field` and the next `,` or `}`, trimmed.
|
||||
fn extract_field(s: &str, field: &str) -> String {
|
||||
let Some(idx) = s.find(field) else {
|
||||
return "?".to_string();
|
||||
};
|
||||
let start = idx + field.len();
|
||||
let tail = &s[start..];
|
||||
let end = tail
|
||||
.find(|c: char| c == ',' || c == '}')
|
||||
.unwrap_or(tail.len());
|
||||
tail[..end].trim().to_string()
|
||||
}
|
||||
|
||||
/// Walk the graph looking for an [`Input`] op with the given label. Used to
|
||||
/// recover a `NodeIndex` we can `set_data` against when the original
|
||||
/// `GraphTensor` handle isn't in scope.
|
||||
fn find_named_input(cx: &Graph, label: &str) -> Option<NodeIndex> {
|
||||
use luminal::hlir::Input;
|
||||
for n in cx.graph.node_indices() {
|
||||
if let Some(Input { label: l, .. }) =
|
||||
(*cx.graph[n]).as_any().downcast_ref::<Input>()
|
||||
{
|
||||
if l == label {
|
||||
return Some(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Reference in New Issue
Block a user