mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
main
...
vanilla-py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c41ede0e5b |
483
crates/luminal_cuda_lite/src/kernel/dlrm_interact.rs
Normal file
483
crates/luminal_cuda_lite/src/kernel/dlrm_interact.rs
Normal file
@@ -0,0 +1,483 @@
|
||||
//! Fused DLRM pairwise-dot interaction.
|
||||
//!
|
||||
//! Replaces the cat→bmm(T,Tᵀ)→tril-gather chain with a single kernel
|
||||
//! that reads N separate `(batch, d)` tensors and writes the strict
|
||||
//! lower-triangular pairwise dot products directly into the output —
|
||||
//! `out[b, p] = Σ_d v_i[b, d] * v_j[b, d]` for each ordered pair (i, j)
|
||||
//! with i > j.
|
||||
//!
|
||||
//! Why this matters for the DLRM forward: the natural luminal lowering
|
||||
//! materializes the `(B, F, D)` stacked tensor, then the full `(B, F, F)`
|
||||
//! BMM output, then a flat gather to pull out F(F-1)/2 pairs. That's
|
||||
//! ~12 small kernels and an `F²·B` intermediate even though only half
|
||||
//! of those elements are kept. The fused version uses N pointer args
|
||||
//! (one per feature vector), computes only the F(F-1)/2 dot products,
|
||||
//! and writes directly to the final `(B, F(F-1)/2)` buffer.
|
||||
//!
|
||||
//! All shapes are static. The kernel source is generated with the
|
||||
//! exact pair table baked in (so the inner loop is a fixed `D`-element
|
||||
//! reduction with no shape-dependent branching).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairwiseDotLowerTriKernel {
|
||||
pub batch: usize,
|
||||
pub num_features: usize, // F
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
impl PairwiseDotLowerTriKernel {
|
||||
fn pair_count(&self) -> usize {
|
||||
self.num_features * (self.num_features - 1) / 2
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for PairwiseDotLowerTriKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let f = self.num_features;
|
||||
let p = self.pair_count();
|
||||
// Pair table (i, j) with i > j, in strict-lower-tri (row-major over
|
||||
// i then j) order — same convention as torch.tril_indices(F, F, -1).
|
||||
let mut pairs: Vec<(usize, usize)> = Vec::with_capacity(p);
|
||||
for i in 0..f {
|
||||
for j in 0..i {
|
||||
pairs.push((i, j));
|
||||
}
|
||||
}
|
||||
// Build kernel params signature: one pointer per input feature.
|
||||
let in_params: String = (0..f)
|
||||
.map(|k| format!(", const float* __restrict__ v{k}"))
|
||||
.collect::<Vec<_>>()
|
||||
.concat();
|
||||
// For each pair p, generate one branch in the switch that selects
|
||||
// the two input pointers to dot-product. With F small (DLRM has
|
||||
// F=4), the branch is fully unrolled.
|
||||
let mut pair_switch = String::new();
|
||||
for (pidx, (i, j)) in pairs.iter().enumerate() {
|
||||
pair_switch += &format!(
|
||||
" case {pidx}: pa = v{i}; pb = v{j}; break;\n"
|
||||
);
|
||||
}
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void dlrm_pairwise_dot_lower_tri_kernel(
|
||||
float* __restrict__ out{in_params}
|
||||
) {{
|
||||
const int B = {batch};
|
||||
const int D = {d};
|
||||
const int P = {p};
|
||||
int b = blockIdx.x;
|
||||
int p = blockIdx.y;
|
||||
int t = threadIdx.x;
|
||||
if (b >= B || p >= P) return;
|
||||
|
||||
const float* pa = nullptr;
|
||||
const float* pb = nullptr;
|
||||
switch (p) {{
|
||||
{pair_switch}
|
||||
default: return;
|
||||
}}
|
||||
|
||||
// Block-wide reduction of dot(pa[b], pb[b]) over D using shared mem.
|
||||
extern __shared__ float smem[];
|
||||
float partial = 0.0f;
|
||||
for (int d = t; d < D; d += blockDim.x) {{
|
||||
partial += pa[b * D + d] * pb[b * D + d];
|
||||
}}
|
||||
smem[t] = partial;
|
||||
__syncthreads();
|
||||
// Power-of-two tree reduce. blockDim.x must be a power of two.
|
||||
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {{
|
||||
if (t < stride) {{
|
||||
smem[t] += smem[t + stride];
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
if (t == 0) {{
|
||||
out[b * P + p] = smem[0];
|
||||
}}
|
||||
}}
|
||||
",
|
||||
batch = self.batch,
|
||||
d = self.d,
|
||||
p = p,
|
||||
pair_switch = pair_switch,
|
||||
in_params = in_params,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module
|
||||
.load_function("dlrm_pairwise_dot_lower_tri_kernel")
|
||||
.unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
// Pick a power-of-two thread count ≤ D, ≥ 32 where possible.
|
||||
let mut threads = 1usize;
|
||||
while threads * 2 <= self.d.max(32) {
|
||||
threads *= 2;
|
||||
}
|
||||
let threads = threads.max(32).min(1024);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(self.batch),
|
||||
Expression::from(p),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(threads),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(threads * 4),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.pair_count())
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Each pair reads 2 vectors of D floats per batch row. F-choose-2
|
||||
// pairs, so per-batch each input vector is read F-1 times.
|
||||
Expression::from(self.batch * self.num_features * (self.num_features - 1) * self.d * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 2D-1 flops per dot product (D mul + D-1 add).
|
||||
Expression::from(self.batch * self.pair_count() * (2 * self.d - 1))
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"DLRMPairwiseDotLowerTri"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairwiseDotLowerTriCustom(pub PairwiseDotLowerTriKernel);
|
||||
|
||||
impl CustomOp for PairwiseDotLowerTriCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Two-input variant of [`PairwiseDotLowerTriKernel`] that consumes the
|
||||
/// dense MLP output and a stacked embedding output without requiring
|
||||
/// the caller to first slice the stack into individual (B, D) views.
|
||||
///
|
||||
/// Treats feature 0 as `dense_out[b, t]` and features 1..=num_emb as
|
||||
/// `emb_stack[b, k-1, t]`. Output pair table is the strict lower tri
|
||||
/// of an `F × F` matrix where `F = num_emb + 1`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairwiseDotLowerTriStackedKernel {
|
||||
pub batch: usize,
|
||||
pub num_emb: usize, // N (excluding the dense feature)
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
impl PairwiseDotLowerTriStackedKernel {
|
||||
fn num_features(&self) -> usize {
|
||||
self.num_emb + 1
|
||||
}
|
||||
fn pair_count(&self) -> usize {
|
||||
let f = self.num_features();
|
||||
f * (f - 1) / 2
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for PairwiseDotLowerTriStackedKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let f = self.num_features();
|
||||
let p = self.pair_count();
|
||||
let n_emb = self.num_emb;
|
||||
let d_ = self.d;
|
||||
|
||||
// Block-per-batch layout. Each block:
|
||||
// 1. Cooperatively loads all F feature vectors for batch b into
|
||||
// shared memory once — F*D floats total. Feature 0 = dense[b];
|
||||
// features 1..F = emb_stack[b, k-1, :].
|
||||
// 2. Each thread `tid` strides over pairs `p = tid, tid+blockDim.x,
|
||||
// …, P-1`. For each, derives (i, j) such that i > j and writes
|
||||
// the dot product of feat[i] and feat[j].
|
||||
//
|
||||
// Compared to the previous (B, P) grid-of-one-block-per-output
|
||||
// layout this:
|
||||
// - Cuts launch count by P× (e.g. 528× at num_cat=32).
|
||||
// - Reads each feature vector once per batch instead of (F-1)
|
||||
// times — F(F-1) reads → F reads, an ~(F-1)/2× memory traffic
|
||||
// reduction (e.g. 16× at num_cat=32, F=33).
|
||||
// - Reuses cached features across all P pairs at shared-memory
|
||||
// latency instead of refetching from global per pair.
|
||||
//
|
||||
// Pair-index → (i, j) is computed from `p` directly using the
|
||||
// closed-form for strict lower-tri row indexing:
|
||||
// row i contains i pairs (j ∈ [0, i)); cumulative row starts
|
||||
// at `i*(i-1)/2`; so `i = floor((1+sqrt(1+8p))/2)` and
|
||||
// `j = p - i*(i-1)/2`. We do a tiny defensive adjustment
|
||||
// afterwards to absorb sqrtf rounding.
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void dlrm_pairwise_dot_lower_tri_stacked_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ dense, // (B, D)
|
||||
const float* __restrict__ emb_stack // (B, N, D)
|
||||
) {{
|
||||
const int B = {batch};
|
||||
const int D = {d};
|
||||
const int N = {n_emb};
|
||||
const int F = {f};
|
||||
const int P = {p};
|
||||
int b = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int tcount = blockDim.x;
|
||||
if (b >= B) return;
|
||||
|
||||
// Shared feature cache: F * D floats.
|
||||
extern __shared__ float feat[];
|
||||
for (int i = tid; i < F * D; i += tcount) {{
|
||||
int feat_idx = i / D;
|
||||
int dim = i - feat_idx * D;
|
||||
if (feat_idx == 0) {{
|
||||
feat[i] = dense[b * D + dim];
|
||||
}} else {{
|
||||
int slot = feat_idx - 1;
|
||||
feat[i] = emb_stack[(b * N + slot) * D + dim];
|
||||
}}
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
// Each thread handles a strided slice of the P pairs.
|
||||
for (int p = tid; p < P; p += tcount) {{
|
||||
float t = sqrtf(8.0f * (float)p + 1.0f);
|
||||
int pi = (int)((t + 1.0f) * 0.5f);
|
||||
// Adjust for fp rounding — pi*(pi-1)/2 must be the largest
|
||||
// row-start ≤ p.
|
||||
while (pi * (pi - 1) / 2 > p) pi--;
|
||||
while ((pi + 1) * pi / 2 <= p) pi++;
|
||||
int pj = p - pi * (pi - 1) / 2;
|
||||
|
||||
float acc = 0.0f;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < {d}; ++d) {{
|
||||
acc += feat[pi * {d} + d] * feat[pj * {d} + d];
|
||||
}}
|
||||
out[b * P + p] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
batch = self.batch,
|
||||
d = d_,
|
||||
n_emb = n_emb,
|
||||
f = f,
|
||||
p = p,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module
|
||||
.load_function("dlrm_pairwise_dot_lower_tri_stacked_kernel")
|
||||
.unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
// Block size: enough threads to cover both the feature-load phase
|
||||
// (F*D elements) and the pair computation (P elements) without
|
||||
// serial waves dominating, capped at 1024 (max CUDA block size)
|
||||
// and rounded down to a multiple of 32 for warp alignment.
|
||||
let want = std::cmp::max(f * d_, p);
|
||||
let threads = want.clamp(32, 1024).next_multiple_of(32);
|
||||
let threads = threads.min(1024);
|
||||
let shared_bytes = f * d_ * 4;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(self.batch),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(threads),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(shared_bytes),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.pair_count())
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_features() * (self.num_features() - 1) * self.d * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
Expression::from(self.batch * self.pair_count() * (2 * self.d - 1))
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"DLRMPairwiseDotLowerTriStacked"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairwiseDotLowerTriStackedCustom(pub PairwiseDotLowerTriStackedKernel);
|
||||
|
||||
impl CustomOp for PairwiseDotLowerTriStackedCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pairwise lower-tri dot product over `dense_out` plus a stacked
|
||||
/// embedding output. Avoids the per-table slice that the variadic
|
||||
/// variant would otherwise need to materialize.
|
||||
///
|
||||
/// * `dense_out`: `(batch, d)` — feature 0 in the pair table.
|
||||
/// * `emb_stack`: `(batch, num_emb, d)` — features 1..=num_emb.
|
||||
///
|
||||
/// Returns `(batch, (num_emb+1) * num_emb / 2)`, same strict-lower-tri
|
||||
/// ordering as [`dlrm_pairwise_dot_lower_tri`].
|
||||
pub fn dlrm_pairwise_dot_lower_tri_stacked(
|
||||
dense_out: GraphTensor,
|
||||
emb_stack: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(dense_out.dtype, DType::F32, "dense_out must be F32");
|
||||
assert_eq!(emb_stack.dtype, DType::F32, "emb_stack must be F32");
|
||||
let dd = dense_out.dims();
|
||||
let sd = emb_stack.dims();
|
||||
assert_eq!(dd.len(), 2, "dense_out must be 2D");
|
||||
assert_eq!(sd.len(), 3, "emb_stack must be 3D (batch, num_emb, d)");
|
||||
let batch = dd[0].to_usize().expect("batch must be static");
|
||||
let d = dd[1].to_usize().expect("d must be static");
|
||||
assert_eq!(sd[0].to_usize().unwrap(), batch);
|
||||
let num_emb = sd[1].to_usize().expect("num_emb must be static");
|
||||
assert_eq!(sd[2].to_usize().unwrap(), d);
|
||||
let kern = PairwiseDotLowerTriStackedKernel {
|
||||
batch,
|
||||
num_emb,
|
||||
d,
|
||||
};
|
||||
let f = num_emb + 1;
|
||||
let p = f * (f - 1) / 2;
|
||||
let cx = unsafe { &mut *dense_out.graph_ref };
|
||||
cx.custom_op(
|
||||
PairwiseDotLowerTriStackedCustom(kern),
|
||||
vec![dense_out, emb_stack],
|
||||
(batch, p),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
|
||||
/// Strict-lower-triangular pairwise dot product of N feature vectors.
|
||||
///
|
||||
/// * `features`: N tensors, each `(batch, d)`, all F32, all the same shape.
|
||||
///
|
||||
/// Returns `(batch, N*(N-1)/2)` with pair ordering matching
|
||||
/// `torch.tril_indices(N, N, -1)` (row-major: (1,0), (2,0), (2,1), …).
|
||||
pub fn dlrm_pairwise_dot_lower_tri(features: Vec<GraphTensor>) -> GraphTensor {
|
||||
assert!(features.len() >= 2, "need at least 2 feature vectors");
|
||||
let first = features[0];
|
||||
let dims = first.dims();
|
||||
assert_eq!(dims.len(), 2, "each feature vector must be 2D (batch, d)");
|
||||
let batch = dims[0].to_usize().expect("batch must be static");
|
||||
let d = dims[1].to_usize().expect("d must be static");
|
||||
let f = features.len();
|
||||
for v in &features {
|
||||
assert_eq!(v.dtype, DType::F32, "features must all be F32");
|
||||
let vd = v.dims();
|
||||
assert_eq!(vd.len(), 2, "features must all be 2D");
|
||||
assert_eq!(vd[0].to_usize().unwrap(), batch, "batch mismatch");
|
||||
assert_eq!(vd[1].to_usize().unwrap(), d, "d mismatch");
|
||||
}
|
||||
let kern = PairwiseDotLowerTriKernel {
|
||||
batch,
|
||||
num_features: f,
|
||||
d,
|
||||
};
|
||||
let p = f * (f - 1) / 2;
|
||||
let cx = unsafe { &mut *first.graph_ref };
|
||||
cx.custom_op(
|
||||
PairwiseDotLowerTriCustom(kern),
|
||||
features,
|
||||
(batch, p),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
757
crates/luminal_cuda_lite/src/kernel/embedding_bag.rs
Normal file
757
crates/luminal_cuda_lite/src/kernel/embedding_bag.rs
Normal file
@@ -0,0 +1,757 @@
|
||||
//! Single-kernel fused EmbeddingBag (sum-pool) operator.
|
||||
//!
|
||||
//! DLRM-style embedding lookups in luminal currently lower into a chain
|
||||
//! of broadcast-iota + multiply + add + Gather + SumReduce kernels (~6
|
||||
//! kernels per table). For a model with even a handful of tables that
|
||||
//! eats most of the per-iter launch budget once everything else is
|
||||
//! captured into a single CUDA graph.
|
||||
//!
|
||||
//! This op collapses the whole pattern — `gather(table, idx) → sum(L)` —
|
||||
//! into one kernel. Same template as `Matmul2DKernel`: implement
|
||||
//! [`KernelOp`], wrap in a [`CustomOp`] so the user-facing call comes
|
||||
//! out as a `dyn KernelOp` in the LLIR (which means it can be absorbed
|
||||
//! into the same CudaGraphOp as everything around it — no extra host
|
||||
//! op, no extra CUDA launch outside the graph).
|
||||
//!
|
||||
//! Semantics: `out[b, d] = Σ_l table[indices[b, l], d]` with
|
||||
//! table: (n_emb, d), F32, row-major
|
||||
//! indices: (batch, bag), I32, row-major
|
||||
//! out: (batch, d), F32, row-major
|
||||
//!
|
||||
//! Fixed-shape: `n_emb`, `d`, `batch`, `bag` are static (baked into
|
||||
//! the kernel source via #defines), matching how the rest of the
|
||||
//! `kernel::` ops in this crate handle shape.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// One-kernel fused EmbeddingBag with sum pooling and fixed bag size.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingBagSumKernel {
|
||||
pub batch: usize,
|
||||
pub bag: usize,
|
||||
pub d: usize,
|
||||
pub n_emb: usize,
|
||||
}
|
||||
|
||||
impl KernelOp for EmbeddingBagSumKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
// One block per batch row, `d` threads per block. Each thread sums
|
||||
// one output column over the `bag` indices. This is the standard
|
||||
// bag-size-1..L pattern and is memory-bandwidth bound on `table`,
|
||||
// which is exactly the right roofline for this op.
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void embedding_bag_sum_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ table,
|
||||
const int* __restrict__ indices
|
||||
) {{
|
||||
const int B = {batch};
|
||||
const int L = {bag};
|
||||
const int D = {d};
|
||||
const int N = {n_emb};
|
||||
int b = blockIdx.x;
|
||||
int d = threadIdx.x;
|
||||
if (b >= B || d >= D) return;
|
||||
float acc = 0.0f;
|
||||
#pragma unroll 4
|
||||
for (int l = 0; l < L; ++l) {{
|
||||
int row = indices[b * L + l];
|
||||
// Index is from user input; trust it (matches torch.EmbeddingBag).
|
||||
acc += table[row * D + d];
|
||||
}}
|
||||
out[b * D + d] = acc;
|
||||
(void)N;
|
||||
}}
|
||||
",
|
||||
batch = self.batch,
|
||||
bag = self.bag,
|
||||
d = self.d,
|
||||
n_emb = self.n_emb,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("embedding_bag_sum_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(self.batch),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(self.d),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// For each output element, L reads from table (4 bytes each), plus
|
||||
// L reads from indices (4 bytes each, shared across D threads — we
|
||||
// just bill once per output to keep this readable).
|
||||
Expression::from(self.batch * self.d * self.bag * 4 + self.batch * self.bag * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// L adds per output element. Pointer math doesn't count.
|
||||
Expression::from(self.batch * self.d * self.bag)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"EmbeddingBagSum"
|
||||
}
|
||||
}
|
||||
|
||||
/// CustomOp wrapper for [`EmbeddingBagSumKernel`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingBagSumCustom(pub EmbeddingBagSumKernel);
|
||||
|
||||
impl CustomOp for EmbeddingBagSumCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// One-kernel fused multi-table EmbeddingBag with sum pooling.
|
||||
///
|
||||
/// Folds all `num_tables` independent embedding lookups into a single
|
||||
/// CUDA kernel launch. Reads from one big weight tensor that is the
|
||||
/// row-wise concatenation of every table; per-table row offsets are
|
||||
/// baked into the kernel source. Per-table index tensors stay separate.
|
||||
/// Output is `(batch, num_tables, d)` so downstream ops can consume it
|
||||
/// as a single stacked tensor (matches v3's `index_select + reshape`
|
||||
/// trick — Inductor fuses gather+sum across all tables; this kernel
|
||||
/// just does it directly).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StackedEmbeddingBagKernel {
|
||||
pub batch: usize,
|
||||
pub bag: usize,
|
||||
pub d: usize,
|
||||
pub num_tables: usize,
|
||||
/// Cumulative row counts: `row_offsets[k]` = number of rows in all
|
||||
/// tables strictly before table `k`. Length = `num_tables + 1`.
|
||||
/// `row_offsets[num_tables]` = total rows in the stacked weight.
|
||||
pub row_offsets: Vec<usize>,
|
||||
}
|
||||
|
||||
impl KernelOp for StackedEmbeddingBagKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert_eq!(
|
||||
self.row_offsets.len(),
|
||||
self.num_tables + 1,
|
||||
"row_offsets must have num_tables+1 entries"
|
||||
);
|
||||
// One index pointer per table — variadic via generated kernel signature.
|
||||
let idx_params: String = (0..self.num_tables)
|
||||
.map(|k| format!(", const int* __restrict__ idx_{k}"))
|
||||
.collect::<Vec<_>>()
|
||||
.concat();
|
||||
// For each table k, generate a `case k` branch that picks the right
|
||||
// index pointer and row offset. The case body is the same fused
|
||||
// gather+sum loop as the single-table kernel.
|
||||
let mut switch = String::new();
|
||||
for k in 0..self.num_tables {
|
||||
let off = self.row_offsets[k];
|
||||
switch += &format!(
|
||||
" case {k}: {{ const int* __restrict__ idx_ptr = idx_{k}; const int row_off = {off}; for (int l = 0; l < L; ++l) {{ int row = idx_ptr[b * L + l] + row_off; acc += weight[row * D + d]; }} break; }}\n"
|
||||
);
|
||||
}
|
||||
|
||||
// Grid is (B,); one block per batch row. Block holds *all* (k, d)
|
||||
// output threads together. The previous (B, N) grid had 16-thread
|
||||
// blocks at D=16, which left each SM under-occupied (Hopper's
|
||||
// max-blocks-per-SM × 16 threads ≪ 64 warps/SM, so the warp
|
||||
// scheduler couldn't hide memory latency). With one batch row
|
||||
// per block we get K·D threads (e.g. 512 at K=32, D=16), which
|
||||
// is 16 warps — enough for the SM to overlap pending loads with
|
||||
// compute on other warps. Each block now produces (K, D) outputs
|
||||
// instead of (1, D), so total block count drops from B·K to B
|
||||
// (e.g. 65k → 2k at K=32, B=2048).
|
||||
//
|
||||
// Threads stride over `total = K · D` if the requested block
|
||||
// size exceeds 1024 (CUDA max). At D=16 this only kicks in for
|
||||
// K > 64, well above the DLRM range.
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void stacked_embedding_bag_sum_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ weight{idx_params}
|
||||
) {{
|
||||
const int B = {batch};
|
||||
const int L = {bag};
|
||||
const int D = {d};
|
||||
const int K = {num_tables};
|
||||
const int total = K * D;
|
||||
int b = blockIdx.x;
|
||||
if (b >= B) return;
|
||||
for (int tid = threadIdx.x; tid < total; tid += blockDim.x) {{
|
||||
int k = tid / D;
|
||||
int d = tid - k * D;
|
||||
float acc = 0.0f;
|
||||
switch (k) {{
|
||||
{switch}
|
||||
default: continue;
|
||||
}}
|
||||
// Output laid out as (B, K, D) row-major.
|
||||
out[(b * K + k) * D + d] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
batch = self.batch,
|
||||
bag = self.bag,
|
||||
d = self.d,
|
||||
num_tables = self.num_tables,
|
||||
idx_params = idx_params,
|
||||
switch = switch,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module
|
||||
.load_function("stacked_embedding_bag_sum_kernel")
|
||||
.unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
// Block size: enough threads to cover K·D output cells per batch
|
||||
// row, rounded up to a warp (32) for full warp utilization, capped
|
||||
// at 1024 (CUDA max block dim). Lower bound of 32 ensures we never
|
||||
// launch sub-warp blocks when K·D < 32 (e.g. N=1).
|
||||
let total = self.num_tables * self.d;
|
||||
let block_threads = total.next_multiple_of(32).clamp(32, 1024);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(self.batch),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(block_threads),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_tables * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per output element, L reads from weight. Index reads ~negligible
|
||||
// (D threads share the same L indices per output row).
|
||||
Expression::from(self.batch * self.num_tables * self.d * self.bag * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_tables * self.d * self.bag)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"StackedEmbeddingBagSum"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StackedEmbeddingBagSumCustom(pub StackedEmbeddingBagKernel);
|
||||
|
||||
impl CustomOp for StackedEmbeddingBagSumCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stacked-table fused EmbeddingBag with sum pooling.
|
||||
///
|
||||
/// * `stacked_weight`: `(sum_k rows_per_table[k], d)` F32, row-major.
|
||||
/// The k-th table's rows occupy indices `[row_offsets[k], row_offsets[k+1])`
|
||||
/// where `row_offsets[k] = sum_{j<k} rows_per_table[j]`.
|
||||
/// * `indices`: list of `num_tables` tensors, each `(batch, bag)` I32.
|
||||
/// Index values for table k are in `[0, rows_per_table[k])` — the
|
||||
/// per-table row offset is added inside the kernel.
|
||||
/// * `row_offsets`: cumulative starting row index for each table
|
||||
/// (length `num_tables + 1`).
|
||||
///
|
||||
/// Returns `(batch, num_tables, d)` F32. Use `slice_along` + `squeeze`
|
||||
/// (or the bundled `dlrm_pairwise_dot_lower_tri_stacked` op) to consume
|
||||
/// per-table outputs downstream.
|
||||
pub fn stacked_embedding_bag_sum_kernel(
|
||||
stacked_weight: GraphTensor,
|
||||
indices: Vec<GraphTensor>,
|
||||
row_offsets: &[usize],
|
||||
) -> GraphTensor {
|
||||
assert_eq!(
|
||||
stacked_weight.dtype,
|
||||
DType::F32,
|
||||
"stacked_embedding_bag_sum_kernel: weight must be F32"
|
||||
);
|
||||
let num_tables = indices.len();
|
||||
assert!(num_tables >= 1, "need at least one index tensor");
|
||||
assert_eq!(
|
||||
row_offsets.len(),
|
||||
num_tables + 1,
|
||||
"row_offsets must have num_tables+1 entries"
|
||||
);
|
||||
let w_dims = stacked_weight.dims();
|
||||
assert_eq!(w_dims.len(), 2, "stacked weight must be 2D (total_rows, d)");
|
||||
let total_rows = w_dims[0].to_usize().expect("total_rows must be static");
|
||||
assert_eq!(
|
||||
total_rows, row_offsets[num_tables],
|
||||
"row_offsets[-1] must equal weight total_rows"
|
||||
);
|
||||
let d = w_dims[1].to_usize().expect("d must be static");
|
||||
let i_dims = indices[0].dims();
|
||||
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
|
||||
let batch = i_dims[0].to_usize().expect("batch must be static");
|
||||
let bag = i_dims[1].to_usize().expect("bag must be static");
|
||||
for idx in &indices {
|
||||
assert_eq!(idx.dtype, DType::Int, "indices must be Int");
|
||||
let id = idx.dims();
|
||||
assert_eq!(id.len(), 2);
|
||||
assert_eq!(id[0].to_usize().unwrap(), batch);
|
||||
assert_eq!(id[1].to_usize().unwrap(), bag);
|
||||
}
|
||||
let kern = StackedEmbeddingBagKernel {
|
||||
batch,
|
||||
bag,
|
||||
d,
|
||||
num_tables,
|
||||
row_offsets: row_offsets.to_vec(),
|
||||
};
|
||||
let cx = unsafe { &mut *stacked_weight.graph_ref };
|
||||
let mut inputs = vec![stacked_weight];
|
||||
inputs.extend(indices);
|
||||
cx.custom_op(
|
||||
StackedEmbeddingBagSumCustom(kern),
|
||||
inputs,
|
||||
(batch, num_tables, d),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
|
||||
/// Fused EmbeddingBag with sum pooling (single table).
|
||||
///
|
||||
/// * `table`: `(n_emb, d)` F32, row-major.
|
||||
/// * `indices`: `(batch, bag)` I32, row-major. Values must be in `[0, n_emb)`.
|
||||
///
|
||||
/// Returns: `(batch, d)` F32, row-major. Each output row is the sum of
|
||||
/// `bag` looked-up rows from `table`.
|
||||
///
|
||||
/// All dimensions must be static. The returned tensor's graph node is a
|
||||
/// `dyn KernelOp` in LLIR, so it lives inside the same CudaGraphOp as
|
||||
/// surrounding kernel ops and benefits from the same CUDA-graph replay.
|
||||
pub fn embedding_bag_sum_kernel(table: GraphTensor, indices: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(table.dtype, DType::F32, "embedding_bag_sum_kernel: table must be F32");
|
||||
assert_eq!(
|
||||
indices.dtype,
|
||||
DType::Int,
|
||||
"embedding_bag_sum_kernel: indices must be Int"
|
||||
);
|
||||
let t_dims = table.dims();
|
||||
let i_dims = indices.dims();
|
||||
assert_eq!(t_dims.len(), 2, "table must be 2D (n_emb, d)");
|
||||
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
|
||||
let n_emb = t_dims[0].to_usize().expect("n_emb must be static");
|
||||
let d = t_dims[1].to_usize().expect("d must be static");
|
||||
let batch = i_dims[0].to_usize().expect("batch must be static");
|
||||
let bag = i_dims[1].to_usize().expect("bag must be static");
|
||||
|
||||
let kern = EmbeddingBagSumKernel {
|
||||
batch,
|
||||
bag,
|
||||
d,
|
||||
n_emb,
|
||||
};
|
||||
let cx = unsafe { &mut *table.graph_ref };
|
||||
cx.custom_op(
|
||||
EmbeddingBagSumCustom(kern),
|
||||
vec![table, indices],
|
||||
(batch, d),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-table EmbeddingBag (one kernel for K independent (weight, idx) pairs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Folds K independent `EmbeddingBag(sum)` lookups into a single CUDA
|
||||
/// kernel launch. Used by the vanilla-DLRMv1 translator path where the
|
||||
/// model has K separate `nn.EmbeddingBag` modules — each one would
|
||||
/// otherwise lower to its own (~5 µs) launch.
|
||||
///
|
||||
/// Inputs (in `KernelOp`-order):
|
||||
/// - `weight_0, weight_1, ..., weight_{K-1}` — each `(n_emb_k, d)` F32.
|
||||
/// **The per-table `n_emb` may differ**; only `d` and bag size `L`
|
||||
/// must match across tables.
|
||||
/// - `idx_0, idx_1, ..., idx_{K-1}` — each `(batch, L)` Int (i32).
|
||||
///
|
||||
/// Two packed staging buffers carry the K weight + K idx device pointers
|
||||
/// into the kernel (`build_params` fills them per execution via
|
||||
/// `cuMemcpyHtoD`). The hot loop reads each pointer from shared memory
|
||||
/// — no per-table switch needed.
|
||||
///
|
||||
/// Output shape: `(batch, num_tables, d)` F32, row-major.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiTableEmbeddingBagSumKernel {
|
||||
pub batch: usize,
|
||||
pub bag: usize,
|
||||
pub d: usize,
|
||||
pub num_tables: usize,
|
||||
}
|
||||
|
||||
impl KernelOp for MultiTableEmbeddingBagSumKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
// Layout (mirrors worktree's StackedEmbeddingBagSumKernel):
|
||||
// - One block per batch row (B blocks).
|
||||
// - Each block produces (K, D) output cells, striding over K·D
|
||||
// threads (rounded up to a warp).
|
||||
// - K weight pointers + K idx pointers come in via two packed
|
||||
// staging buffers populated in `build_params`.
|
||||
// - Shared memory caches both pointer arrays so the hot loop
|
||||
// reads at shmem latency.
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void multi_table_embedding_bag_sum_kernel(
|
||||
float* __restrict__ out,
|
||||
const long* __restrict__ w_ptrs_packed,
|
||||
const long* __restrict__ idx_ptrs_packed
|
||||
) {{
|
||||
const int B = {batch};
|
||||
const int L = {bag};
|
||||
const int D = {d};
|
||||
const int K = {num_tables};
|
||||
const int total = K * D;
|
||||
int b = blockIdx.x;
|
||||
if (b >= B) return;
|
||||
|
||||
__shared__ const float* s_w_ptrs[K];
|
||||
__shared__ const int* s_idx_ptrs[K];
|
||||
if (threadIdx.x < K) {{
|
||||
s_w_ptrs[threadIdx.x] = (const float*)(w_ptrs_packed[threadIdx.x]);
|
||||
s_idx_ptrs[threadIdx.x] = (const int*)(idx_ptrs_packed[threadIdx.x]);
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
for (int tid = threadIdx.x; tid < total; tid += blockDim.x) {{
|
||||
int k = tid / D;
|
||||
int d = tid - k * D;
|
||||
const float* w = s_w_ptrs[k];
|
||||
const int* idx = s_idx_ptrs[k];
|
||||
float acc = 0.0f;
|
||||
#pragma unroll 4
|
||||
for (int l = 0; l < L; ++l) {{
|
||||
int row = idx[b * L + l];
|
||||
acc += w[row * D + d];
|
||||
}}
|
||||
// (B, K, D) row-major.
|
||||
out[(b * K + k) * D + d] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
batch = self.batch,
|
||||
bag = self.bag,
|
||||
d = self.d,
|
||||
num_tables = self.num_tables,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module
|
||||
.load_function("multi_table_embedding_bag_sum_kernel")
|
||||
.unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let total = self.num_tables * self.d;
|
||||
let block_threads = total.next_multiple_of(32).clamp(32, 1024);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(self.batch),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(block_threads),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_tables * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_tables * self.d * self.bag * 4
|
||||
+ self.batch * self.num_tables * self.bag * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
Expression::from(self.batch * self.num_tables * self.d * self.bag)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MultiTableEmbeddingBagSum"
|
||||
}
|
||||
|
||||
/// Two staging buffers: one for K weight ptrs, one for K idx ptrs.
|
||||
/// Each is `K * 8` bytes (an array of u64s, written as `long*` on
|
||||
/// the device side).
|
||||
fn allocate_internal_buffers(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<CudaSlice<u8>> {
|
||||
let buf_size = self.num_tables * 8;
|
||||
vec![
|
||||
stream
|
||||
.alloc_zeros::<u8>(buf_size)
|
||||
.expect("alloc MultiTableEmbBag w-ptr staging buffer"),
|
||||
stream
|
||||
.alloc_zeros::<u8>(buf_size)
|
||||
.expect("alloc MultiTableEmbBag idx-ptr staging buffer"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Pack the K weight + K idx pointers into the two staging buffers
|
||||
/// each execution, then emit `[out, w_buf, idx_buf]` as kernel params.
|
||||
///
|
||||
/// `input_ptrs` layout: `[w_0, w_1, ..., w_{K-1}, idx_0, ..., idx_{K-1}]`.
|
||||
/// `cuMemcpyHtoD_v2` is a blocking host call so by the time we return
|
||||
/// the staging buffers are populated and the subsequent CUDA-graph
|
||||
/// node-param update reads stable device pointers.
|
||||
fn build_params(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
internal_bufs: &[CudaSlice<u8>],
|
||||
_dyn_dims_ptr: u64,
|
||||
) -> Vec<u64> {
|
||||
assert_eq!(
|
||||
input_ptrs.len(),
|
||||
2 * self.num_tables,
|
||||
"MultiTableEmbeddingBagSum: expected {} input pointers (K weights + K idx), got {}",
|
||||
2 * self.num_tables,
|
||||
input_ptrs.len(),
|
||||
);
|
||||
let (w_ptrs, idx_ptrs) = input_ptrs.split_at(self.num_tables);
|
||||
let w_buf = &internal_bufs[0];
|
||||
let idx_buf = &internal_bufs[1];
|
||||
let w_dev_ptr: u64 = w_buf.device_ptr(stream).0;
|
||||
let idx_dev_ptr: u64 = idx_buf.device_ptr(stream).0;
|
||||
unsafe {
|
||||
let r1 = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
w_dev_ptr,
|
||||
w_ptrs.as_ptr() as *const std::ffi::c_void,
|
||||
w_ptrs.len() * 8,
|
||||
);
|
||||
assert_eq!(
|
||||
r1,
|
||||
cudarc::driver::sys::CUresult::CUDA_SUCCESS,
|
||||
"cuMemcpyHtoD_v2 for MultiTableEmbBag w-ptr staging failed: {r1:?}",
|
||||
);
|
||||
let r2 = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
idx_dev_ptr,
|
||||
idx_ptrs.as_ptr() as *const std::ffi::c_void,
|
||||
idx_ptrs.len() * 8,
|
||||
);
|
||||
assert_eq!(
|
||||
r2,
|
||||
cudarc::driver::sys::CUresult::CUDA_SUCCESS,
|
||||
"cuMemcpyHtoD_v2 for MultiTableEmbBag idx-ptr staging failed: {r2:?}",
|
||||
);
|
||||
}
|
||||
vec![output_ptr, w_dev_ptr, idx_dev_ptr]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiTableEmbeddingBagSumCustom(pub MultiTableEmbeddingBagSumKernel);
|
||||
|
||||
impl CustomOp for MultiTableEmbeddingBagSumCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Frontend helper: K independent EmbeddingBag(sum) lookups in one
|
||||
/// kernel launch. Returns `(batch, num_tables, d)` F32, row-major;
|
||||
/// slice along axis 1 (`out.slice_along(k..k+1, 1).squeeze(1)`) to
|
||||
/// recover the k-th table's `(batch, d)` output.
|
||||
///
|
||||
/// * `weights`: K `(n_emb_k, d)` F32 tensors. Per-table `n_emb` may
|
||||
/// differ; only `d` must be shared.
|
||||
/// * `indices`: K `(batch, bag)` Int tensors (cast `.cast(DType::Int)`
|
||||
/// on the caller side if your indices are i64).
|
||||
pub fn multi_table_embedding_bag_sum_kernel(
|
||||
weights: Vec<GraphTensor>,
|
||||
indices: Vec<GraphTensor>,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
indices.len(),
|
||||
"multi_table_embedding_bag_sum_kernel: need one weight per index tensor"
|
||||
);
|
||||
let num_tables = weights.len();
|
||||
assert!(num_tables >= 1, "need at least one table");
|
||||
let first_w = weights[0];
|
||||
let first_idx = indices[0];
|
||||
let w_dims = first_w.dims();
|
||||
let i_dims = first_idx.dims();
|
||||
assert_eq!(w_dims.len(), 2, "weights must be 2D (n_emb, d)");
|
||||
assert_eq!(i_dims.len(), 2, "indices must be 2D (batch, bag)");
|
||||
let d = w_dims[1].to_usize().expect("d must be static");
|
||||
let batch = i_dims[0].to_usize().expect("batch must be static");
|
||||
let bag = i_dims[1].to_usize().expect("bag must be static");
|
||||
for w in &weights {
|
||||
assert_eq!(w.dtype, DType::F32, "weights must all be F32");
|
||||
let wd = w.dims();
|
||||
assert_eq!(wd.len(), 2, "weight must be 2D");
|
||||
assert_eq!(
|
||||
wd[1].to_usize().unwrap(),
|
||||
d,
|
||||
"all weights must share inner dim"
|
||||
);
|
||||
}
|
||||
for idx in &indices {
|
||||
assert_eq!(idx.dtype, DType::Int, "indices must all be Int (i32)");
|
||||
let id = idx.dims();
|
||||
assert_eq!(id.len(), 2);
|
||||
assert_eq!(id[0].to_usize().unwrap(), batch);
|
||||
assert_eq!(id[1].to_usize().unwrap(), bag);
|
||||
}
|
||||
let kern = MultiTableEmbeddingBagSumKernel {
|
||||
batch,
|
||||
bag,
|
||||
d,
|
||||
num_tables,
|
||||
};
|
||||
let mut inputs = weights;
|
||||
inputs.extend(indices);
|
||||
let cx = unsafe { &mut *first_w.graph_ref };
|
||||
cx.custom_op(
|
||||
MultiTableEmbeddingBagSumCustom(kern),
|
||||
inputs,
|
||||
(batch, num_tables, d),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
@@ -52,6 +52,19 @@ use crate::kernel::KernelOp;
|
||||
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
|
||||
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
|
||||
/// `(N,)` and broadcast across batches.
|
||||
/// Activation epilogue fused into the matmul kernel's store path.
|
||||
///
|
||||
/// Saves one full pass over the output buffer per MLP layer — the same
|
||||
/// trick cuBLASLt does with `CUBLASLT_EPILOGUE_RELU_BIAS` etc., but
|
||||
/// inside our custom kernel so we don't have to invoke cuBLASLt.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
None,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DKernel {
|
||||
pub m: usize,
|
||||
@@ -65,6 +78,18 @@ pub struct Matmul2DKernel {
|
||||
pub has_bias: bool,
|
||||
/// Storage dtype of B. Currently F32 or BF16 are supported.
|
||||
pub weight_dtype: DType,
|
||||
/// Activation applied to `acc + bias` before writing to C.
|
||||
/// Defaults to None; ReLU and Sigmoid avoid a separate elementwise
|
||||
/// pass over the matmul output.
|
||||
pub activation: Activation,
|
||||
/// When `Some(split)`, A is read from two source pointers:
|
||||
/// columns `0..split` → `A_lo`, stride `split` per row
|
||||
/// columns `split..K` → `A_hi`, stride `K - split` per row
|
||||
/// This lets a `cat(A_lo, A_hi)` materialization be skipped entirely —
|
||||
/// the K-loop's A-load branches on the column index instead. `None`
|
||||
/// keeps the existing single-pointer path. Only supported for
|
||||
/// `batch == 1` (DLRM's use case); the kernel asserts on this.
|
||||
pub a_split: Option<usize>,
|
||||
}
|
||||
|
||||
const TILE: usize = 16;
|
||||
@@ -93,6 +118,46 @@ impl KernelOp for Matmul2DKernel {
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let activation_apply = match self.activation {
|
||||
Activation::None => "",
|
||||
// Branchless ReLU; keeps the fully-occupied write path simple.
|
||||
Activation::Relu => " acc = fmaxf(acc, 0.0f);\n",
|
||||
// Sigmoid: 1/(1+exp(-acc)). Used by DLRM's final layer.
|
||||
Activation::Sigmoid => " acc = 1.0f / (1.0f + __expf(-acc));\n",
|
||||
};
|
||||
// A-input parameter declaration + per-K-tile load expression depend
|
||||
// on whether the caller asked for the dual-source (split) path.
|
||||
// Single-source (default) keeps the original `const float* A` and
|
||||
// reads `A[a_m * K + a_k]`. Split mode takes two pointer args
|
||||
// (A_lo / A_hi) and selects between them at runtime by comparing
|
||||
// `a_k` against the compile-time-baked split column.
|
||||
let (a_param_decl, a_load_expr) = if let Some(split) = self.a_split {
|
||||
assert!(
|
||||
split > 0 && split < self.k,
|
||||
"Matmul2DKernel a_split must be in 1..K; got split={split}, K={}",
|
||||
self.k
|
||||
);
|
||||
assert_eq!(
|
||||
self.batch, 1,
|
||||
"Matmul2DKernel a_split path only supports batch=1 (got batch={})",
|
||||
self.batch
|
||||
);
|
||||
let hi = self.k - split;
|
||||
(
|
||||
"const float* __restrict__ A_lo, const float* __restrict__ A_hi"
|
||||
.to_string(),
|
||||
format!(
|
||||
"((a_k < {split}) \
|
||||
? A_lo[a_m * {split} + a_k] \
|
||||
: A_hi[a_m * {hi} + (a_k - {split})])"
|
||||
),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
"const float* __restrict__ A".to_string(),
|
||||
"A[a_batch_off + a_m * K + a_k]".to_string(),
|
||||
)
|
||||
};
|
||||
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
|
||||
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
|
||||
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
|
||||
@@ -122,7 +187,7 @@ impl KernelOp for Matmul2DKernel {
|
||||
"
|
||||
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ A,
|
||||
{a_param_decl},
|
||||
{b_param_type}{bias_param}
|
||||
) {{
|
||||
const int M = {m};
|
||||
@@ -156,10 +221,12 @@ impl KernelOp for Matmul2DKernel {
|
||||
for (int t = 0; t < n_tiles; ++t) {{
|
||||
int k0 = t * TILE;
|
||||
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]. In single-source
|
||||
// mode this is `A[a_batch_off + a_m * K + a_k]`. In split mode the
|
||||
// load expression branches on `a_k < split` (baked in by the host).
|
||||
int a_m = a_m_base + ty;
|
||||
int a_k = k0 + tx;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? ({a_load_expr}) : 0.0f;
|
||||
|
||||
// Load B tile depending on transpose_b
|
||||
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
|
||||
@@ -182,7 +249,7 @@ impl KernelOp for Matmul2DKernel {
|
||||
|
||||
if (m_global < M && n_global < N) {{
|
||||
int n = n_global;
|
||||
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
{bias_add}{activation_apply} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
@@ -195,7 +262,10 @@ impl KernelOp for Matmul2DKernel {
|
||||
b_param_type = b_param_type,
|
||||
bias_param = bias_param,
|
||||
bias_add = bias_add,
|
||||
activation_apply = activation_apply,
|
||||
bf16_include = bf16_include,
|
||||
a_param_decl = a_param_decl,
|
||||
a_load_expr = a_load_expr,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
@@ -264,7 +334,20 @@ impl KernelOp for Matmul2DKernel {
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Matmul2D"
|
||||
match (self.has_bias, self.activation, self.a_split.is_some()) {
|
||||
(true, Activation::Relu, false) => "Matmul2D_BiasRelu",
|
||||
(true, Activation::Sigmoid, false) => "Matmul2D_BiasSigmoid",
|
||||
(true, Activation::None, false) => "Matmul2D_Bias",
|
||||
(false, Activation::Relu, false) => "Matmul2D_Relu",
|
||||
(false, Activation::Sigmoid, false) => "Matmul2D_Sigmoid",
|
||||
(false, Activation::None, false) => "Matmul2D",
|
||||
(true, Activation::Relu, true) => "Matmul2D_BiasRelu_SplitA",
|
||||
(true, Activation::Sigmoid, true) => "Matmul2D_BiasSigmoid_SplitA",
|
||||
(true, Activation::None, true) => "Matmul2D_Bias_SplitA",
|
||||
(false, Activation::Relu, true) => "Matmul2D_Relu_SplitA",
|
||||
(false, Activation::Sigmoid, true) => "Matmul2D_Sigmoid_SplitA",
|
||||
(false, Activation::None, true) => "Matmul2D_SplitA",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,20 +363,82 @@ impl CustomOp for Matmul2DCustom {
|
||||
|
||||
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
|
||||
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None, Activation::None)
|
||||
}
|
||||
|
||||
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
|
||||
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
|
||||
/// pattern produced by linear / projection layers (`x @ w.t()`).
|
||||
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None, Activation::None)
|
||||
}
|
||||
|
||||
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
|
||||
/// `(N,)`, row-major F32 throughout.
|
||||
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::None)
|
||||
}
|
||||
|
||||
/// Like [`linear_bias`] but applies ReLU in the kernel epilogue. Saves
|
||||
/// one full pass over the output buffer per layer.
|
||||
pub fn linear_bias_relu(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::Relu)
|
||||
}
|
||||
|
||||
/// Like [`linear_bias`] but applies Sigmoid in the kernel epilogue.
|
||||
/// Used for the final layer of binary-classifier MLPs (DLRM CTR head).
|
||||
pub fn linear_bias_sigmoid(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias), Activation::Sigmoid)
|
||||
}
|
||||
|
||||
/// Two-A-input variant of [`linear_bias`].
|
||||
///
|
||||
/// Computes `cat(a_lo, a_hi) @ bᵀ + bias` *without* materializing the
|
||||
/// concat — the K-loop's A-load reads from `a_lo` for columns `0..K_lo`
|
||||
/// and from `a_hi` for columns `K_lo..K_lo+K_hi`. Logically equivalent
|
||||
/// to feeding `concat_along(a_lo, a_hi, 1)` into [`linear_bias`], but
|
||||
/// skips ~9 scaffolding kernels (Iota + Cast + Gather + masked-add) per
|
||||
/// concat call.
|
||||
///
|
||||
/// Shapes:
|
||||
/// * `a_lo`: `(M, K_lo)` F32
|
||||
/// * `a_hi`: `(M, K_hi)` F32
|
||||
/// * `b`: `(N, K_lo + K_hi)` F32 (transposed convention, same as
|
||||
/// [`linear_bias`])
|
||||
/// * `bias`: `(N,)` F32
|
||||
///
|
||||
/// Output: `(M, N)` F32. Only 2D inputs are supported (batch=1).
|
||||
pub fn linear_bias_split_a(
|
||||
a_lo: GraphTensor,
|
||||
a_hi: GraphTensor,
|
||||
b: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::None)
|
||||
}
|
||||
|
||||
/// Like [`linear_bias_split_a`] but applies ReLU in the kernel epilogue.
|
||||
/// Use this for hidden MLP layers that consume a concat of two upstream
|
||||
/// tensors — the natural shape of DLRM's top-MLP first layer (which reads
|
||||
/// `cat(dense_out, interactions)`).
|
||||
pub fn linear_bias_relu_split_a(
|
||||
a_lo: GraphTensor,
|
||||
a_hi: GraphTensor,
|
||||
b: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::Relu)
|
||||
}
|
||||
|
||||
/// Like [`linear_bias_split_a`] but applies Sigmoid in the kernel
|
||||
/// epilogue.
|
||||
pub fn linear_bias_sigmoid_split_a(
|
||||
a_lo: GraphTensor,
|
||||
a_hi: GraphTensor,
|
||||
b: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
matmul_inner_split_a(a_lo, a_hi, b, Some(bias), Activation::Sigmoid)
|
||||
}
|
||||
|
||||
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
|
||||
@@ -321,12 +466,12 @@ pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor
|
||||
|
||||
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
|
||||
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None, Activation::None)
|
||||
}
|
||||
|
||||
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
|
||||
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None, Activation::None)
|
||||
}
|
||||
|
||||
fn matmul_inner(
|
||||
@@ -334,6 +479,7 @@ fn matmul_inner(
|
||||
b: GraphTensor,
|
||||
transpose_b: bool,
|
||||
bias: Option<GraphTensor>,
|
||||
activation: Activation,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
|
||||
let weight_dtype = b.dtype;
|
||||
@@ -412,6 +558,8 @@ fn matmul_inner(
|
||||
transpose_b,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
activation,
|
||||
a_split: None,
|
||||
};
|
||||
let cx = unsafe { &mut *a.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
@@ -425,3 +573,71 @@ fn matmul_inner(
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal helper for the split-A path. Validates shapes and dispatches
|
||||
/// to a [`Matmul2DKernel`] with `a_split = Some(K_lo)`. Always uses
|
||||
/// `transpose_b = true` (linear-projection convention; matches
|
||||
/// [`linear_bias`]). Only 2D inputs are supported.
|
||||
fn matmul_inner_split_a(
|
||||
a_lo: GraphTensor,
|
||||
a_hi: GraphTensor,
|
||||
b: GraphTensor,
|
||||
bias: Option<GraphTensor>,
|
||||
activation: Activation,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a_lo.dtype, DType::F32, "split-A matmul requires F32 A_lo");
|
||||
assert_eq!(a_hi.dtype, DType::F32, "split-A matmul requires F32 A_hi");
|
||||
let weight_dtype = b.dtype;
|
||||
assert_eq!(
|
||||
weight_dtype,
|
||||
DType::F32,
|
||||
"split-A matmul currently only supports F32 B (got {weight_dtype:?})"
|
||||
);
|
||||
let lo_dims = a_lo.dims();
|
||||
let hi_dims = a_hi.dims();
|
||||
let b_dims = b.dims();
|
||||
assert_eq!(lo_dims.len(), 2, "split-A matmul A_lo must be 2D");
|
||||
assert_eq!(hi_dims.len(), 2, "split-A matmul A_hi must be 2D");
|
||||
assert_eq!(b_dims.len(), 2, "split-A matmul B must be 2D");
|
||||
let m = lo_dims[0].to_usize().expect("M must be a static dim");
|
||||
let m_hi = hi_dims[0].to_usize().expect("M (A_hi) must be a static dim");
|
||||
assert_eq!(m, m_hi, "split-A matmul: A_lo and A_hi must have the same M");
|
||||
let k_lo = lo_dims[1].to_usize().expect("K_lo must be a static dim");
|
||||
let k_hi = hi_dims[1].to_usize().expect("K_hi must be a static dim");
|
||||
let k = k_lo + k_hi;
|
||||
let n = b_dims[0].to_usize().expect("N must be a static dim");
|
||||
let k_b = b_dims[1].to_usize().expect("K (B) must be a static dim");
|
||||
assert_eq!(
|
||||
k, k_b,
|
||||
"split-A matmul: A_lo.K + A_hi.K = {k} must equal B.K = {k_b}"
|
||||
);
|
||||
let has_bias = bias.is_some();
|
||||
if let Some(bias) = bias {
|
||||
let bdims = bias.dims();
|
||||
assert_eq!(bdims.len(), 1, "split-A matmul bias must be 1D");
|
||||
assert_eq!(
|
||||
bdims[0].to_usize().expect("bias dim must be static"),
|
||||
n,
|
||||
"split-A matmul bias size must equal N"
|
||||
);
|
||||
assert_eq!(bias.dtype, DType::F32, "split-A matmul bias must be F32");
|
||||
}
|
||||
let kern = Matmul2DKernel {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch: 1,
|
||||
transpose_b: true,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
activation,
|
||||
a_split: Some(k_lo),
|
||||
};
|
||||
let cx = unsafe { &mut *a_lo.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
vec![a_lo, a_hi, b, bias]
|
||||
} else {
|
||||
vec![a_lo, a_hi, b]
|
||||
};
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod dlrm_interact;
|
||||
pub mod embedding_bag;
|
||||
pub mod fusion;
|
||||
pub mod generic_matmul;
|
||||
pub mod hlir;
|
||||
@@ -20,10 +22,22 @@ pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use dlrm_interact::{
|
||||
PairwiseDotLowerTriCustom, PairwiseDotLowerTriKernel, PairwiseDotLowerTriStackedCustom,
|
||||
PairwiseDotLowerTriStackedKernel, dlrm_pairwise_dot_lower_tri,
|
||||
dlrm_pairwise_dot_lower_tri_stacked,
|
||||
};
|
||||
pub use embedding_bag::{
|
||||
EmbeddingBagSumCustom, EmbeddingBagSumKernel, MultiTableEmbeddingBagSumCustom,
|
||||
MultiTableEmbeddingBagSumKernel, StackedEmbeddingBagKernel,
|
||||
StackedEmbeddingBagSumCustom, embedding_bag_sum_kernel,
|
||||
multi_table_embedding_bag_sum_kernel, stacked_embedding_bag_sum_kernel,
|
||||
};
|
||||
pub use generic_matmul::GenericMatmul;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
Activation, Matmul2DCustom, Matmul2DKernel, linear_bias, linear_bias_relu,
|
||||
linear_bias_relu_split_a, linear_bias_sigmoid, linear_bias_sigmoid_split_a,
|
||||
linear_bias_split_a, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t, matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
|
||||
@@ -215,6 +215,25 @@ impl CudaRuntime {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Read per-kernel GPU elapsed times (ms), if event-record nodes were
|
||||
/// inserted at graph build time.
|
||||
///
|
||||
/// The per-kernel event recording infra from `origin/dlrm-fused-kernels`
|
||||
/// is not ported on this branch yet — this stub returns empty so the
|
||||
/// dlrm example's optional `LUMINAL_KERNEL_TIMING=1` path falls back to
|
||||
/// "(no per-kernel timings available — events not recorded)".
|
||||
pub fn read_per_kernel_timings_ms(&self) -> Vec<(&'static str, f32)> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Synchronize the runtime's CUDA stream. Use this after `execute()` if
|
||||
/// you need GPU-side completion time (e.g. for benchmarking) — `execute`
|
||||
/// itself no longer syncs at the end so it stays capturable by
|
||||
/// `torch.cuda.CUDAGraph` and similar external graph-capture machinery.
|
||||
pub fn synchronize_stream(&self) {
|
||||
let _ = self.cuda_stream.synchronize();
|
||||
}
|
||||
|
||||
fn bucket_buffer(
|
||||
bucket: &CompiledBucket,
|
||||
stream: &Arc<CudaStream>,
|
||||
@@ -1675,8 +1694,15 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Single sync at end - CUDA stream ordering guarantees sequential execution
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
// Sync only when profiling (kernel search timing needs an accurate
|
||||
// total). In the regular execute path, dropping the sync lets the
|
||||
// call be captured by `torch.cuda.CUDAGraph` (or any external graph
|
||||
// capture) — PyTorch syncs on tensor reads, so correctness is
|
||||
// preserved. The CPU-side `last_total_time_us` becomes a dispatch-
|
||||
// time measurement in that case.
|
||||
if self.profiling {
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
}
|
||||
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
|
||||
// Populate last_kernel_stats from HostOps that report stats
|
||||
@@ -1717,9 +1743,22 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
let to_consume: Vec<NodeIndex> = self
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.copied()
|
||||
.iter()
|
||||
// Don't consume external device pointers — they're non-owning
|
||||
// views over caller-provided memory and they represent stable
|
||||
// input slots that the wrapper *may* skip re-registering on a
|
||||
// hot iter if the underlying pointer is unchanged. Removing
|
||||
// them here would force the wrapper to re-register every
|
||||
// iter even when nothing changed (which the prior behavior
|
||||
// assumed). Internal `CudaInput::Buffer` entries — e.g.
|
||||
// weights loaded via `set_data_bytes` and one-shot CPU input
|
||||
// copies — still get consumed when they're not preserved by
|
||||
// the bucket.
|
||||
.filter(|(hlir_node, input)| {
|
||||
!inputs_with_outputs.contains(hlir_node)
|
||||
&& !matches!(input, CudaInput::Ptr(_))
|
||||
})
|
||||
.map(|(n, _)| *n)
|
||||
.collect();
|
||||
|
||||
for hlir_node in to_consume {
|
||||
|
||||
@@ -391,6 +391,73 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolve an input or output tensor name to an opaque integer ID.
|
||||
/// One FFI hop at compile time, then per-iter `run_with_ptrs` and
|
||||
/// the `*_by_id` setters can skip the string-keyed HashMap lookup.
|
||||
/// Returns the underlying `NodeIndex.index() as u32` — caller should
|
||||
/// treat it as opaque.
|
||||
fn tensor_id(&self, name: &str) -> PyResult<u32> {
|
||||
self.tensor_ids
|
||||
.get(name)
|
||||
.map(|n| n.index() as u32)
|
||||
.ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown tensor: {}",
|
||||
name
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// One-shot batched: register all input + output device pointers and
|
||||
/// execute. Collapses the per-iter `set_input_device_ptr` ×
|
||||
/// (n_inputs) + `set_output_device_ptr` × (n_outputs) + `run` chain
|
||||
/// into a single Python→Rust FFI crossing.
|
||||
///
|
||||
/// Inputs and outputs use the opaque IDs returned by `tensor_id(name)`
|
||||
/// to skip the per-call string lookup. Returns a parallel vector of
|
||||
/// "is zero-copy" booleans for each output (same as
|
||||
/// `output_is_zero_copy(name)` on the unbatched path), so the Python
|
||||
/// caller can fall back to a DtoD copy when a kernel aliased an
|
||||
/// output instead of writing into the registered buffer.
|
||||
///
|
||||
/// Safety contract:
|
||||
/// * Every `device_ptr` must point to a valid CUDA allocation with
|
||||
/// at least `n_bytes` bytes.
|
||||
/// * Pointers must remain valid through the duration of `run()`.
|
||||
/// * `tensor_id`s must come from `self.tensor_id(name)` on this
|
||||
/// same graph — using stale IDs from a different graph is UB
|
||||
/// (will likely panic in `NodeIndex` lookups, but not guaranteed).
|
||||
fn run_with_ptrs(
|
||||
&mut self,
|
||||
inputs: Vec<(u32, u64, usize)>,
|
||||
outputs: Vec<(u32, u64, usize)>,
|
||||
) -> PyResult<Vec<bool>> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"run_with_ptrs requires a GPU backend",
|
||||
));
|
||||
}
|
||||
// Register inputs.
|
||||
for (id, ptr, n) in &inputs {
|
||||
let node_id = NodeIndex::new(*id as usize);
|
||||
unsafe { self.runtime.set_device_ptr(node_id, *ptr, *n) };
|
||||
}
|
||||
// Register outputs.
|
||||
for (id, ptr, n) in &outputs {
|
||||
let node_id = NodeIndex::new(*id as usize);
|
||||
unsafe { self.runtime.set_output_device_ptr(node_id, *ptr, *n) };
|
||||
}
|
||||
// Execute.
|
||||
self.runtime.execute(&self.graph.dyn_map);
|
||||
// Report zero-copy status for each output (parallel to `outputs`
|
||||
// input order). Aliased outputs need a DtoD copy on the Python
|
||||
// side, same as the unbatched path.
|
||||
Ok(outputs
|
||||
.iter()
|
||||
.map(|(id, _, _)| self.runtime.output_is_zero_copy(NodeIndex::new(*id as usize)))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Register a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
|
||||
@@ -119,30 +119,147 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
"torch.ops.aten.mm.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
a.matmul(b)
|
||||
}
|
||||
// bmm: batched 3-D matmul. Fast path under cuda + F32 when
|
||||
// B was produced by a permute([0, 2, 1]) (i.e. `T @ T.T` —
|
||||
// the DLRM pairwise-interaction pattern): route to
|
||||
// `matmul_3d_t` with the original (B, F, D) tensor, which
|
||||
// uses the fused Matmul2DKernel and avoids the
|
||||
// expand+mul+sum-reduce decomposition that produces ~25
|
||||
// small kernels per bmm.
|
||||
"torch.ops.aten.bmm.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
|
||||
let b_src = node.inputs.get(1).and_then(|n| n.arg.as_tensor_name())
|
||||
.and_then(|n| self.transpose_2d_source.get(n).cloned());
|
||||
|
||||
let f32_all = a.dtype == DType::F32 && b.dtype == DType::F32;
|
||||
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
|
||||
if cfg!(feature = "cuda")
|
||||
&& backend_is_cuda
|
||||
&& f32_all
|
||||
&& a.shape.dims.len() == 3
|
||||
&& b.shape.dims.len() == 3
|
||||
&& let Some(orig_name) = b_src
|
||||
&& let Some(orig_b) = self.tensors.get(&orig_name).copied()
|
||||
&& orig_b.shape.dims.len() == 3
|
||||
{
|
||||
// a: (B, M, K), orig_b: (B, N, K) — matmul_3d_t does
|
||||
// a @ orig_b.t() = (B, M, K) @ (B, K, N) = (B, M, N).
|
||||
luminal_cuda_lite::kernel::matmul_3d_t(a, orig_b)
|
||||
} else {
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
a.matmul(b)
|
||||
}
|
||||
}
|
||||
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
//
|
||||
// Fast path (CUDA, the common nn.Linear case): when
|
||||
// * shapes are 2-D F32, alpha=beta=1
|
||||
// * mat2 was produced by `aten.permute([1,0])` of a 2-D
|
||||
// tensor (`weight.t()` from nn.Linear)
|
||||
// * bias is 1-D
|
||||
// we lower to the fused `linear_bias` kernel using the
|
||||
// *original* (N,K) weight — bypassing the
|
||||
// expand+mul+sum-reduce decomposition that otherwise
|
||||
// produces ~25 small kernels per Linear layer (~3.7 ms on
|
||||
// tiny shapes due to launch overhead).
|
||||
//
|
||||
// The transpose detection comes from
|
||||
// `translate_permute`, which populates
|
||||
// `transpose_2d_source` whenever it sees a 2-D permute.
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
let mat2 = self.get_input_tensor(node, 2)?;
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
|
||||
let mat2_src = node.inputs.get(2).and_then(|n| n.arg.as_tensor_name())
|
||||
.and_then(|n| self.transpose_2d_source.get(n).cloned());
|
||||
|
||||
let unit_scale = (alpha - 1.0).abs() < 1e-7 && (beta - 1.0).abs() < 1e-7;
|
||||
let f32_all = mat1.dtype == DType::F32
|
||||
&& mat2.dtype == DType::F32
|
||||
&& input.dtype == DType::F32;
|
||||
let two_d = mat1.shape.dims.len() == 2 && mat2.shape.dims.len() == 2;
|
||||
let bias_is_1d = input.shape.dims.len() == 1;
|
||||
|
||||
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
|
||||
if cfg!(feature = "cuda")
|
||||
&& backend_is_cuda
|
||||
&& two_d
|
||||
&& f32_all
|
||||
&& unit_scale
|
||||
&& bias_is_1d
|
||||
&& let Some(weight_name) = mat2_src
|
||||
&& let Some(orig_weight) = self.tensors.get(&weight_name).copied()
|
||||
&& orig_weight.shape.dims.len() == 2
|
||||
{
|
||||
// Forward-looking fusion: if this addmm has exactly one
|
||||
// consumer and that consumer is `aten.relu.default` or
|
||||
// `aten.sigmoid.default`, emit the fused
|
||||
// `linear_bias_relu` / `linear_bias_sigmoid` kernel and
|
||||
// mark the consumer as absorbed so we don't emit a
|
||||
// redundant unary op downstream. This collapses the
|
||||
// standard nn.Linear+ReLU MLP layer to one kernel.
|
||||
let addmm_out: Option<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let fuse_act = addmm_out
|
||||
.as_deref()
|
||||
.and_then(|n| self.unique_consumer(n));
|
||||
let (fused, absorbed) = match fuse_act {
|
||||
Some((target, out_name)) if target == "torch.ops.aten.relu.default" => {
|
||||
(
|
||||
Some(luminal_cuda_lite::kernel::linear_bias_relu(
|
||||
mat1, orig_weight, input,
|
||||
)),
|
||||
Some(out_name),
|
||||
)
|
||||
}
|
||||
Some((target, out_name)) if target == "torch.ops.aten.sigmoid.default" => {
|
||||
(
|
||||
Some(luminal_cuda_lite::kernel::linear_bias_sigmoid(
|
||||
mat1, orig_weight, input,
|
||||
)),
|
||||
Some(out_name),
|
||||
)
|
||||
}
|
||||
_ => (None, None),
|
||||
};
|
||||
if let Some(t) = fused {
|
||||
if let Some(out_name) = absorbed {
|
||||
self.absorbed_nodes.insert(out_name.clone());
|
||||
self.tensors.insert(out_name, t);
|
||||
}
|
||||
t
|
||||
} else {
|
||||
// No unary consumer to fuse; plain linear+bias.
|
||||
luminal_cuda_lite::kernel::linear_bias(mat1, orig_weight, input)
|
||||
}
|
||||
} else {
|
||||
// Generic fallback (non-cuda, scaled, or unknown
|
||||
// mat2 source).
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
}
|
||||
}
|
||||
|
||||
// Convolution
|
||||
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
|
||||
|
||||
// Reduction ops
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_sum_with_embbag_peephole(node)?,
|
||||
"torch.ops.aten.mean.dim" => self.translate_reduction(node, ReductionOp::Mean)?,
|
||||
"torch.ops.aten.amax.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
|
||||
@@ -151,10 +268,25 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
|
||||
// EmbeddingBag (sum-pool, fixed bag size). PT2 export decomposes
|
||||
// nn.EmbeddingBag → `_embedding_bag` (with backward) or
|
||||
// `_embedding_bag_forward_only` (without). Both return the same
|
||||
// 4-tuple `(output, offset2bag, bag_size, max_indices)` and only
|
||||
// the first slot is consumed at inference. The two op variants
|
||||
// are math-identical for the forward path, so route both through
|
||||
// the same handler. The handler stores into `tensors` itself
|
||||
// (multi-output op) so we return early afterwards.
|
||||
"torch.ops.aten._embedding_bag_forward_only.default"
|
||||
| "torch.ops.aten._embedding_bag.default" => {
|
||||
self.translate_embedding_bag_forward_only(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -514,6 +646,51 @@ impl<'a> Translator<'a> {
|
||||
};
|
||||
|
||||
if !output_name.is_empty() {
|
||||
// Record the chain (FX target + first input name) keyed by
|
||||
// output name so multi-node peepholes (e.g. the EmbBag fast
|
||||
// path that detects sum ← view ← index_select) can walk
|
||||
// back without re-scanning all parsed nodes.
|
||||
//
|
||||
// For variadic ops (e.g. `aten.cat.default` whose first
|
||||
// arg is `as_tensors`) fall back to the first entry of the
|
||||
// variadic tensor list. The DLRM PairwiseDot peephole
|
||||
// needs `node_chain[cat]` to walk back from `bmm → cat`.
|
||||
let first_input_name: Option<String> = node
|
||||
.inputs
|
||||
.first()
|
||||
.and_then(|i| {
|
||||
i.arg
|
||||
.as_tensor_name()
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
i.arg
|
||||
.as_tensors()
|
||||
.and_then(|ts| ts.first().map(|tn| tn.name.clone()))
|
||||
})
|
||||
});
|
||||
if let Some(first_input) = first_input_name {
|
||||
self.node_chain.insert(
|
||||
output_name.clone(),
|
||||
(node.target.clone(), first_input),
|
||||
);
|
||||
}
|
||||
// Also record the full input-name list (in order, including
|
||||
// entries that come from `as_tensors` for variadic ops like
|
||||
// `aten.cat`). Used by the DLRM PairwiseDot peephole which
|
||||
// needs all cat inputs and both bmm inputs.
|
||||
let mut all_inputs: Vec<String> = Vec::new();
|
||||
for inp in &node.inputs {
|
||||
if let Some(names) = inp.arg.as_tensors() {
|
||||
for tn in names {
|
||||
all_inputs.push(tn.name.clone());
|
||||
}
|
||||
} else if let Some(name) = inp.arg.as_tensor_name() {
|
||||
all_inputs.push(name.to_string());
|
||||
}
|
||||
}
|
||||
if !all_inputs.is_empty() {
|
||||
self.op_inputs.insert(output_name.clone(), all_inputs);
|
||||
}
|
||||
self.tensors.insert(output_name, result);
|
||||
}
|
||||
Ok(())
|
||||
@@ -521,6 +698,69 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Peephole for the DLRM-v3 embedding-bag pattern:
|
||||
/// `sum(dim=[1], keepdim=False)( view([?, L, D])( index_select(W, 0, IDX) ) )`
|
||||
/// substitutes the fused `embedding_bag_sum_kernel(W, IDX.view(?, L))`
|
||||
/// — same kernel as the hand-rolled DLRM example uses. Falls back to
|
||||
/// the generic reduction path when the chain doesn't match.
|
||||
pub(crate) fn translate_sum_with_embbag_peephole(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
) -> Result<GraphTensor> {
|
||||
let dims = self.get_ints_arg(node, 1).unwrap_or_default();
|
||||
let keepdim = self.get_bool_arg(node, 2).unwrap_or(false);
|
||||
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
|
||||
|
||||
// Only attempt fast-path under cuda + the specific sum(dim=[1]) pattern.
|
||||
if cfg!(feature = "cuda")
|
||||
&& backend_is_cuda
|
||||
&& dims.len() == 1
|
||||
&& dims[0] == 1
|
||||
&& !keepdim
|
||||
&& let Some(sum_input_name) = node.inputs.first().and_then(|i| i.arg.as_tensor_name())
|
||||
&& let Some((view_target, view_src)) = self.node_chain.get(sum_input_name).cloned()
|
||||
&& view_target == "torch.ops.aten.view.default"
|
||||
&& let Some((is_target, is_src)) = self.node_chain.get(&view_src).cloned()
|
||||
&& is_target == "torch.ops.aten.index_select.default"
|
||||
// Pull the FX index_select node so we can grab its dim + index args.
|
||||
&& let Some(is_node) = self.parsed.program.graph_module.graph.nodes.iter()
|
||||
.find(|n| n.outputs.first()
|
||||
.and_then(|o| o.as_tensor.as_ref())
|
||||
.map(|t| t.name == view_src)
|
||||
.unwrap_or(false))
|
||||
{
|
||||
let weight = self.tensors.get(&is_src).copied();
|
||||
let idx_name = is_node.inputs.get(2).and_then(|i| i.arg.as_tensor_name());
|
||||
let is_dim = self.get_int_arg(is_node, 1).unwrap_or(-1);
|
||||
let in_tensor = self.tensors.get(sum_input_name).copied();
|
||||
if is_dim == 0
|
||||
&& let Some(w) = weight
|
||||
&& let Some(idx_n) = idx_name
|
||||
&& let Some(idx) = self.tensors.get(idx_n).copied()
|
||||
&& let Some(inp) = in_tensor
|
||||
&& w.shape.dims.len() == 2
|
||||
&& idx.shape.dims.len() == 1
|
||||
&& inp.shape.dims.len() == 3
|
||||
// ensure view's middle dim == idx's bag dim divides idx total
|
||||
&& inp.dtype == DType::F32
|
||||
&& w.dtype == DType::F32
|
||||
{
|
||||
let l = inp.shape.dims[1];
|
||||
let kb = inp.shape.dims[0];
|
||||
let d = inp.shape.dims[2];
|
||||
// Reshape flat indices (K*B*L,) to (K*B, L).
|
||||
let idx_2d = reshape_tensor(idx, vec![kb, l]);
|
||||
// embedding_bag_sum_kernel expects (n_emb, d) weights +
|
||||
// (batch, bag) indices, returns (batch, d).
|
||||
let _ = d; // d already encoded in `w.shape.dims[1]`
|
||||
return Ok(luminal_cuda_lite::kernel::embedding_bag_sum_kernel(w, idx_2d));
|
||||
}
|
||||
}
|
||||
|
||||
// Generic fallback.
|
||||
self.translate_reduction(node, ReductionOp::Sum)
|
||||
}
|
||||
|
||||
fn translate_scalar_comparison(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
|
||||
@@ -51,6 +51,42 @@ pub(crate) struct Translator<'a> {
|
||||
pub(crate) output_ids: Vec<(String, NodeIndex)>,
|
||||
/// Extra tensor metadata from inlined subgraphs.
|
||||
pub(crate) extra_tensor_values: HashMap<String, TensorMeta>,
|
||||
/// Peephole: maps an output-tensor name produced by a `permute([1,0])`
|
||||
/// (i.e. a 2-D transpose) back to its input-tensor name. Used by the
|
||||
/// addmm dispatch to detect `aten.addmm(bias, x, weight.t())` and
|
||||
/// route it through the fused `Matmul2DKernel` (`matmul_2d_t`) with
|
||||
/// the original weight, instead of through the generic
|
||||
/// expand+mul+sum decomposition that materializes ~25 small kernels.
|
||||
pub(crate) transpose_2d_source: HashMap<String, String>,
|
||||
/// Trace each emitted node by its first output's name → (FX target,
|
||||
/// first input's name). Used by the EmbBag peephole to walk back
|
||||
/// `sum.dim_IntList → aten.view.default → aten.index_select.default`
|
||||
/// and substitute the fused `embedding_bag_sum_kernel` for the slow
|
||||
/// expand+gather decomposition. Populated by `record_node_chain`
|
||||
/// after dispatching each op.
|
||||
pub(crate) node_chain: HashMap<String, (String, String)>,
|
||||
/// Per-node side table mapping the primary output name → list of all
|
||||
/// input tensor names (in order). Lets multi-input peepholes — e.g.
|
||||
/// `index.Tensor(bmm(cat([…]), permute(cat([…]))), [None, li, lj])`
|
||||
/// → `dlrm_pairwise_dot_lower_tri([…])` — walk back through cat and
|
||||
/// bmm without re-scanning the FX node array. Populated alongside
|
||||
/// `node_chain` after each translated op.
|
||||
pub(crate) op_inputs: HashMap<String, Vec<String>>,
|
||||
/// Tensor name → list of *consumer output-tensor names*. Built once at
|
||||
/// the start of `translate_graph` from the parsed FX node array.
|
||||
/// (The pt2_schema's `Node` has no name field; nodes are identified by
|
||||
/// their primary output tensor name.) Used by forward-looking fusions:
|
||||
/// e.g. when the addmm fast path fires and the single consumer is
|
||||
/// `relu`/`sigmoid`, we emit the fused `linear_bias_relu`/
|
||||
/// `linear_bias_sigmoid` kernel and absorb the consumer node via
|
||||
/// `absorbed_nodes`.
|
||||
pub(crate) consumers: HashMap<String, Vec<String>>,
|
||||
/// Set of *output tensor names* whose producing FX node was absorbed
|
||||
/// into an earlier node's emission (e.g. a `relu` folded into
|
||||
/// `linear_bias` by the addmm fast path). `translate_graph` short-
|
||||
/// circuits these nodes. The absorbed node's output tensor must be
|
||||
/// pre-populated under its name by the absorbing node.
|
||||
pub(crate) absorbed_nodes: std::collections::HashSet<String>,
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
@@ -61,17 +97,112 @@ impl<'a> Translator<'a> {
|
||||
graph: Graph::new(),
|
||||
tensors: HashMap::new(),
|
||||
sym_map,
|
||||
transpose_2d_source: HashMap::new(),
|
||||
node_chain: HashMap::new(),
|
||||
op_inputs: HashMap::new(),
|
||||
consumers: HashMap::new(),
|
||||
absorbed_nodes: std::collections::HashSet::new(),
|
||||
user_input_ids: Vec::new(),
|
||||
output_ids: Vec::new(),
|
||||
extra_tensor_values: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper: extract a node's primary output tensor name. Nodes in the
|
||||
/// pt2 schema are identified by this (no separate name field).
|
||||
fn node_out_name(node: &Node) -> Option<String> {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.or_else(|| {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a `tensor_name → consumer output-tensor names` map for the
|
||||
/// parsed graph. One pass over all FX nodes; each node contributes to
|
||||
/// the `consumers` entry of every tensor it reads (Argument::Tensor or
|
||||
/// Argument::Tensors). The consumer is keyed by its primary output
|
||||
/// tensor name. Used by forward-looking fast paths to detect when an
|
||||
/// op's output has a single downstream consumer of a known kind
|
||||
/// (relu/sigmoid) and emit a fused kernel that absorbs the consumer.
|
||||
fn build_consumers(&mut self) {
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for node in nodes {
|
||||
let Some(consumer_out) = Self::node_out_name(node) else {
|
||||
continue;
|
||||
};
|
||||
for inp in &node.inputs {
|
||||
match &inp.arg {
|
||||
Argument::Tensor(t) => {
|
||||
self.consumers
|
||||
.entry(t.as_tensor.name.clone())
|
||||
.or_default()
|
||||
.push(consumer_out.clone());
|
||||
}
|
||||
Argument::Tensors(ts) => {
|
||||
for t in &ts.as_tensors {
|
||||
self.consumers
|
||||
.entry(t.name.clone())
|
||||
.or_default()
|
||||
.push(consumer_out.clone());
|
||||
}
|
||||
}
|
||||
Argument::OptionalTensors(ots) => {
|
||||
for ot in &ots.as_optional_tensors {
|
||||
if let crate::pt2_schema::OptionalTensorEntry::Tensor(t) = ot {
|
||||
self.consumers
|
||||
.entry(t.as_tensor.name.clone())
|
||||
.or_default()
|
||||
.push(consumer_out.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the unique consumer of `tensor_name`, if there is exactly one
|
||||
/// and we can find the corresponding FX node. Returns `(target,
|
||||
/// output_tensor_name)`. None if the consumer set is empty, has more
|
||||
/// than one entry, or the FX node lookup fails.
|
||||
pub(crate) fn unique_consumer(&self, tensor_name: &str) -> Option<(String, String)> {
|
||||
let consumers = self.consumers.get(tensor_name)?;
|
||||
if consumers.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let consumer_out = &consumers[0];
|
||||
let node = self
|
||||
.parsed
|
||||
.program
|
||||
.graph_module
|
||||
.graph
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|n| Self::node_out_name(n).as_deref() == Some(consumer_out.as_str()))?;
|
||||
Some((node.target.clone(), consumer_out.clone()))
|
||||
}
|
||||
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
self.build_consumers();
|
||||
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
// Skip nodes whose translation was absorbed by an earlier
|
||||
// node's fast path (e.g. a `relu` folded into a fused
|
||||
// `linear_bias_relu`). The absorbing node has already
|
||||
// populated `tensors` under this node's output name.
|
||||
if let Some(out_name) = Self::node_out_name(node)
|
||||
&& self.absorbed_nodes.contains(&out_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
self.translate_node(node)
|
||||
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
|
||||
}
|
||||
@@ -90,9 +221,64 @@ impl<'a> Translator<'a> {
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
|
||||
// Post-translation dead-code elimination. luminal's egglog DOES
|
||||
// prune unreachable subgraphs in the common case (e.g. an unused
|
||||
// `x*2.0` next to a returned `x+1.0`), but in some patterns the
|
||||
// optimizer holds onto subgraphs that were created and then
|
||||
// superseded by a translator peephole — most notably the DLRM
|
||||
// PairwiseDot path where `index.Tensor(bmm(cat(...), perm(cat(...))), ...)`
|
||||
// is replaced with a fused custom op but the original bmm/cat
|
||||
// pad-and-add chain remains in the HLIR. Walk back from every
|
||||
// `Output` HLIR node, mark reachable producers, and drop the rest.
|
||||
// Preserves `Input` nodes unconditionally so the runtime's input
|
||||
// signature stays intact even when an input is unused (a few
|
||||
// models pass dead constants alongside live tensors).
|
||||
self.dce();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sweep the HLIR graph: remove every node not reachable backward
|
||||
/// from an `Output` HLIR sink. Inputs are kept regardless so the
|
||||
/// runtime input contract is preserved.
|
||||
fn dce(&mut self) {
|
||||
use luminal::hlir::{Input, Output};
|
||||
use petgraph::Direction;
|
||||
use std::collections::HashSet;
|
||||
|
||||
let mut keep: HashSet<NodeIndex> = HashSet::new();
|
||||
let mut stack: Vec<NodeIndex> = Vec::new();
|
||||
let node_ids: Vec<NodeIndex> = self.graph.graph.node_indices().collect();
|
||||
for n in &node_ids {
|
||||
if self.graph.try_get_op::<Output>(*n).is_some() {
|
||||
if keep.insert(*n) {
|
||||
stack.push(*n);
|
||||
}
|
||||
}
|
||||
if self.graph.try_get_op::<Input>(*n).is_some() {
|
||||
keep.insert(*n);
|
||||
}
|
||||
}
|
||||
while let Some(n) = stack.pop() {
|
||||
// Walk incoming edges — operands of `n`.
|
||||
let preds: Vec<NodeIndex> = self
|
||||
.graph
|
||||
.graph
|
||||
.neighbors_directed(n, Direction::Incoming)
|
||||
.collect();
|
||||
for pred in preds {
|
||||
if keep.insert(pred) {
|
||||
stack.push(pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
for n in node_ids {
|
||||
if !keep.contains(&n) {
|
||||
self.graph.graph.remove_node(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inputs(&mut self) -> Result<()> {
|
||||
let inputs = self.parsed.classify_inputs();
|
||||
for input in &inputs {
|
||||
|
||||
@@ -72,6 +72,24 @@ impl<'a> Translator<'a> {
|
||||
.iter()
|
||||
.map(|&d| normalize_dim(d, a.shape.len()))
|
||||
.collect();
|
||||
// Record matmul-compatible inner-axis transposes so addmm /
|
||||
// bmm can route them through the fused Matmul2DKernel /
|
||||
// matmul_3d_t with the *original* input. The view-transposed
|
||||
// tensor has non-contiguous strides that the SGEMM kernel
|
||||
// doesn't honor, so we need the original. We recognize two
|
||||
// patterns:
|
||||
// * 2-D permute [1, 0] — `weight.t()` from nn.Linear
|
||||
// * 3-D permute [0, 2, 1] — `T.transpose(1, 2)` for bmm
|
||||
let is_inner_transpose = (axes == [1usize, 0usize] && a.shape.dims.len() == 2)
|
||||
|| (axes == [0usize, 2usize, 1usize] && a.shape.dims.len() == 3);
|
||||
if is_inner_transpose
|
||||
&& let Some(src_name) = node.inputs.first().and_then(|i| i.arg.as_tensor_name())
|
||||
&& let Some(out_ref) = node.outputs.first()
|
||||
&& let Some(out_t) = out_ref.as_tensor.as_ref()
|
||||
{
|
||||
self.transpose_2d_source
|
||||
.insert(out_t.name.clone(), src_name.to_string());
|
||||
}
|
||||
Ok(a.permute(axes))
|
||||
}
|
||||
|
||||
@@ -256,7 +274,231 @@ impl<'a> Translator<'a> {
|
||||
Ok(weight.gather(ids_expanded + arange_expanded))
|
||||
}
|
||||
|
||||
/// `aten.index_select(input, dim, index)` — pick rows/slices of `input`
|
||||
/// along `dim` using a 1-D `index` tensor. Output shape is
|
||||
/// `input.shape` with `dim` replaced by `index.shape[0]`.
|
||||
///
|
||||
/// For the DLRM v3 use case this is `index_select(emb_weight, 0,
|
||||
/// flat_indices)` — a 2-D source and 1-D index along dim 0. We lower
|
||||
/// it the same way `translate_embedding` does: build a flat-rows
|
||||
/// gather index `(index * hidden_dim) + arange(hidden_dim)` and read
|
||||
/// the flattened weight in one pass. Higher-rank sources and non-zero
|
||||
/// `dim` are not yet wired (would need stride math over Expression
|
||||
/// shapes); they error out cleanly so they're easy to add when the
|
||||
/// next model surfaces them.
|
||||
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let source = self.get_input_tensor(node, 0)?;
|
||||
let dim_raw = self.get_int_arg(node, 1)?;
|
||||
let index = self.get_input_tensor(node, 2)?;
|
||||
|
||||
let rank = source.shape.dims.len();
|
||||
anyhow::ensure!(
|
||||
rank == 2,
|
||||
"translate_index_select: only 2-D source supported (got rank {rank}); \
|
||||
extend this when a model needs higher rank."
|
||||
);
|
||||
let dim = if dim_raw < 0 {
|
||||
dim_raw + rank as i64
|
||||
} else {
|
||||
dim_raw
|
||||
};
|
||||
anyhow::ensure!(
|
||||
dim == 0,
|
||||
"translate_index_select: only dim=0 supported (got {dim}); \
|
||||
extend this when a model needs another axis."
|
||||
);
|
||||
anyhow::ensure!(
|
||||
index.shape.dims.len() == 1,
|
||||
"translate_index_select: index must be 1-D (got rank {})",
|
||||
index.shape.dims.len()
|
||||
);
|
||||
|
||||
// Same lowering as `translate_embedding`: build a flat gather index
|
||||
// that combines the row-base offsets (`index * hidden_dim`) with a
|
||||
// per-row `arange(hidden_dim)` broadcast.
|
||||
let hidden_dim = source.shape.dims[1];
|
||||
let n_idx = index.shape.dims[0];
|
||||
let index_int = index.cast(DType::Int);
|
||||
let base_expanded = (index_int * hidden_dim).expand_dim(1, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim);
|
||||
let arange_expanded = arange.expand_dim(0, n_idx);
|
||||
Ok(source.gather(base_expanded + arange_expanded))
|
||||
}
|
||||
|
||||
/// `aten._embedding_bag_forward_only(weight, indices, offsets,
|
||||
/// scale_grad_by_freq, mode, sparse, per_sample_weights,
|
||||
/// include_last_offset, padding_idx)` →
|
||||
/// `(output, offset2bag, bag_size, max_indices)`.
|
||||
///
|
||||
/// PyTorch decomposes `nn.EmbeddingBag` to this op. For the DLRM use
|
||||
/// case all bags share a fixed stride `L = indices.len() / offsets.len()`
|
||||
/// and `mode == 0` (sum). We detect that and lower to the fused
|
||||
/// [`embedding_bag_sum_kernel`] on CUDA, or to a generic
|
||||
/// `gather → reshape → sum` chain on CPU.
|
||||
///
|
||||
/// Only `output` (tuple slot 0) is computed — `offset2bag`, `bag_size`
|
||||
/// and `max_indices` are training-time dead ends for inference DLRM
|
||||
/// and never read by any downstream `getitem`.
|
||||
pub(crate) fn translate_embedding_bag_forward_only(&mut self, node: &Node) -> Result<()> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
let offsets = self.get_input_tensor(node, 2)?;
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
anyhow::ensure!(
|
||||
mode == 0,
|
||||
"translate_embedding_bag_forward_only: only mode=0 (sum) supported (got {mode}); \
|
||||
vanilla DLRM uses sum-pooled bags. Extend this when a model needs mean/max."
|
||||
);
|
||||
// per_sample_weights is input index 6 and may be None / absent.
|
||||
let has_per_sample_weights = node
|
||||
.inputs
|
||||
.get(6)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.is_some();
|
||||
anyhow::ensure!(
|
||||
!has_per_sample_weights,
|
||||
"translate_embedding_bag_forward_only: per_sample_weights not supported \
|
||||
(DLRM doesn't use them)."
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
weight.shape.dims.len() == 2,
|
||||
"translate_embedding_bag_forward_only: weight must be 2-D (got rank {})",
|
||||
weight.shape.dims.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
indices.shape.dims.len() == 1,
|
||||
"translate_embedding_bag_forward_only: indices must be 1-D (got rank {})",
|
||||
indices.shape.dims.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
offsets.shape.dims.len() == 1,
|
||||
"translate_embedding_bag_forward_only: offsets must be 1-D (got rank {})",
|
||||
offsets.shape.dims.len()
|
||||
);
|
||||
|
||||
let n_idx = indices.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("translate_embedding_bag_forward_only: indices length must be static")?;
|
||||
let batch = offsets.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("translate_embedding_bag_forward_only: offsets length must be static")?;
|
||||
anyhow::ensure!(
|
||||
n_idx % batch == 0,
|
||||
"translate_embedding_bag_forward_only: indices length ({n_idx}) must be a \
|
||||
multiple of offsets length ({batch}); variable bag sizes not supported."
|
||||
);
|
||||
let bag = n_idx / batch;
|
||||
let d = weight.shape.dims[1]
|
||||
.to_usize()
|
||||
.context("translate_embedding_bag_forward_only: weight dim 1 must be static")?;
|
||||
|
||||
// Reshape indices (B*L,) → (B, L) and cast to i32 (luminal kernel
|
||||
// wants Int). Then either use the fused kernel under CUDA or
|
||||
// a host-portable gather+sum lowering.
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let indices_2d = {
|
||||
let new_shape = ShapeTracker::new(vec![
|
||||
Expression::from(batch),
|
||||
Expression::from(bag),
|
||||
]);
|
||||
GraphTensor {
|
||||
id: indices_int.id,
|
||||
graph_ref: indices_int.graph_ref,
|
||||
shape: new_shape,
|
||||
dtype: indices_int.dtype,
|
||||
}
|
||||
};
|
||||
|
||||
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
|
||||
|
||||
let result = if cfg!(feature = "cuda") && backend_is_cuda && weight.dtype == DType::F32 {
|
||||
// Fused CUDA path: one kernel for the whole bag-sum.
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
luminal_cuda_lite::kernel::embedding_bag_sum_kernel(weight, indices_2d)
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
// Unreachable — gated above, but keep the compiler happy.
|
||||
unreachable!()
|
||||
}
|
||||
} else {
|
||||
// Generic fallback: gather (B*L, D) then reshape and sum.
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
let ids_expanded = (indices_2d * hidden_dim).expand_dim(2, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim);
|
||||
let arange_expanded = arange.expand_dim(0, batch).expand_dim(0, bag);
|
||||
// Note: weight.gather expects the gather indices to broadcast
|
||||
// against weight's row-flattened layout; we want (B, L, D)
|
||||
// out, then sum along L.
|
||||
let _ = d; // hidden_dim is used; keep `d` reachable for debug only.
|
||||
let gathered = weight.gather(ids_expanded + arange_expanded);
|
||||
gathered.sum(1)
|
||||
};
|
||||
|
||||
// Record the output under outputs[0][0] (the tuple's first slot).
|
||||
// The other three slots are dead under inference and there's no
|
||||
// downstream `getitem` that reads them — but if there ever is,
|
||||
// we'd need to materialize them too.
|
||||
let out_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
.or_else(|| {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
})
|
||||
.context(
|
||||
"translate_embedding_bag_forward_only: missing output[0] name in FX node",
|
||||
)?;
|
||||
self.tensors.insert(out_name.clone(), result);
|
||||
// Record node_chain / op_inputs for the *primary* output (tuple
|
||||
// slot 0). Multi-output ops normally skip the bookkeeping at the
|
||||
// bottom of `translate_node` (they return early), but later
|
||||
// peepholes — specifically the stacked-emb-bag fusion inside the
|
||||
// pairwise-dot peephole — need to identify which tensor sources
|
||||
// came from an embedding_bag, so we record explicitly.
|
||||
let first_input_name: Option<String> = node
|
||||
.inputs
|
||||
.first()
|
||||
.and_then(|i| i.arg.as_tensor_name().map(|s| s.to_string()));
|
||||
if let Some(first_input) = first_input_name {
|
||||
self.node_chain
|
||||
.insert(out_name.clone(), (node.target.clone(), first_input));
|
||||
}
|
||||
let mut all_inputs: Vec<String> = Vec::new();
|
||||
for inp in &node.inputs {
|
||||
if let Some(names) = inp.arg.as_tensors() {
|
||||
for tn in names {
|
||||
all_inputs.push(tn.name.clone());
|
||||
}
|
||||
} else if let Some(name) = inp.arg.as_tensor_name() {
|
||||
all_inputs.push(name.to_string());
|
||||
}
|
||||
}
|
||||
if !all_inputs.is_empty() {
|
||||
self.op_inputs.insert(out_name, all_inputs);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
// Try the DLRM PairwiseDot peephole before falling back to the
|
||||
// generic gather-based lowering. Detects the
|
||||
// Z[:, li, lj] ← bmm(T, T.transpose(1, 2)) ← cat([t1.unsqueeze(1), ..., tF.unsqueeze(1)], dim=1)
|
||||
// pattern that vanilla `nn.Sequential` DLRM (DLRMv1) emits — the
|
||||
// version a user writes with one `EmbeddingBag` per categorical
|
||||
// table. Replacing it with `dlrm_pairwise_dot_lower_tri` collapses
|
||||
// the cat-then-bmm-then-gather chain (which lowers to ~40
|
||||
// small Iota/Cast/Gather/FusedRegion kernels via pad_along+add)
|
||||
// into a single CUDA kernel.
|
||||
if let Some(t) = self.try_translate_pairwise_dot_lower_tri(node)? {
|
||||
return Ok(t);
|
||||
}
|
||||
|
||||
let source = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// Handle indices as_tensors (all non-None) or as individual args with None entries
|
||||
@@ -308,6 +550,98 @@ impl<'a> Translator<'a> {
|
||||
expanded.shape.expand(target);
|
||||
return Ok(source.gather_elements(expanded, first_non_none_dim));
|
||||
}
|
||||
|
||||
// Multi-index advanced indexing through leading dims that
|
||||
// pass through (e.g. DLRM's `Z[:, li, lj]` where Z has
|
||||
// shape `(B, F, F)` and output[b, p] = Z[b, li[p], lj[p]]).
|
||||
//
|
||||
// Strategy: reduce to the proven single-index simple
|
||||
// case. Combine the multi-axis indices into one (`li * F
|
||||
// + lj`) and reshape the source so the indexed region
|
||||
// becomes a single dim. Then take the exact `gather_elements`
|
||||
// path the rest of this translator already uses.
|
||||
//
|
||||
// Supported shape pattern (DLRM): exactly one leading
|
||||
// passthrough dim, no trailing dims after the indexed
|
||||
// region, per-axis indices all 1-D of the same length.
|
||||
if first_non_none_dim > 0 {
|
||||
let src_dims = source.shape.dims;
|
||||
let src_rank = src_dims.len();
|
||||
let n_idx = index_names.len();
|
||||
let trailing_start = first_non_none_dim + n_idx;
|
||||
anyhow::ensure!(
|
||||
first_non_none_dim == 1,
|
||||
"index.Tensor: leading-dim passthrough only supported for \
|
||||
exactly one leading dim (got {first_non_none_dim})."
|
||||
);
|
||||
anyhow::ensure!(
|
||||
trailing_start == src_rank,
|
||||
"index.Tensor: trailing dims after indexed region not yet supported."
|
||||
);
|
||||
let mut idx_tensors: Vec<GraphTensor> = Vec::with_capacity(n_idx);
|
||||
for n in &index_names {
|
||||
idx_tensors.push(self.get_tensor(&n.name)?.cast(DType::Int));
|
||||
}
|
||||
let idx0_shape = idx_tensors[0].shape.dims;
|
||||
anyhow::ensure!(
|
||||
idx0_shape.len() == 1,
|
||||
"index.Tensor: only 1-D per-axis indices supported (got rank {})",
|
||||
idx0_shape.len()
|
||||
);
|
||||
for it in idx_tensors.iter().skip(1) {
|
||||
anyhow::ensure!(
|
||||
it.shape.dims == idx0_shape,
|
||||
"index.Tensor: per-axis indices must share a common shape"
|
||||
);
|
||||
}
|
||||
// strides over indexed axes (no trailing dims).
|
||||
let mut strides_idx: Vec<Expression> = vec![Expression::from(1usize); n_idx];
|
||||
for i in (0..n_idx - 1).rev() {
|
||||
strides_idx[i] =
|
||||
strides_idx[i + 1] * src_dims[first_non_none_dim + i + 1];
|
||||
}
|
||||
// combined[p] = sum_i idx_i * stride_i (1-D)
|
||||
let mut combined: Option<GraphTensor> = None;
|
||||
for (i, it) in idx_tensors.into_iter().enumerate() {
|
||||
let weighted = if strides_idx[i].to_usize() == Some(1) {
|
||||
it
|
||||
} else {
|
||||
it * strides_idx[i]
|
||||
};
|
||||
combined = Some(match combined {
|
||||
Some(acc) => {
|
||||
let (a, b) = broadcast_binary(acc, weighted);
|
||||
a + b
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
let combined = combined.context("index.Tensor: no indices")?;
|
||||
|
||||
// Indexed region size, then a (leading, indexed_size) reshape.
|
||||
let mut indexed_size = Expression::from(1usize);
|
||||
for d in &src_dims[first_non_none_dim..trailing_start] {
|
||||
indexed_size *= *d;
|
||||
}
|
||||
let leading_dim = src_dims[0];
|
||||
let flat_source =
|
||||
reshape_tensor(source, vec![leading_dim, indexed_size]);
|
||||
|
||||
// Now dispatch through the exact single-index simple
|
||||
// case lowering — known-good. Add unit leading dims
|
||||
// to match flat_source rank, then expand to the full
|
||||
// (leading_dim, pair_count) shape.
|
||||
let mut expanded = combined;
|
||||
let flat_rank = 2; // (leading, indexed_size)
|
||||
for _ in 0..(flat_rank - expanded.shape.len()) {
|
||||
expanded = expanded.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let idx_dim_size = expanded.shape.dims[1];
|
||||
let mut target: Vec<Expression> = vec![leading_dim, indexed_size];
|
||||
target[1] = idx_dim_size;
|
||||
expanded.shape.expand(target);
|
||||
return Ok(flat_source.gather_elements(expanded, 1));
|
||||
}
|
||||
} else {
|
||||
bail!(
|
||||
"index.Tensor: unsupported indices format: {:?}",
|
||||
@@ -552,4 +886,394 @@ impl<'a> Translator<'a> {
|
||||
|
||||
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
|
||||
}
|
||||
|
||||
/// DLRM PairwiseDot peephole: detect
|
||||
/// `aten.index.Tensor(bmm, [None, li, lj])`
|
||||
/// where
|
||||
/// `bmm = aten.bmm.default(T, T_permuted)`
|
||||
/// `T_permuted = aten.permute.default(T, [0, 2, 1])`
|
||||
/// `T = aten.cat.default([unsqueeze_a, unsqueeze_b, …], dim=1)`
|
||||
/// each `unsqueeze_k = aten.unsqueeze.default(t_k, 1)`
|
||||
/// and lower to `dlrm_pairwise_dot_lower_tri([t_0, t_1, …])`.
|
||||
///
|
||||
/// Why this matters: at DLRM nc=3 the generic lowering produces
|
||||
/// ~80 CUDA-graph kernels from the cat+bmm+gather chain alone (the
|
||||
/// `pad_along + add` decomposition of cat fans out into many
|
||||
/// Iota/Cast/Gather/FusedRegion launches). The fused kernel
|
||||
/// computes the F(F-1)/2 dot products directly with one launch.
|
||||
///
|
||||
/// Returns `Ok(Some(out))` on match, `Ok(None)` if the pattern
|
||||
/// doesn't apply, `Err(_)` only if matching diagnostics surface a
|
||||
/// genuine bug.
|
||||
fn try_translate_pairwise_dot_lower_tri(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
) -> Result<Option<GraphTensor>> {
|
||||
// CUDA-only fast path. The kernel is in luminal_cuda_lite.
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
let _ = node;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let backend_is_cuda = true; // LUMINAL_BACKEND_CUDA replaced by cfg!(feature="cuda") gate above.
|
||||
if !backend_is_cuda {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// 1. Detect [None, li, lj] index list.
|
||||
let opt_tensors =
|
||||
match node.inputs.get(1).and_then(|i| i.arg.as_optional_tensors()) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
// Expect [None, li, lj]: three entries, first None, last two
|
||||
// are tensors.
|
||||
if opt_tensors.len() != 3 {
|
||||
return Ok(None);
|
||||
}
|
||||
use crate::pt2_schema::OptionalTensorEntry;
|
||||
let (li_name, lj_name) =
|
||||
match (&opt_tensors[0], &opt_tensors[1], &opt_tensors[2]) {
|
||||
(OptionalTensorEntry::None(_), OptionalTensorEntry::Tensor(li), OptionalTensorEntry::Tensor(lj)) => {
|
||||
(li.as_tensor.name.clone(), lj.as_tensor.name.clone())
|
||||
}
|
||||
_ => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
// 2. Source must be a bmm of (T, T_permuted).
|
||||
let source_name = match node.inputs.first().and_then(|i| i.arg.as_tensor_name()) {
|
||||
Some(s) => s.to_string(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
let bmm_info = match self.node_chain.get(&source_name) {
|
||||
Some(x) => x.clone(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
if bmm_info.0 != "torch.ops.aten.bmm.default" {
|
||||
return Ok(None);
|
||||
}
|
||||
let bmm_inputs = match self.op_inputs.get(&source_name) {
|
||||
Some(v) => v.clone(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
if bmm_inputs.len() != 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
let (bmm_a, bmm_b) = (bmm_inputs[0].clone(), bmm_inputs[1].clone());
|
||||
// 3. Both bmm inputs must descend from the same cat — one
|
||||
// directly, one via a [0, 2, 1] permute. The permute is
|
||||
// already recorded in `transpose_2d_source` (we extended
|
||||
// it to cover 3-D `[0, 2, 1]` for bmm fast paths).
|
||||
let permute_src = self.transpose_2d_source.get(&bmm_b).cloned();
|
||||
let (cat_name, _has_transpose) = if permute_src.as_deref() == Some(bmm_a.as_str()) {
|
||||
(bmm_a.clone(), true)
|
||||
} else if self
|
||||
.transpose_2d_source
|
||||
.get(&bmm_a)
|
||||
.map(|s| s.as_str() == bmm_b.as_str())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
(bmm_b.clone(), true)
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
let cat_info = match self.node_chain.get(&cat_name) {
|
||||
Some(x) => x.clone(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
if cat_info.0 != "torch.ops.aten.cat.default" {
|
||||
return Ok(None);
|
||||
}
|
||||
let cat_inputs = match self.op_inputs.get(&cat_name) {
|
||||
Some(v) => v.clone(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
if cat_inputs.len() < 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
// 4. Each cat input should be `unsqueeze(t_k, 1)` — peel the
|
||||
// unsqueeze so we recover the original (B, D) tensor.
|
||||
// Also track the FX *source name* of each feature so the
|
||||
// multi-call fusion below can recognize emb-bag-rooted
|
||||
// features and emit the stacked path.
|
||||
let mut feature_tensors: Vec<GraphTensor> = Vec::with_capacity(cat_inputs.len());
|
||||
let mut feature_source_names: Vec<String> = Vec::with_capacity(cat_inputs.len());
|
||||
for ci in &cat_inputs {
|
||||
let unsqueeze_info = match self.node_chain.get(ci) {
|
||||
Some(x) => x.clone(),
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
if unsqueeze_info.0 != "torch.ops.aten.unsqueeze.default" {
|
||||
return Ok(None);
|
||||
}
|
||||
// unsqueeze's first input is the source tensor name.
|
||||
let src = unsqueeze_info.1;
|
||||
let t = self.get_tensor(&src)?;
|
||||
if t.dtype != DType::F32 || t.shape.dims.len() != 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
feature_tensors.push(t);
|
||||
feature_source_names.push(src);
|
||||
}
|
||||
// 5. Sanity check on li/lj — they must be the strict lower-tri
|
||||
// pair table for F = feature_tensors.len(). We don't
|
||||
// materialize them; just verify the index buffers have
|
||||
// the right length, then trust they're tril-indices.
|
||||
// A user passing arbitrary indices through this exact
|
||||
// chain would silently get tril results; gating on
|
||||
// `index buffer length == F*(F-1)/2` catches the common
|
||||
// case without invasive constant-folding work.
|
||||
let f = feature_tensors.len();
|
||||
let pair_count = f * (f - 1) / 2;
|
||||
let li_t = self.get_tensor(&li_name)?;
|
||||
let lj_t = self.get_tensor(&lj_name)?;
|
||||
if li_t.shape.dims.len() != 1
|
||||
|| lj_t.shape.dims.len() != 1
|
||||
|| li_t.shape.dims[0].to_usize() != Some(pair_count)
|
||||
|| lj_t.shape.dims[0].to_usize() != Some(pair_count)
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// 6. Multi-call fusion: detect when exactly one feature is the
|
||||
// dense MLP output and all others came from
|
||||
// `_embedding_bag.default` calls with matching `(rows, D)`.
|
||||
// Then collapse the N separate per-table embedding-bag
|
||||
// kernels into one fused `stacked_embedding_bag_sum_kernel`
|
||||
// and use the stacked-variant pairwise dot. Mirrors the
|
||||
// hand-written DLRM rust example's kernel shape.
|
||||
if let Some((dense, emb_stack)) =
|
||||
self.try_fuse_stacked_emb_bag(&feature_tensors, &feature_source_names)?
|
||||
{
|
||||
let out = luminal_cuda_lite::kernel::dlrm_pairwise_dot_lower_tri_stacked(
|
||||
dense, emb_stack,
|
||||
);
|
||||
return Ok(Some(out));
|
||||
}
|
||||
|
||||
// Fallback: per-feature variadic kernel. The bmm, cat,
|
||||
// and unsqueeze nodes left dangling in the HLIR get picked
|
||||
// up by `Translator::dce()` after every FX node is translated
|
||||
// (walks back from `Output` HLIR sinks and drops everything
|
||||
// unreachable). luminal's egglog optimizer leaves some of
|
||||
// these subgraphs alive on its own, so the explicit pass is
|
||||
// load-bearing for this peephole.
|
||||
let out = luminal_cuda_lite::kernel::dlrm_pairwise_dot_lower_tri(feature_tensors);
|
||||
Ok(Some(out))
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-call fusion: scan a list of feature sources for the DLRM
|
||||
/// pattern "1 dense MLP output + N `_embedding_bag.default` outputs"
|
||||
/// and, when matched, emit a single `stacked_embedding_bag_sum_kernel`
|
||||
/// over a concatenated weight tensor. Returns `(dense, emb_stack)`
|
||||
/// where `dense` is the (B, D) MLP output and `emb_stack` is
|
||||
/// `(B, num_emb, D)` for the stacked-pairwise-dot kernel to consume.
|
||||
///
|
||||
/// Requirements:
|
||||
/// * All embedding-bag weights are 2-D F32 with the same `(rows, D)`
|
||||
/// * All bag indices have the same `(batch, bag_size)` shape
|
||||
/// * The dense feature is the *first* cat input (DLRMv1 emits
|
||||
/// `cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)`)
|
||||
#[cfg(feature = "cuda")]
|
||||
fn try_fuse_stacked_emb_bag(
|
||||
&mut self,
|
||||
feature_tensors: &[GraphTensor],
|
||||
feature_source_names: &[String],
|
||||
) -> Result<Option<(GraphTensor, GraphTensor)>> {
|
||||
use crate::pt2_schema::OptionalTensorEntry;
|
||||
if feature_tensors.len() < 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
// Look up each feature source's producing FX node target. The
|
||||
// dense MLP output has target ≠ `_embedding_bag*`. We require the
|
||||
// *first* feature to be dense (DLRMv1 convention) and the rest to
|
||||
// be emb-bag.
|
||||
let producer_target = |name: &str| -> Option<String> {
|
||||
self.node_chain.get(name).map(|(target, _)| target.clone())
|
||||
};
|
||||
let dense_target = producer_target(&feature_source_names[0]);
|
||||
if let Some(t) = &dense_target
|
||||
&& (t == "torch.ops.aten._embedding_bag.default"
|
||||
|| t == "torch.ops.aten._embedding_bag_forward_only.default")
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
for src in feature_source_names.iter().skip(1) {
|
||||
match producer_target(src).as_deref() {
|
||||
Some("torch.ops.aten._embedding_bag.default")
|
||||
| Some("torch.ops.aten._embedding_bag_forward_only.default") => {}
|
||||
_ => return Ok(None),
|
||||
}
|
||||
}
|
||||
// For each emb-bag feature source, pull (weight, indices, offsets)
|
||||
// by walking back through the parsed FX node array. The FX nodes
|
||||
// for emb-bag store the output tensor name(s) in `outputs[0].as_tensors`.
|
||||
let mut emb_weights: Vec<GraphTensor> = Vec::new();
|
||||
let mut emb_indices: Vec<GraphTensor> = Vec::new();
|
||||
let mut emb_rows: Option<usize> = None;
|
||||
let mut emb_d: Option<usize> = None;
|
||||
let mut emb_batch: Option<usize> = None;
|
||||
let mut emb_bag: Option<usize> = None;
|
||||
for src in feature_source_names.iter().skip(1) {
|
||||
// Find the FX emb-bag node whose primary output is `src`.
|
||||
// emb-bag is a multi-output op: `outputs[0].as_tensors[0].name`
|
||||
// is what downstream `getitem(node, 0)` references, which is
|
||||
// what our translator stores in `tensors`. Some exports drop
|
||||
// the getitem entirely and just inline the name as a single
|
||||
// tensor output, so check both shapes.
|
||||
let node_opt = self
|
||||
.parsed
|
||||
.program
|
||||
.graph_module
|
||||
.graph
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|n| {
|
||||
if n.target != "torch.ops.aten._embedding_bag.default"
|
||||
&& n.target != "torch.ops.aten._embedding_bag_forward_only.default"
|
||||
{
|
||||
return false;
|
||||
}
|
||||
let Some(out) = n.outputs.first() else {
|
||||
return false;
|
||||
};
|
||||
if let Some(ts) = out.as_tensors.as_ref() {
|
||||
return ts.iter().any(|t| t.name == *src);
|
||||
}
|
||||
if let Some(t) = out.as_tensor.as_ref() {
|
||||
return t.name == *src;
|
||||
}
|
||||
false
|
||||
});
|
||||
let Some(node) = node_opt else {
|
||||
return Ok(None);
|
||||
};
|
||||
// _embedding_bag(weight, indices, offsets, ...)
|
||||
let weight_name = node
|
||||
.inputs
|
||||
.first()
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|s| s.to_string());
|
||||
let indices_name = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|s| s.to_string());
|
||||
let offsets_name = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|s| s.to_string());
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
if mode != 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
// per_sample_weights at index 6
|
||||
let has_psw = node
|
||||
.inputs
|
||||
.get(6)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.is_some();
|
||||
if has_psw {
|
||||
return Ok(None);
|
||||
}
|
||||
let (Some(wn), Some(in_), Some(on)) = (weight_name, indices_name, offsets_name) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let w = self.get_tensor(&wn)?;
|
||||
let i = self.get_tensor(&in_)?;
|
||||
let o = self.get_tensor(&on)?;
|
||||
if w.dtype != DType::F32 || w.shape.dims.len() != 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
if i.shape.dims.len() != 1 || o.shape.dims.len() != 1 {
|
||||
return Ok(None);
|
||||
}
|
||||
let rows = w.shape.dims[0].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("stacked emb-bag fusion: rows must be static")
|
||||
})?;
|
||||
let d = w.shape.dims[1].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("stacked emb-bag fusion: D must be static")
|
||||
})?;
|
||||
let n_idx = i.shape.dims[0].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("stacked emb-bag fusion: indices length must be static")
|
||||
})?;
|
||||
let batch = o.shape.dims[0].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("stacked emb-bag fusion: offsets length must be static")
|
||||
})?;
|
||||
if n_idx % batch != 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
let bag = n_idx / batch;
|
||||
// All tables must agree on (rows, d, batch, bag).
|
||||
if emb_rows.get_or_insert(rows) != &rows {
|
||||
return Ok(None);
|
||||
}
|
||||
if emb_d.get_or_insert(d) != &d {
|
||||
return Ok(None);
|
||||
}
|
||||
if emb_batch.get_or_insert(batch) != &batch {
|
||||
return Ok(None);
|
||||
}
|
||||
if emb_bag.get_or_insert(bag) != &bag {
|
||||
return Ok(None);
|
||||
}
|
||||
// Reshape indices to (B, bag) of int32 for the kernel.
|
||||
let i_int = i.cast(DType::Int);
|
||||
let new_shape = ShapeTracker::new(vec![
|
||||
Expression::from(batch),
|
||||
Expression::from(bag),
|
||||
]);
|
||||
let indices_2d = GraphTensor {
|
||||
id: i_int.id,
|
||||
graph_ref: i_int.graph_ref,
|
||||
shape: new_shape,
|
||||
dtype: i_int.dtype,
|
||||
};
|
||||
emb_weights.push(w);
|
||||
emb_indices.push(indices_2d);
|
||||
// Silence unused warning when no tables match the size.
|
||||
let _ = OptionalTensorEntry::None;
|
||||
}
|
||||
|
||||
let _ = emb_rows; // (already validated via per-table equality)
|
||||
// Use the multi-table kernel: takes N (weight, idx) pairs and
|
||||
// produces (B, num_emb, D) in one launch. Crucially this avoids
|
||||
// the `concat_along`-of-persistent-weights expansion (which would
|
||||
// emit pad+add HLIR kernels per pair) — the kernel reads each
|
||||
// table's weight pointer directly via a packed staging buffer.
|
||||
let emb_stack = luminal_cuda_lite::kernel::multi_table_embedding_bag_sum_kernel(
|
||||
emb_weights,
|
||||
emb_indices,
|
||||
);
|
||||
Ok(Some((feature_tensors[0], emb_stack)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn try_fuse_stacked_emb_bag(
|
||||
&mut self,
|
||||
_feature_tensors: &[GraphTensor],
|
||||
_feature_source_names: &[String],
|
||||
) -> Result<Option<(GraphTensor, GraphTensor)>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,63 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
# Pre-resolve opaque integer ids for inputs and outputs so the
|
||||
# per-iter `run_with_ptrs` path can skip the string-keyed lookup
|
||||
# inside Rust. One pyo3 call per name at compile time, no per-iter
|
||||
# name handling thereafter. Falls back to None when the Rust side
|
||||
# is an older build that doesn't expose `tensor_id` /
|
||||
# `run_with_ptrs`.
|
||||
self._batched_ptrs_supported = (
|
||||
self._supports_device_ptrs
|
||||
and hasattr(graph_result, "tensor_id")
|
||||
and hasattr(graph_result, "run_with_ptrs")
|
||||
)
|
||||
if self._batched_ptrs_supported:
|
||||
self._input_ids = [graph_result.tensor_id(n) for n in self._input_names]
|
||||
self._output_ids = [graph_result.tensor_id(n) for n in self._output_names]
|
||||
else:
|
||||
self._input_ids = None
|
||||
self._output_ids = None
|
||||
# Output dtype codes — cached at __init__ instead of re-fetched in
|
||||
# each call (was a per-iter pyo3 attribute access on the dynamic
|
||||
# path; for static-shape models the codes don't change).
|
||||
self._output_dtype_codes_cached = (
|
||||
None if self._has_dynamic_dims else list(graph_result.output_dtypes)
|
||||
)
|
||||
# Per-input "have-we-seen-this-ptr-before" cache. In a hot bench
|
||||
# loop the same user tensor objects are passed each iter, so most
|
||||
# of the per-iter Python work (detach + contiguous + dtype cast +
|
||||
# data_ptr + numel + element_size) repeats with identical inputs.
|
||||
# Cache (id(orig_tensor), orig_data_ptr, cast_tensor, cast_ptr,
|
||||
# cast_n_bytes) per input slot; on hit, skip everything and rely
|
||||
# on luminal's previously-registered pointer (the runtime keeps
|
||||
# CudaInput::Ptr entries across `execute()` calls thanks to the
|
||||
# consume-step filter in `cuda_lite/runtime.rs`). The cast tensor
|
||||
# reference is held inside the cache so PyTorch's caching
|
||||
# allocator can't recycle the converted buffer.
|
||||
#
|
||||
# Sharp edge: callers that mutate a user tensor in place via
|
||||
# `.copy_(...)` keep the same `id()` and `data_ptr()` and will
|
||||
# silently get stale cached data. Fresh-tensor callers
|
||||
# (`make_inputs(...)` each iter, or new outputs from upstream)
|
||||
# cache-miss naturally and pay the full cold-path cost. If a
|
||||
# future model needs in-place input mutation, swap this check
|
||||
# for one that also looks at `tensor._version` (autograd's
|
||||
# mutation counter) — but PyTorch flags `_version` as private,
|
||||
# so don't reach for it unless an actual model needs it.
|
||||
self._input_cache_ids = [None] * len(self._input_names)
|
||||
self._input_cache_orig_ptrs = [0] * len(self._input_names)
|
||||
self._input_cache_cast_tensors = [None] * len(self._input_names)
|
||||
self._input_cache_specs = [None] * len(self._input_names)
|
||||
# Cached output tensors mirror the input-side cache. For static-
|
||||
# shape models the output tensors can be reused across calls if
|
||||
# the input device is unchanged — saves ~3 μs/output of
|
||||
# `torch.empty` + the FFI to register the device pointer.
|
||||
# NB: callers that stash the returned tensor must `.clone()`
|
||||
# before the next call; the default contract returns a fresh
|
||||
# tensor each call so leave this opt-in via env var for now.
|
||||
self._output_cache_tensors = None
|
||||
self._output_cache_specs = None
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -87,6 +144,90 @@ class CompiledModel:
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
|
||||
# Batched CUDA fast path. When all user inputs are on CUDA and all
|
||||
# outputs are floating-point, collapse the per-iter `set_input_device_ptr`
|
||||
# × N, `set_output_device_ptr` × M, and `run` calls into a single
|
||||
# `run_with_ptrs` FFI crossing. The Rust side iterates the
|
||||
# (id, ptr, n_bytes) tuples without paying per-call pyo3
|
||||
# marshalling cost.
|
||||
all_cuda = bool(user_inputs) and all(t.is_cuda for t in user_inputs)
|
||||
if self._batched_ptrs_supported and all_cuda:
|
||||
output_shapes = (
|
||||
self._graph.resolve_output_shapes()
|
||||
if self._has_dynamic_dims
|
||||
else self._output_shapes
|
||||
)
|
||||
output_dtype_codes = (
|
||||
self._graph.output_dtypes
|
||||
if self._output_dtype_codes_cached is None
|
||||
else self._output_dtype_codes_cached
|
||||
)
|
||||
output_dtypes = [
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._output_names))
|
||||
]
|
||||
if all(d.is_floating_point for d in output_dtypes):
|
||||
# Build the input-spec list. Hot path: when the user passes
|
||||
# the same tensor object as last call AND that tensor's
|
||||
# data_ptr is unchanged, the previous registration is still
|
||||
# valid in the luminal runtime — skip both the Python-side
|
||||
# cast/contiguous/data_ptr work AND the Rust-side
|
||||
# `set_device_ptr` call by omitting it from `input_specs`.
|
||||
input_specs = []
|
||||
_cache_ids = self._input_cache_ids
|
||||
_cache_orig = self._input_cache_orig_ptrs
|
||||
_cache_cast = self._input_cache_cast_tensors
|
||||
_cache_spec = self._input_cache_specs
|
||||
for i, (id_, tensor, expected_dtype) in enumerate(zip(
|
||||
self._input_ids, user_inputs, self._input_dtypes
|
||||
)):
|
||||
orig_id = id(tensor)
|
||||
orig_ptr = tensor.data_ptr()
|
||||
if _cache_ids[i] == orig_id and _cache_orig[i] == orig_ptr:
|
||||
# Pointer unchanged — luminal already has the
|
||||
# registration. Skip everything.
|
||||
continue
|
||||
# Cold path: cast (if needed) and update cache.
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
cast_ptr = t.data_ptr()
|
||||
spec = (id_, cast_ptr, n_bytes)
|
||||
input_specs.append(spec)
|
||||
_cache_ids[i] = orig_id
|
||||
_cache_orig[i] = orig_ptr
|
||||
_cache_cast[i] = t # keep alive
|
||||
_cache_spec[i] = spec
|
||||
# Outputs: pre-allocate fresh each call so the caller gets
|
||||
# a unique tensor (matches the unbatched path's contract;
|
||||
# callers that stash results don't get blown away on the
|
||||
# next iteration). The output_specs list is always passed
|
||||
# through to `run_with_ptrs`.
|
||||
output_tensors = []
|
||||
output_specs = []
|
||||
for id_, shape, dt in zip(self._output_ids, output_shapes, output_dtypes):
|
||||
out = torch.empty(shape, dtype=dt, device=input_device)
|
||||
output_specs.append((id_, out.data_ptr(), out.numel() * out.element_size()))
|
||||
output_tensors.append(out)
|
||||
zero_copy_flags = self._graph.run_with_ptrs(input_specs, output_specs)
|
||||
# For any output the runtime had to alias (not zero-copy),
|
||||
# request the DtoD copy explicitly into the registered buffer.
|
||||
# In DLRMv1 and similar models this never fires, but it's the
|
||||
# same fallback the unbatched path has.
|
||||
for ok, name, tensor in zip(zero_copy_flags, self._output_names, output_tensors):
|
||||
if not ok:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, tensor.data_ptr(), tensor.numel() * tensor.element_size()
|
||||
)
|
||||
# Return tuple unconditionally — the unbatched path below
|
||||
# also returns `tuple(outputs)` even for single-output
|
||||
# models, and torch.compile / dynamo's output handling
|
||||
# depends on that contract. Returning a bare Tensor here
|
||||
# made dynamo iterate the first dim and reshape the
|
||||
# output to a 1-element slice.
|
||||
return tuple(output_tensors)
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
|
||||
11
examples/dlrm/.gitignore
vendored
Normal file
11
examples/dlrm/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# Correctness-test fixtures generated by `correctness_dump.py`.
|
||||
# Regenerate with: python correctness_dump.py --num-cat N --rows R --out-dir weights_N
|
||||
weights/
|
||||
weights_*/
|
||||
|
||||
# Python bytecode
|
||||
__pycache__/
|
||||
|
||||
# Intermediate sweep outputs from sweep_all.py — `results.csv` is the
|
||||
# canonical published result; everything else is regenerable.
|
||||
quick*.csv
|
||||
18
examples/dlrm/Cargo.toml
Normal file
18
examples/dlrm/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "dlrm"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "dlrm"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "check"
|
||||
path = "src/bin/check.rs"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
rand = "0.9.2"
|
||||
bytemuck = "1"
|
||||
236
examples/dlrm/RESULTS.md
Normal file
236
examples/dlrm/RESULTS.md
Normal file
@@ -0,0 +1,236 @@
|
||||
# DLRMv1 via `torch.compile(backend=luminal_backend)` — sweep results
|
||||
|
||||
## TL;DR
|
||||
|
||||
- **Compiles end-to-end**: vanilla DLRMv1 (`nn.EmbeddingBag` × num_cat, then
|
||||
pairwise-dot interaction, then top MLP) lands on `torch.compile` with
|
||||
`backend=luminal_backend` for all 55 (`batch`, `num_cat`) cells in the
|
||||
sweep.
|
||||
- **Correctness**: max abs diff ≤ 1.8 × 10⁻⁷ vs PyTorch eager at
|
||||
`num_cat ∈ {2, 4, 8, 16, 32}` (essentially fp32 noise).
|
||||
- **Beats `pt_eager` in 55/55 cells (100%)**, by **4.7–7.1×**.
|
||||
- **Beats `graph_safe_inductor_cg` at the highest cell** (`nc=32, b=2048`):
|
||||
174 μs vs 241 μs (1.38× faster). Still slower at smaller cells —
|
||||
100 μs of fixed Python-wrapper overhead per call vs PT-CUDAGraph's
|
||||
~22 μs replay floor.
|
||||
|
||||
## Hardware / setup
|
||||
|
||||
- NVIDIA GH200 480GB, CUDA 12.8, driver 570.148.08.
|
||||
- DLRMv1 inline (matches `examples/dlrm/sweep_pytorch.py`'s `DLRMv1`).
|
||||
Fixed dials: `m_den=3`, `m_spa=16`, `bag=2`, `rows=4096`.
|
||||
- Harness: 5 rounds × 20 iters × 10 warmup before round 0,
|
||||
median-of-round-medians, CUDA-event timing.
|
||||
|
||||
## Sweep dimensions
|
||||
|
||||
| dim | values |
|
||||
|---|---|
|
||||
| `batch` | 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 |
|
||||
| `num_cat` | 2, 4, 8, 16, 32 |
|
||||
|
||||
55 cells × 3 variants = 165 timings. CSV: `results.csv`.
|
||||
|
||||
## Variants
|
||||
|
||||
- `pt_eager`: vanilla DLRMv1 forward, no compile, no graph capture.
|
||||
- `graph_safe_inductor_cg`: wrap DLRMv1 in `GraphSafeDLRM` (pre-bakes the
|
||||
`li`/`lj` lower-tri index buffers so the forward is capturable), then
|
||||
`torch.compile(backend="inductor", mode="max-autotune-no-cudagraphs",
|
||||
fullgraph=False, dynamic=False)`, then manual
|
||||
`torch.cuda.CUDAGraph` capture/replay. **This is the named PT baseline.**
|
||||
- `luminal_compiled`: `torch.compile(backend=luminal_backend, fullgraph=
|
||||
False, dynamic=False)`, no external CUDAGraph wrap. luminal's
|
||||
`cuda_lite` runtime captures/replays a CUDA graph internally on
|
||||
every `execute()` call.
|
||||
|
||||
## Wall-clock results (median-of-round-medians, μs)
|
||||
|
||||
### `luminal_compiled`
|
||||
|
||||
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|
||||
|---|---|---|---|---|---|---|
|
||||
| 2 | 101 | 95 | 96 | 100 | 98 | 102 |
|
||||
| 4 | 100 | 99 | 110 | 103 | 104 | 112 |
|
||||
| 8 | 109 | 106 | 116 | 110 | 111 | 121 |
|
||||
| 16 | 131 | 133 | 134 | 136 | 140 | 135 |
|
||||
| 32 | 172 | 172 | 178 | 174 | 176 | 174 |
|
||||
|
||||
### `graph_safe_inductor_cg` — the named PT baseline
|
||||
|
||||
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|
||||
|---|---|---|---|---|---|---|
|
||||
| 2 | 22 | 24 | 26 | 27 | 30 | 34 |
|
||||
| 4 | 27 | 29 | 31 | 33 | 35 | 45 |
|
||||
| 8 | 46 | 48 | 52 | 56 | 57 | 66 |
|
||||
| 16 | 75 | 79 | 91 | 96 | 96 | 121 |
|
||||
| 32 | 132 | 138 | 164 | 157 | 170 | 241 |
|
||||
|
||||
### `pt_eager`
|
||||
|
||||
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|
||||
|---|---|---|---|---|---|---|
|
||||
| 2 | 506 | 453 | 472 | 496 | 578 | 555 |
|
||||
| 4 | 514 | 488 | 569 | 535 | 605 | 588 |
|
||||
| 8 | 624 | 632 | 709 | 639 | 676 | 604 |
|
||||
| 16 | 786 | 783 | 879 | 813 | 862 | 865 |
|
||||
| 32 | 1125 | 1133 | 1180 | 1169 | 1176 | 1233 |
|
||||
|
||||
## Speedup vs `graph_safe_inductor_cg` (>1 = luminal faster)
|
||||
|
||||
| nc \ batch | 2 | 16 | 64 | 256 | 1024 | 2048 |
|
||||
|---|---|---|---|---|---|---|
|
||||
| 2 | 0.22× | 0.25× | 0.27× | 0.28× | 0.30× | 0.33× |
|
||||
| 4 | 0.27× | 0.29× | 0.29× | 0.32× | 0.34× | 0.40× |
|
||||
| 8 | 0.42× | 0.45× | 0.45× | 0.51× | 0.52× | 0.55× |
|
||||
| 16 | 0.57× | 0.60× | 0.68× | 0.71× | 0.68× | 0.90× |
|
||||
| 32 | 0.77× | 0.80× | 0.92× | 0.91× | 0.97× | **1.38×** |
|
||||
|
||||
## Speedup vs `pt_eager` (>1 = luminal faster)
|
||||
|
||||
luminal_compiled is **4.7–7.1× faster** at every cell. Won at 55/55 cells (100%).
|
||||
|
||||
## What changed to get here
|
||||
|
||||
Three rounds of wrapper-overhead reduction (all in `crates/luminal_python/`
|
||||
and `crates/luminal_cuda_lite/`):
|
||||
|
||||
### Round 1: Translator + cuda_lite kernel matching
|
||||
Made `torch.compile(model, backend=luminal_backend)` emit the same kernel
|
||||
set as the hand-written rust DLRM:
|
||||
|
||||
| FX subgraph | luminal kernel |
|
||||
|---|---|
|
||||
| `addmm(bias, x, weight.t())` + `relu/sigmoid` consumer | `linear_bias_(relu|sigmoid)` (one fused kernel per MLP layer) |
|
||||
| N × `_embedding_bag(W_k, idx_k, off_k)` | `multi_table_embedding_bag_sum_kernel` (one fused kernel for all tables) |
|
||||
| `index.Tensor(bmm(cat(unsqueezes), permute([0,2,1])), [None, li, lj])` | `dlrm_pairwise_dot_lower_tri_stacked` (one fused kernel) |
|
||||
| `bmm(A, permute(B, [0,2,1]))` (generic) | `matmul_3d_t` |
|
||||
|
||||
Plus a post-translation DCE pass to drop the now-unreachable
|
||||
`bmm/cat/permute` chain superseded by the pairwise-dot peephole.
|
||||
|
||||
### Round 2: Batched FFI + by-id setter
|
||||
Added `tensor_id(name) → u32` and `run_with_ptrs(inputs, outputs)` on the
|
||||
Rust side (one pyo3 hop instead of N), and cached the IDs once at
|
||||
`CompiledModel.__init__`. The Python wrapper now passes
|
||||
`[(id, ptr, n_bytes), …]` lists instead of N separate
|
||||
`set_input_device_ptr(name, ptr, n_bytes)` calls.
|
||||
|
||||
### Round 3: Skip re-registration on unchanged inputs
|
||||
The Python wrapper now caches `(id(tensor), data_ptr)` per input slot.
|
||||
On a hot bench loop the same user tensors are passed each iteration —
|
||||
all 65 inputs hit the cache, the `input_specs` list passed to
|
||||
`run_with_ptrs` is empty, and the runtime relies on the previously-
|
||||
registered pointers. Required one runtime change: in
|
||||
`cuda_lite/runtime.rs::execute()`'s post-run consume step, don't drop
|
||||
`CudaInput::Ptr` entries — they're non-owning views over caller memory
|
||||
and must persist across `execute()` calls for the skip path to work.
|
||||
|
||||
## Where the per-iter time goes (nc=32, b=2048)
|
||||
|
||||
Before any of the rounds:
|
||||
|
||||
| section | μs/iter |
|
||||
|---|---|
|
||||
| `set_input_device_ptr` × 65 (Python loop + FFI) | 995 |
|
||||
| `_graph.run()` (CUDA-graph replay) | 61 |
|
||||
| set_output + alloc + collect | 13 |
|
||||
| **total** | **1069** |
|
||||
|
||||
After all three rounds:
|
||||
|
||||
| section | μs/iter |
|
||||
|---|---|
|
||||
| input cache check (0 cold-path inputs after warmup) | 15 |
|
||||
| `torch.empty` + register output | 11 |
|
||||
| `run_with_ptrs` (batched FFI, all caches hit, runtime executes graph) | 56 |
|
||||
| **total** | **82** |
|
||||
|
||||
**13× per-iter reduction**, none of it in the kernels themselves.
|
||||
|
||||
## Hand-written rust DLRM (`examples/dlrm/src/main.rs`)
|
||||
|
||||
The hand-written rust uses the same `cuda_lite` kernels directly — no
|
||||
Python, no `torch.compile`. On `nc=32, b=2048`: **104 μs**
|
||||
(median-of-round-medians with explicit `synchronize_stream` for accurate
|
||||
GPU time). The `luminal_compiled` figure for the same cell is 174 μs —
|
||||
the residual 70 μs gap is roughly:
|
||||
|
||||
- `torch.empty` + output registration (~11 μs)
|
||||
- Python-side cache check + 1 round-trip into `run_with_ptrs` (~10 μs)
|
||||
- Marshalling 1 input spec + 1 output spec across the pyo3 boundary
|
||||
- pt2 backend wrapper invocation by torch.compile (a few μs of
|
||||
dynamo/eval_frame work)
|
||||
- The remaining ~40 μs is the runtime's per-call exec_op iteration +
|
||||
buffer_map building inside `cuda_lite/src/runtime.rs::execute()` —
|
||||
a structural cost of how the runtime dispatches host ops, paid once
|
||||
per call regardless of input count.
|
||||
|
||||
## What's left
|
||||
|
||||
In scope, deferred:
|
||||
|
||||
- `linear_bias_relu_split_a` peephole on top MLP first layer. Saves
|
||||
one materialized `cat` + one small kernel. Modest cell-by-cell win.
|
||||
|
||||
Out of scope:
|
||||
|
||||
- The remaining ~80–100 μs of wrapper / runtime fixed overhead is a
|
||||
combination of `torch.compile`'s dynamo eval-frame dispatch, the
|
||||
Python `__call__` setup work, and the runtime's per-execute toposort
|
||||
+ buffer_map build. Closing it further would either need a deeper
|
||||
rewrite of `runtime.rs::execute()` (probably worth it independent of
|
||||
DLRM) or a way to skip dynamo's per-call overhead.
|
||||
|
||||
## Files this work touches
|
||||
|
||||
In scope:
|
||||
|
||||
- `crates/luminal_python/rust/src/translator/{mod,dispatch,movement}.rs` —
|
||||
peephole infra, op handlers, multi-call fusions, post-translation DCE.
|
||||
- `crates/luminal_cuda_lite/src/kernel/{embedding_bag,dlrm_interact,
|
||||
matmul2d,mod}.rs` — fused kernels for embedding bag (single + multi-
|
||||
table + stacked), pairwise dot (variadic + stacked), and fused-
|
||||
activation linear (relu/sigmoid + split-A).
|
||||
- `crates/luminal_cuda_lite/src/runtime.rs` — gate end-of-execute
|
||||
stream sync on `profiling`; expose `synchronize_stream()` and
|
||||
`read_per_kernel_timings_ms()` stub; preserve `CudaInput::Ptr`
|
||||
entries across `execute()` calls (the unchanged-input cache fix).
|
||||
|
||||
Wrapper layer (relaxed scope; minimal targeted changes):
|
||||
|
||||
- `crates/luminal_python/rust/src/compiled_graph.rs` — `tensor_id(name)`
|
||||
+ `run_with_ptrs(inputs, outputs)` for batched FFI.
|
||||
- `crates/luminal_python/src/luminal/compiled_model.py` — fast-path
|
||||
`run_with_ptrs` call with the per-input cache.
|
||||
|
||||
Leaf consumer (allowed):
|
||||
|
||||
- `examples/dlrm/` — bench harness, hand-written rust DLRM, correctness
|
||||
check, sweep scripts, this report. Cherry-picked from
|
||||
`origin/dlrm-fused-kernels` then adapted (inline DLRMv1, no
|
||||
upstream `dlrm_s_pytorch` dependency, added third PT variant).
|
||||
|
||||
## Reproduction
|
||||
|
||||
From `examples/dlrm`:
|
||||
|
||||
```bash
|
||||
# Build
|
||||
(cd ../.. && cargo build -p dlrm --release)
|
||||
(cd ../../crates/luminal_python && CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml \
|
||||
--features cuda -r)
|
||||
|
||||
# Correctness check (PyTorch eager → luminal hand-written rust)
|
||||
for nc in 2 4 8 16 32; do
|
||||
python correctness_dump.py --num-cat $nc --rows 4096
|
||||
(cd ../.. && ./target/release/check)
|
||||
done
|
||||
|
||||
# Full 55-cell sweep, 3 variants
|
||||
python sweep_all.py --variants pt_eager graph_safe_inductor_cg luminal_compiled
|
||||
|
||||
# Pretty summary
|
||||
python summarize.py results.csv
|
||||
```
|
||||
348
examples/dlrm/bench_pytorch.py
Normal file
348
examples/dlrm/bench_pytorch.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""PyTorch reference for the DLRM `_make_dlrm_batch_2048` config.
|
||||
|
||||
Mirrors `bench_luminal_exact.py` from https://github.com/jss8649/tmp-dlrm-bench
|
||||
in spirit (shape-for-shape, same indices, same MLPs, same interaction op)
|
||||
but reimplements the model inline so we don't need the upstream
|
||||
facebookresearch/dlrm package.
|
||||
|
||||
Measures:
|
||||
* eager
|
||||
* torch.compile (inductor, default)
|
||||
* torch.compile (inductor, mode="reduce-overhead" → CUDA-graph capture)
|
||||
* the v3 fused trick (index_select + reshape + sum on a stacked table)
|
||||
with reduce-overhead — that's the WINNER in the reference repo.
|
||||
|
||||
Harness: 5 rounds × 20 iters, 10 warmup, median-of-round-medians (CUDA-event
|
||||
timing). Matches the luminal-side harness verbatim.
|
||||
|
||||
Usage:
|
||||
python bench_pytorch.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# ---- Config (matches `_make_dlrm_batch_2048`) -----------------------------
|
||||
BATCH = 2048
|
||||
M_DEN = 3
|
||||
M_SPA = 16
|
||||
INDICES_PER_BAG = 2
|
||||
LN_EMB = np.array([4096, 2048, 1024], dtype=np.int64)
|
||||
NUM_EMB = len(LN_EMB)
|
||||
NUM_FEA = NUM_EMB + 1
|
||||
PAIR_COUNT = NUM_FEA * (NUM_FEA - 1) // 2 # strict lower tri, no diagonal
|
||||
TOP_IN = PAIR_COUNT + M_SPA # 6 + 16 = 22
|
||||
LN_BOT = [M_DEN, 64, M_SPA] # [3, 64, 16]
|
||||
LN_TOP_TAIL = [64, 32, 1] # ln_top = [22, 64, 32, 1]
|
||||
SEED = 0
|
||||
|
||||
|
||||
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
|
||||
layers: List[nn.Module] = []
|
||||
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
|
||||
layers.append(nn.Linear(a, b, bias=True))
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
# ---- v1-shape model: EmbeddingBag per table (the "natural" expression
|
||||
# a user writes; same shape as the upstream DLRM forward) ------------
|
||||
class DLRMv1(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
ln_top = [TOP_IN] + LN_TOP_TAIL # [22, 64, 32, 1]
|
||||
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
|
||||
# sigmoid_top in the upstream is `len(ln_top) - 2` = 2 → final linear
|
||||
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
|
||||
self.emb = nn.ModuleList(
|
||||
[
|
||||
nn.EmbeddingBag(int(n), M_SPA, mode="sum", sparse=False)
|
||||
for n in LN_EMB
|
||||
]
|
||||
)
|
||||
# Pre-compute strict lower-tri (no diagonal) indices.
|
||||
li, lj = [], []
|
||||
for i in range(NUM_FEA):
|
||||
for j in range(i):
|
||||
li.append(i)
|
||||
lj.append(j)
|
||||
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
|
||||
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
|
||||
|
||||
def forward(self, dense_x, lS_o, lS_i):
|
||||
x = self.bot(dense_x) # (B, M_SPA)
|
||||
ly = [
|
||||
self.emb[k](lS_i[k], lS_o[k]) for k in range(NUM_EMB)
|
||||
] # each (B, M_SPA)
|
||||
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1) # (B, F, M_SPA)
|
||||
Z = torch.bmm(T, T.transpose(1, 2)) # (B, F, F)
|
||||
Zflat = Z[:, self.li, self.lj] # (B, PAIRS)
|
||||
R = torch.cat([x, Zflat], dim=1) # (B, M_SPA + PAIRS)
|
||||
return self.top(R)
|
||||
|
||||
|
||||
# ---- v3 fused: stacked embedding table, index_select + reshape + sum ------
|
||||
# This is the winner per perf.md (Inductor can fuse gather+sum into
|
||||
# a single Triton kernel; opaque EmbeddingBag blocks that fusion). ----
|
||||
class DLRMv3(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
total_rows = int(LN_EMB.sum())
|
||||
self.num_emb = NUM_EMB
|
||||
self.m_spa = M_SPA
|
||||
self.L = INDICES_PER_BAG
|
||||
big_w = np.empty((total_rows, M_SPA), dtype=np.float32)
|
||||
starts = np.zeros(NUM_EMB, dtype=np.int64)
|
||||
s = 0
|
||||
for k, n in enumerate(LN_EMB):
|
||||
starts[k] = s
|
||||
big_w[s : s + int(n)] = np.random.uniform(
|
||||
-np.sqrt(1.0 / int(n)),
|
||||
np.sqrt(1.0 / int(n)),
|
||||
(int(n), M_SPA),
|
||||
).astype(np.float32)
|
||||
s += int(n)
|
||||
self.emb_weight = nn.Parameter(torch.from_numpy(big_w))
|
||||
self.register_buffer("row_offsets", torch.from_numpy(starts))
|
||||
|
||||
ln_top = [TOP_IN] + LN_TOP_TAIL
|
||||
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
|
||||
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
|
||||
|
||||
li, lj = [], []
|
||||
for i in range(NUM_FEA):
|
||||
for j in range(i):
|
||||
li.append(i)
|
||||
lj.append(j)
|
||||
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
|
||||
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
|
||||
|
||||
def forward(self, dense_x, flat_indices):
|
||||
bs = dense_x.shape[0]
|
||||
gathered = torch.index_select(self.emb_weight, 0, flat_indices)
|
||||
gathered = gathered.view(self.num_emb * bs, self.L, self.m_spa)
|
||||
pooled = gathered.sum(dim=1) # (num_emb*B, m_spa)
|
||||
ly = pooled.view(self.num_emb, bs, self.m_spa).transpose(0, 1) # (B, num_emb, m_spa)
|
||||
x = self.bot(dense_x)
|
||||
T = torch.cat([x.unsqueeze(1), ly], dim=1)
|
||||
Z = torch.bmm(T, T.transpose(1, 2))
|
||||
Zflat = Z[:, self.li, self.lj]
|
||||
R = torch.cat([x, Zflat], dim=1)
|
||||
return self.top(R)
|
||||
|
||||
|
||||
# ---- Deterministic inputs matching `_make_dlrm_batch_2048` ---------------
|
||||
def make_v1_inputs(device):
|
||||
dense_x = (
|
||||
torch.linspace(-1.0, 1.0, BATCH * M_DEN, dtype=torch.float32, device=device)
|
||||
.reshape(BATCH, M_DEN)
|
||||
)
|
||||
total = BATCH * INDICES_PER_BAG
|
||||
positions = torch.arange(total, dtype=torch.int64, device=device)
|
||||
offsets = torch.arange(0, total, INDICES_PER_BAG, dtype=torch.int64, device=device)
|
||||
lS_o = [offsets.clone() for _ in range(NUM_EMB)]
|
||||
lS_i = [
|
||||
((positions * 3 + 1) % int(LN_EMB[0])).to(torch.int64),
|
||||
((positions * 5 + 2) % int(LN_EMB[1])).to(torch.int64),
|
||||
((positions * 7 + 3) % int(LN_EMB[2])).to(torch.int64),
|
||||
]
|
||||
return dense_x, lS_o, lS_i
|
||||
|
||||
|
||||
def make_v3_inputs(device, model: DLRMv3):
|
||||
dense_x, _, lS_i = make_v1_inputs(device)
|
||||
# Stack and add per-table row offsets so a single index_select pulls
|
||||
# from the unified table.
|
||||
stacked = torch.stack(lS_i, dim=0) # (NUM_EMB, B*L)
|
||||
flat = (stacked + model.row_offsets.view(NUM_EMB, 1)).reshape(-1)
|
||||
return dense_x, flat
|
||||
|
||||
|
||||
# ---- Timing harness (mirrors bench_luminal_exact.py) ---------------------
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
prev_r = torch._dynamo.config.recompile_limit
|
||||
prev_c = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = prev_r
|
||||
torch._dynamo.config.cache_size_limit = prev_c
|
||||
|
||||
|
||||
def _compile_inductor(model):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
|
||||
|
||||
def _compile_inductor_ro(model):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(
|
||||
copy.deepcopy(model), backend="inductor", mode="reduce-overhead"
|
||||
)
|
||||
|
||||
|
||||
def _timed_cuda_runs(model, inputs, warmup, timed, mark_step=False):
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup):
|
||||
if mark_step:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
_ = model(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed)]
|
||||
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed)]
|
||||
with torch.no_grad():
|
||||
for i in range(timed):
|
||||
if mark_step:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
starts[i].record()
|
||||
_ = model(*inputs)
|
||||
ends[i].record()
|
||||
torch.cuda.synchronize()
|
||||
return np.array(
|
||||
[s.elapsed_time(e) for s, e in zip(starts, ends)], dtype=np.float64
|
||||
)
|
||||
|
||||
|
||||
def _timed_rounds(model, inputs, *, warmup, timed, rounds, mark_step=False):
|
||||
arr = _timed_cuda_runs(model, inputs, warmup, timed, mark_step=mark_step)
|
||||
round_medians = [float(np.median(arr))]
|
||||
for _ in range(rounds - 1):
|
||||
arr = _timed_cuda_runs(model, inputs, 0, timed, mark_step=mark_step)
|
||||
round_medians.append(float(np.median(arr)))
|
||||
return round_medians
|
||||
|
||||
|
||||
def _manual_cudagraph_capture(model, inputs):
|
||||
"""Capture model(inputs) into a CUDA graph and return a replayable
|
||||
closure that runs zero new launches (just cuGraphLaunch)."""
|
||||
# Warm up.
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s), torch.no_grad():
|
||||
for _ in range(3):
|
||||
_ = model(*inputs)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g), torch.no_grad():
|
||||
out = model(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def replay():
|
||||
g.replay()
|
||||
return out
|
||||
return replay
|
||||
|
||||
|
||||
def main():
|
||||
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
|
||||
print(f"device={torch.cuda.get_device_name(0)} cap={torch.cuda.get_device_capability(0)}")
|
||||
print(f"float32 matmul precision: {torch.get_float32_matmul_precision()}")
|
||||
device = torch.device("cuda")
|
||||
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
||||
# v1: EmbeddingBag-style
|
||||
v1 = DLRMv1().to(device).eval()
|
||||
v1_inputs = make_v1_inputs(device)
|
||||
v1_eager = copy.deepcopy(v1).to(device).eval()
|
||||
v1_ind = _compile_inductor(v1)
|
||||
v1_ind_ro = _compile_inductor_ro(v1)
|
||||
|
||||
# v3: index_select + reshape + sum on stacked table
|
||||
v3 = DLRMv3().to(device).eval()
|
||||
v3_inputs = make_v3_inputs(device, v3)
|
||||
v3_ind_ro = _compile_inductor_ro(v3)
|
||||
|
||||
# Prime
|
||||
with torch.no_grad():
|
||||
_ = v1_eager(*v1_inputs)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
_ = v1_ind_ro(*v1_inputs)
|
||||
_ = v1_ind(*v1_inputs)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
_ = v3_ind_ro(*v3_inputs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Manual CUDA graph capture for v3 eager
|
||||
v3_eager_replay = _manual_cudagraph_capture(v3, v3_inputs)
|
||||
# Manual CUDA graph capture for v1 eager
|
||||
v1_eager_replay = _manual_cudagraph_capture(v1, v1_inputs)
|
||||
|
||||
rounds, iters, warmup = 5, 20, 10
|
||||
|
||||
def report(label, model, inputs, mark_step=False, *, raw_replay=False):
|
||||
if raw_replay:
|
||||
# Time the replay closure directly.
|
||||
replay = model
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup):
|
||||
replay()
|
||||
torch.cuda.synchronize()
|
||||
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
with torch.no_grad():
|
||||
for i in range(iters):
|
||||
starts[i].record()
|
||||
replay()
|
||||
ends[i].record()
|
||||
torch.cuda.synchronize()
|
||||
arr = np.array([s.elapsed_time(e) for s, e in zip(starts, ends)])
|
||||
rms = [float(np.median(arr))]
|
||||
for _ in range(rounds - 1):
|
||||
with torch.no_grad():
|
||||
for i in range(iters):
|
||||
starts[i].record()
|
||||
replay()
|
||||
ends[i].record()
|
||||
torch.cuda.synchronize()
|
||||
arr = np.array(
|
||||
[s.elapsed_time(e) for s, e in zip(starts, ends)]
|
||||
)
|
||||
rms.append(float(np.median(arr)))
|
||||
else:
|
||||
rms = _timed_rounds(
|
||||
model, inputs, warmup=warmup, timed=iters, rounds=rounds,
|
||||
mark_step=mark_step,
|
||||
)
|
||||
rms_sorted = sorted(rms)
|
||||
med = rms_sorted[len(rms_sorted) // 2]
|
||||
tput = BATCH / (med / 1000.0)
|
||||
print(
|
||||
f" {label:<60} median {med:7.3f} ms ({tput:>9,.0f} samples/s) "
|
||||
f"round medians: [{', '.join(f'{v:.3f}' for v in rms)}]"
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"PyTorch reference (5 rounds x 20 iters, 10 warmup), batch={BATCH}:")
|
||||
report("v1 eager", v1_eager, v1_inputs)
|
||||
report("v1 torch.compile(inductor)", v1_ind, v1_inputs)
|
||||
report("v1 torch.compile(reduce-overhead)", v1_ind_ro, v1_inputs, mark_step=True)
|
||||
report("v1 eager + manual CUDAGraph", v1_eager_replay, None, raw_replay=True)
|
||||
report("v3 torch.compile(reduce-overhead)", v3_ind_ro, v3_inputs, mark_step=True)
|
||||
report("v3 eager + manual CUDAGraph", v3_eager_replay, None, raw_replay=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
172
examples/dlrm/correctness_dump.py
Normal file
172
examples/dlrm/correctness_dump.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Dump a deterministic PyTorch DLRMv1 forward to disk so a parallel
|
||||
luminal binary can load the exact same weights/inputs and verify it
|
||||
produces the same output.
|
||||
|
||||
Writes one f32 little-endian binary blob per tensor, plus a manifest
|
||||
JSON describing the shapes. All paths are under `weights/`.
|
||||
|
||||
Usage: python correctness_dump.py [--num-cat N] [--rows R]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
SEED = 1234
|
||||
BATCH = 2048
|
||||
M_DEN = 3
|
||||
M_SPA = 16
|
||||
L = 2
|
||||
LN_BOT = [M_DEN, 64, M_SPA]
|
||||
LN_TOP_TAIL = [64, 32, 1]
|
||||
|
||||
|
||||
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
|
||||
layers: List[nn.Module] = []
|
||||
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
|
||||
layers.append(nn.Linear(a, b, bias=True))
|
||||
layers.append(nn.Sigmoid() if i == sigmoid_layer else nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class DLRMv1(nn.Module):
|
||||
def __init__(self, num_cat: int, rows: int):
|
||||
super().__init__()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
self.num_cat = num_cat
|
||||
ni = num_cat + 1
|
||||
num_int = ni * (ni - 1) // 2 + M_SPA
|
||||
ln_top = [num_int] + LN_TOP_TAIL
|
||||
self.bot = _build_mlp(LN_BOT, sigmoid_layer=-1)
|
||||
# `sigmoid_top=2` in upstream means sigmoid on the final linear,
|
||||
# which corresponds to sigmoid_layer = len(layers in Sequential) - 1
|
||||
# Our `_build_mlp` indexes by Linear-index, so the last linear has
|
||||
# index len(ln_top) - 2.
|
||||
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
|
||||
self.emb = nn.ModuleList(
|
||||
[nn.EmbeddingBag(rows, M_SPA, mode="sum", sparse=False) for _ in range(num_cat)]
|
||||
)
|
||||
li, lj = [], []
|
||||
for i in range(ni):
|
||||
for j in range(i):
|
||||
li.append(i)
|
||||
lj.append(j)
|
||||
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
|
||||
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
|
||||
|
||||
def forward(self, dense_x, lS_o, lS_i):
|
||||
x = self.bot(dense_x)
|
||||
ly = [self.emb[k](lS_i[k], lS_o[k]) for k in range(self.num_cat)]
|
||||
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)
|
||||
Z = torch.bmm(T, T.transpose(1, 2))
|
||||
Zflat = Z[:, self.li, self.lj]
|
||||
R = torch.cat([x, Zflat], dim=1)
|
||||
return self.top(R)
|
||||
|
||||
|
||||
def build_indices(table_idx: int, batch: int, bag: int, rows: int) -> np.ndarray:
|
||||
# Match sweep_categories.py and the luminal Rust binary exactly.
|
||||
s = 2 * table_idx + 3
|
||||
o = table_idx + 1
|
||||
pos = np.arange(batch * bag, dtype=np.int64)
|
||||
return (pos * s + o) % rows
|
||||
|
||||
|
||||
def build_dense_x(batch: int, m_den: int) -> np.ndarray:
|
||||
total = batch * m_den
|
||||
return np.linspace(-1.0, 1.0, num=total, dtype=np.float32).reshape(batch, m_den)
|
||||
|
||||
|
||||
def write_f32(path: Path, arr: np.ndarray) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
arr.astype(np.float32, copy=False).tofile(path)
|
||||
|
||||
|
||||
def write_i32(path: Path, arr: np.ndarray) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
arr.astype(np.int32, copy=False).tofile(path)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--num-cat", type=int, default=3)
|
||||
ap.add_argument("--rows", type=int, default=4096)
|
||||
ap.add_argument("--out-dir", type=str, default="weights")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = DLRMv1(args.num_cat, args.rows).to(device).eval()
|
||||
|
||||
dense_np = build_dense_x(BATCH, M_DEN)
|
||||
dense = torch.from_numpy(dense_np).to(device)
|
||||
offsets = torch.arange(0, BATCH * L, L, dtype=torch.int64, device=device)
|
||||
lS_o = [offsets.clone() for _ in range(args.num_cat)]
|
||||
lS_i = [
|
||||
torch.from_numpy(build_indices(k, BATCH, L, args.rows)).to(device)
|
||||
for k in range(args.num_cat)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(dense, lS_o, lS_i)
|
||||
print(f"output[:8] = {out.detach().cpu().numpy().flatten()[:8]}")
|
||||
print(f"output stats: min={out.min().item():.6f} max={out.max().item():.6f} "
|
||||
f"mean={out.mean().item():.6f}")
|
||||
|
||||
# ---- dump weights ----
|
||||
# Bottom MLP: linears at indices 0 and 2 in the Sequential (Linear, ReLU,
|
||||
# Linear). Luminal stores W as (out, in), matching PyTorch's nn.Linear.
|
||||
bot_lins = [m for m in model.bot if isinstance(m, nn.Linear)]
|
||||
top_lins = [m for m in model.top if isinstance(m, nn.Linear)]
|
||||
for i, l in enumerate(bot_lins):
|
||||
write_f32(out_dir / f"bot_{i}_w.bin", l.weight.detach().cpu().numpy())
|
||||
write_f32(out_dir / f"bot_{i}_b.bin", l.bias.detach().cpu().numpy())
|
||||
for i, l in enumerate(top_lins):
|
||||
write_f32(out_dir / f"top_{i}_w.bin", l.weight.detach().cpu().numpy())
|
||||
write_f32(out_dir / f"top_{i}_b.bin", l.bias.detach().cpu().numpy())
|
||||
for k, e in enumerate(model.emb):
|
||||
write_f32(out_dir / f"emb_{k}.bin", e.weight.detach().cpu().numpy())
|
||||
|
||||
# ---- dump inputs ----
|
||||
write_f32(out_dir / "dense.bin", dense_np)
|
||||
for k, idx in enumerate(lS_i):
|
||||
write_i32(out_dir / f"idx_{k}.bin", idx.cpu().numpy())
|
||||
|
||||
# ---- dump expected output ----
|
||||
write_f32(out_dir / "expected.bin", out.detach().cpu().numpy())
|
||||
|
||||
manifest = {
|
||||
"num_cat": args.num_cat,
|
||||
"rows": args.rows,
|
||||
"batch": BATCH,
|
||||
"m_den": M_DEN,
|
||||
"m_spa": M_SPA,
|
||||
"indices_per_bag": L,
|
||||
"ln_bot": LN_BOT,
|
||||
"ln_top": [args.num_cat * (args.num_cat + 1) // 2 + M_SPA] + LN_TOP_TAIL,
|
||||
"bot_layer_shapes": [list(l.weight.shape) for l in bot_lins],
|
||||
"top_layer_shapes": [list(l.weight.shape) for l in top_lins],
|
||||
"output_shape": list(out.shape),
|
||||
"output_head": out.detach().cpu().numpy().flatten()[:8].tolist(),
|
||||
}
|
||||
with open(out_dir / "manifest.json", "w") as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
print(f"\nWrote weights/inputs/expected to {out_dir}/")
|
||||
print(f" manifest: {out_dir}/manifest.json")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
166
examples/dlrm/results.csv
Normal file
166
examples/dlrm/results.csv
Normal file
@@ -0,0 +1,166 @@
|
||||
variant,num_cat,batch,m_spa,bag,rows,ms,samples_per_sec,status
|
||||
pt_eager,2,2,16,2,4096,0.5056319832801819,3955.4459886524933,ok
|
||||
graph_safe_inductor_cg,2,2,16,2,4096,0.022352000698447227,89477.44888621707,ok
|
||||
luminal_compiled,2,2,16,2,4096,0.10134400054812431,19734.764654867537,ok
|
||||
pt_eager,2,4,16,2,4096,0.4413599967956543,9062.89656752006,ok
|
||||
graph_safe_inductor_cg,2,4,16,2,4096,0.022592000663280487,177053.8191644678,ok
|
||||
luminal_compiled,2,4,16,2,4096,0.09467199817299843,42251.141596173125,ok
|
||||
pt_eager,2,8,16,2,4096,0.4634399861097336,17262.21353309327,ok
|
||||
graph_safe_inductor_cg,2,8,16,2,4096,0.02404800057411194,332667.9893966789,ok
|
||||
luminal_compiled,2,8,16,2,4096,0.09742399677634239,82115.29258408182,ok
|
||||
pt_eager,2,16,16,2,4096,0.4532639980316162,35299.516549920125,ok
|
||||
graph_safe_inductor_cg,2,16,16,2,4096,0.023744000121951103,673854.4439783823,ok
|
||||
luminal_compiled,2,16,16,2,4096,0.09547200053930283,167588.40193584614,ok
|
||||
pt_eager,2,32,16,2,4096,0.5251840054988861,60931.025440507,ok
|
||||
graph_safe_inductor_cg,2,32,16,2,4096,0.02582399919629097,1239157.411552122,ok
|
||||
luminal_compiled,2,32,16,2,4096,0.09985600039362907,320461.46324564423,ok
|
||||
pt_eager,2,64,16,2,4096,0.47198399901390076,135597.81715844796,ok
|
||||
graph_safe_inductor_cg,2,64,16,2,4096,0.02619200013577938,2443494.184034204,ok
|
||||
luminal_compiled,2,64,16,2,4096,0.09620799869298935,665225.3541229069,ok
|
||||
pt_eager,2,128,16,2,4096,0.5091679990291595,251390.50420305296,ok
|
||||
graph_safe_inductor_cg,2,128,16,2,4096,0.027855999767780304,4595060.348472987,ok
|
||||
luminal_compiled,2,128,16,2,4096,0.10607999935746193,1206636.508062876,ok
|
||||
pt_eager,2,256,16,2,4096,0.49588799476623535,516245.6092946552,ok
|
||||
graph_safe_inductor_cg,2,256,16,2,4096,0.027456000447273254,9324009.172116116,ok
|
||||
luminal_compiled,2,256,16,2,4096,0.09959999844431877,2570281.1646439577,ok
|
||||
pt_eager,2,512,16,2,4096,0.499007984995842,1026035.6855898134,ok
|
||||
graph_safe_inductor_cg,2,512,16,2,4096,0.036959998309612274,13852814.486380616,ok
|
||||
luminal_compiled,2,512,16,2,4096,0.10465599969029427,4892218.32971973,ok
|
||||
pt_eager,2,1024,16,2,4096,0.5781759917736053,1771087.0298484561,ok
|
||||
graph_safe_inductor_cg,2,1024,16,2,4096,0.029888000339269638,34261241.58110951,ok
|
||||
luminal_compiled,2,1024,16,2,4096,0.098191998898983,10428548.267496424,ok
|
||||
pt_eager,2,2048,16,2,4096,0.5551839768886566,3688867.2678871835,ok
|
||||
graph_safe_inductor_cg,2,2048,16,2,4096,0.03388800099492073,60434370.27480501,ok
|
||||
luminal_compiled,2,2048,16,2,4096,0.10235200077295303,20009379.245483134,ok
|
||||
pt_eager,4,2,16,2,4096,0.5139839947223663,3891.171749580102,ok
|
||||
graph_safe_inductor_cg,4,2,16,2,4096,0.02729600016027689,73270.8084794982,ok
|
||||
luminal_compiled,4,2,16,2,4096,0.09959999844431877,20080.32159878092,ok
|
||||
pt_eager,4,4,16,2,4096,0.512287974357605,7808.10833011626,ok
|
||||
graph_safe_inductor_cg,4,4,16,2,4096,0.027664000168442726,144592.24897500317,ok
|
||||
luminal_compiled,4,4,16,2,4096,0.10916800051927567,36640.77367885587,ok
|
||||
pt_eager,4,8,16,2,4096,0.5412319898605347,14781.092303988627,ok
|
||||
graph_safe_inductor_cg,4,8,16,2,4096,0.02937600016593933,272331.15314574994,ok
|
||||
luminal_compiled,4,8,16,2,4096,0.10155199840664864,78777.37637387782,ok
|
||||
pt_eager,4,16,16,2,4096,0.48787200450897217,32795.48703784202,ok
|
||||
graph_safe_inductor_cg,4,16,16,2,4096,0.028704000636935234,557413.5885229809,ok
|
||||
luminal_compiled,4,16,16,2,4096,0.09879999980330467,161943.32016046048,ok
|
||||
pt_eager,4,32,16,2,4096,0.5792959928512573,55239.46375409585,ok
|
||||
graph_safe_inductor_cg,4,32,16,2,4096,0.036240000277757645,883002.2007378418,ok
|
||||
luminal_compiled,4,32,16,2,4096,0.10119999945163727,316205.5353102305,ok
|
||||
pt_eager,4,64,16,2,4096,0.5686399936676025,112549.24154598081,ok
|
||||
graph_safe_inductor_cg,4,64,16,2,4096,0.03144000098109245,2035623.3461471149,ok
|
||||
luminal_compiled,4,64,16,2,4096,0.1101440005004406,581057.5220549029,ok
|
||||
pt_eager,4,128,16,2,4096,0.5731199979782104,223338.9175941239,ok
|
||||
graph_safe_inductor_cg,4,128,16,2,4096,0.03721600025892258,3439380.887507165,ok
|
||||
luminal_compiled,4,128,16,2,4096,0.10672000050544739,1199400.294169474,ok
|
||||
pt_eager,4,256,16,2,4096,0.5351200103759766,478397.3595383469,ok
|
||||
graph_safe_inductor_cg,4,256,16,2,4096,0.0326399989426136,7843137.508983668,ok
|
||||
luminal_compiled,4,256,16,2,4096,0.10263999924063683,2494154.344251451,ok
|
||||
pt_eager,4,512,16,2,4096,0.5590240061283112,915881.9556712956,ok
|
||||
graph_safe_inductor_cg,4,512,16,2,4096,0.0337119996547699,15187470.492500357,ok
|
||||
luminal_compiled,4,512,16,2,4096,0.10249600186944008,4995316.799304895,ok
|
||||
pt_eager,4,1024,16,2,4096,0.6048000156879425,1693121.6492037117,ok
|
||||
graph_safe_inductor_cg,4,1024,16,2,4096,0.03484800085425377,29384755.937154543,ok
|
||||
luminal_compiled,4,1024,16,2,4096,0.10395199805498123,9850700.507539993,ok
|
||||
pt_eager,4,2048,16,2,4096,0.5883680284023285,3480814.5601677205,ok
|
||||
graph_safe_inductor_cg,4,2048,16,2,4096,0.04467200115323067,45845271.02278446,ok
|
||||
luminal_compiled,4,2048,16,2,4096,0.11158400028944016,18353885.814163756,ok
|
||||
pt_eager,8,2,16,2,4096,0.6239520013332367,3205.3747655692696,ok
|
||||
graph_safe_inductor_cg,8,2,16,2,4096,0.04572800174355507,43736.877268683216,ok
|
||||
luminal_compiled,8,2,16,2,4096,0.10910400003194809,18331.13359193389,ok
|
||||
pt_eager,8,4,16,2,4096,0.6168799996376038,6484.243292617471,ok
|
||||
graph_safe_inductor_cg,8,4,16,2,4096,0.045903999358415604,87138.37695857928,ok
|
||||
luminal_compiled,8,4,16,2,4096,0.10825600102543831,36949.4527980954,ok
|
||||
pt_eager,8,8,16,2,4096,0.6259680092334747,12780.205828403836,ok
|
||||
graph_safe_inductor_cg,8,8,16,2,4096,0.04761600121855736,168010.74838855147,ok
|
||||
luminal_compiled,8,8,16,2,4096,0.11219200119376183,71306.33124355768,ok
|
||||
pt_eager,8,16,16,2,4096,0.6317119896411896,25327.997983840625,ok
|
||||
graph_safe_inductor_cg,8,16,16,2,4096,0.047807998955249786,334672.02873261116,ok
|
||||
luminal_compiled,8,16,16,2,4096,0.10617600008845329,150693.18854233247,ok
|
||||
pt_eager,8,32,16,2,4096,0.6701280176639557,47752.0699873898,ok
|
||||
graph_safe_inductor_cg,8,32,16,2,4096,0.0514880008995533,621504.0289178838,ok
|
||||
luminal_compiled,8,32,16,2,4096,0.11286400258541107,283527.0703409942,ok
|
||||
pt_eager,8,64,16,2,4096,0.7092800140380859,90232.34651098376,ok
|
||||
graph_safe_inductor_cg,8,64,16,2,4096,0.052239999175071716,1225114.873863551,ok
|
||||
luminal_compiled,8,64,16,2,4096,0.11633599922060966,550130.659716395,ok
|
||||
pt_eager,8,128,16,2,4096,0.6788640022277832,188550.28338511224,ok
|
||||
graph_safe_inductor_cg,8,128,16,2,4096,0.053888000547885895,2375296.8879640796,ok
|
||||
luminal_compiled,8,128,16,2,4096,0.11127999797463417,1150251.6384766388,ok
|
||||
pt_eager,8,256,16,2,4096,0.6391039788722992,400560.7983410035,ok
|
||||
graph_safe_inductor_cg,8,256,16,2,4096,0.05567999929189682,4597701.207895956,ok
|
||||
luminal_compiled,8,256,16,2,4096,0.10991999879479408,2328966.5466419603,ok
|
||||
pt_eager,8,512,16,2,4096,0.6350559890270233,806228.1260971039,ok
|
||||
graph_safe_inductor_cg,8,512,16,2,4096,0.0562559999525547,9101251.429746367,ok
|
||||
luminal_compiled,8,512,16,2,4096,0.12015999853610992,4260985.404773753,ok
|
||||
pt_eager,8,1024,16,2,4096,0.6764000058174133,1513897.089285977,ok
|
||||
graph_safe_inductor_cg,8,1024,16,2,4096,0.05718399956822395,17907107.018254407,ok
|
||||
luminal_compiled,8,1024,16,2,4096,0.11070400103926659,9249891.515996683,ok
|
||||
pt_eager,8,2048,16,2,4096,0.6043839752674103,3388574.2901999685,ok
|
||||
graph_safe_inductor_cg,8,2048,16,2,4096,0.06619199737906456,30940296.12479633,ok
|
||||
luminal_compiled,8,2048,16,2,4096,0.12067200243473053,16971625.221084144,ok
|
||||
pt_eager,16,2,16,2,4096,0.7858880162239075,2544.891840455523,ok
|
||||
graph_safe_inductor_cg,16,2,16,2,4096,0.07545600086450577,26505.512843058623,ok
|
||||
luminal_compiled,16,2,16,2,4096,0.13145600259304047,15214.215863474643,ok
|
||||
pt_eager,16,4,16,2,4096,0.7631199955940247,5241.639615125452,ok
|
||||
graph_safe_inductor_cg,16,4,16,2,4096,0.07680000364780426,52083.330859507856,ok
|
||||
luminal_compiled,16,4,16,2,4096,0.12828800082206726,31179.84514816717,ok
|
||||
pt_eager,16,8,16,2,4096,0.7696959972381592,10393.713919139223,ok
|
||||
graph_safe_inductor_cg,16,8,16,2,4096,0.08003199845552444,99960.0179226535,ok
|
||||
luminal_compiled,16,8,16,2,4096,0.1287200003862381,62150.40379113693,ok
|
||||
pt_eager,16,16,16,2,4096,0.7825759947299957,20445.29874126834,ok
|
||||
graph_safe_inductor_cg,16,16,16,2,4096,0.07948800176382065,201288.24029996534,ok
|
||||
luminal_compiled,16,16,16,2,4096,0.13308800011873245,120221.20691366495,ok
|
||||
pt_eager,16,32,16,2,4096,0.8636959791183472,37050.074069657356,ok
|
||||
graph_safe_inductor_cg,16,32,16,2,4096,0.08278399705886841,386548.1389747891,ok
|
||||
luminal_compiled,16,32,16,2,4096,0.13150399923324585,243338.6070886124,ok
|
||||
pt_eager,16,64,16,2,4096,0.878896027803421,72818.62470120825,ok
|
||||
graph_safe_inductor_cg,16,64,16,2,4096,0.09142400324344635,700034.9769149688,ok
|
||||
luminal_compiled,16,64,16,2,4096,0.13363199681043625,478927.2144962949,ok
|
||||
pt_eager,16,128,16,2,4096,0.8490720093250275,150752.82024872545,ok
|
||||
graph_safe_inductor_cg,16,128,16,2,4096,0.0907679982483387,1410188.6399410903,ok
|
||||
luminal_compiled,16,128,16,2,4096,0.1300320029258728,984373.0552467828,ok
|
||||
pt_eager,16,256,16,2,4096,0.812527984380722,315066.07147212717,ok
|
||||
graph_safe_inductor_cg,16,256,16,2,4096,0.09612800180912018,2663115.795419685,ok
|
||||
luminal_compiled,16,256,16,2,4096,0.13608000427484512,1881246.2665929128,ok
|
||||
pt_eager,16,512,16,2,4096,0.8759520053863525,584506.9100266221,ok
|
||||
graph_safe_inductor_cg,16,512,16,2,4096,0.09095999971032143,5628847.863132768,ok
|
||||
luminal_compiled,16,512,16,2,4096,0.13556800037622452,3776702.4561777995,ok
|
||||
pt_eager,16,1024,16,2,4096,0.8617600202560425,1188265.8465587131,ok
|
||||
graph_safe_inductor_cg,16,1024,16,2,4096,0.09561599791049957,10709504.919443557,ok
|
||||
luminal_compiled,16,1024,16,2,4096,0.14022399485111237,7302601.819947201,ok
|
||||
pt_eager,16,2048,16,2,4096,0.8651839792728424,2367126.587019415,ok
|
||||
graph_safe_inductor_cg,16,2048,16,2,4096,0.12144000083208084,16864295.00961416,ok
|
||||
luminal_compiled,16,2048,16,2,4096,0.13519999384880066,15147929.68326875,ok
|
||||
pt_eager,32,2,16,2,4096,1.1254720091819763,1777.0321995423542,ok
|
||||
graph_safe_inductor_cg,32,2,16,2,4096,0.13230399787425995,15116.70117407015,ok
|
||||
luminal_compiled,32,2,16,2,4096,0.1716800034046173,11649.58038407291,ok
|
||||
pt_eager,32,4,16,2,4096,1.1279360055923462,3546.300481736428,ok
|
||||
graph_safe_inductor_cg,32,4,16,2,4096,0.14241600036621094,28086.731755661804,ok
|
||||
luminal_compiled,32,4,16,2,4096,0.17348799854516983,23056.34991205774,ok
|
||||
pt_eager,32,8,16,2,4096,1.0751680135726929,7440.69754588092,ok
|
||||
graph_safe_inductor_cg,32,8,16,2,4096,0.1366880014538765,58527.4487512314,ok
|
||||
luminal_compiled,32,8,16,2,4096,0.16710400581359863,47874.37596752677,ok
|
||||
pt_eager,32,16,16,2,4096,1.1331039667129517,14120.504799232836,ok
|
||||
graph_safe_inductor_cg,32,16,16,2,4096,0.13777600228786469,116130.52878809854,ok
|
||||
luminal_compiled,32,16,16,2,4096,0.17158400267362595,93248.78631275425,ok
|
||||
pt_eager,32,32,16,2,4096,1.1613919734954834,27553.143753601504,ok
|
||||
graph_safe_inductor_cg,32,32,16,2,4096,0.14433600008487701,221704.91063339947,ok
|
||||
luminal_compiled,32,32,16,2,4096,0.17432000488042831,183570.44001891708,ok
|
||||
pt_eager,32,64,16,2,4096,1.1802719831466675,54224.789636514644,ok
|
||||
graph_safe_inductor_cg,32,64,16,2,4096,0.16435199975967407,389408.10025789076,ok
|
||||
luminal_compiled,32,64,16,2,4096,0.17791999876499176,359712.2327127224,ok
|
||||
pt_eager,32,128,16,2,4096,1.1843199729919434,108078.9000599509,ok
|
||||
graph_safe_inductor_cg,32,128,16,2,4096,0.15452799946069717,828328.8494429494,ok
|
||||
luminal_compiled,32,128,16,2,4096,0.1780960038304329,718713.4873720702,ok
|
||||
pt_eager,32,256,16,2,4096,1.168511986732483,219082.0487138128,ok
|
||||
graph_safe_inductor_cg,32,256,16,2,4096,0.15727999806404114,1627670.4167796476,ok
|
||||
luminal_compiled,32,256,16,2,4096,0.17375999689102173,1473296.5272815775,ok
|
||||
pt_eager,32,512,16,2,4096,1.0871520042419434,470955.30156062293,ok
|
||||
graph_safe_inductor_cg,32,512,16,2,4096,0.16273599863052368,3146200.0068125455,ok
|
||||
luminal_compiled,32,512,16,2,4096,0.1789119988679886,2861742.10360135,ok
|
||||
pt_eager,32,1024,16,2,4096,1.1761599779129028,870629.8626289675,ok
|
||||
graph_safe_inductor_cg,32,1024,16,2,4096,0.16993600130081177,6025797.901336805,ok
|
||||
luminal_compiled,32,1024,16,2,4096,0.17558399587869644,5831966.603080603,ok
|
||||
pt_eager,32,2048,16,2,4096,1.2334399819374084,1660396.962958127,ok
|
||||
graph_safe_inductor_cg,32,2048,16,2,4096,0.24084799736738205,8503288.473999824,ok
|
||||
luminal_compiled,32,2048,16,2,4096,0.17404799908399582,11766868.971654378,ok
|
||||
|
255
examples/dlrm/src/bin/check.rs
Normal file
255
examples/dlrm/src/bin/check.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Numerical correctness check: load weights and inputs dumped by
|
||||
//! `correctness_dump.py`, run the luminal DLRM forward, and compare
|
||||
//! element-wise against PyTorch's expected output.
|
||||
//!
|
||||
//! Usage: `cargo run --release --bin check -- [weights/]`.
|
||||
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{
|
||||
dlrm_pairwise_dot_lower_tri_stacked, linear_bias, linear_bias_relu, linear_bias_relu_split_a,
|
||||
linear_bias_sigmoid, stacked_embedding_bag_sum_kernel,
|
||||
};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const BATCH: usize = 2048;
|
||||
const M_DEN: usize = 3;
|
||||
const M_SPA: usize = 16;
|
||||
const L: usize = 2;
|
||||
const LN_BOT: &[usize] = &[M_DEN, 64, M_SPA];
|
||||
const LN_TOP_TAIL: &[usize] = &[64, 32, 1];
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum Act {
|
||||
None,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
}
|
||||
|
||||
struct LinearWB {
|
||||
w: GraphTensor,
|
||||
b: GraphTensor,
|
||||
act: Act,
|
||||
}
|
||||
|
||||
impl LinearWB {
|
||||
fn new(cx: &mut Graph, in_dim: usize, out_dim: usize, name: &str, act: Act) -> Self {
|
||||
Self {
|
||||
w: cx
|
||||
.named_tensor(format!("{name}_w").as_str(), (out_dim, in_dim))
|
||||
.persist(),
|
||||
b: cx
|
||||
.named_tensor(format!("{name}_b").as_str(), out_dim)
|
||||
.persist(),
|
||||
act,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
match self.act {
|
||||
Act::None => linear_bias(x, self.w, self.b),
|
||||
Act::Relu => linear_bias_relu(x, self.w, self.b),
|
||||
Act::Sigmoid => linear_bias_sigmoid(x, self.w, self.b),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_f32(path: &Path) -> Vec<f32> {
|
||||
let bytes = fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("can't read {}: {e}", path.display()));
|
||||
assert_eq!(bytes.len() % 4, 0, "{} not a multiple of 4 bytes", path.display());
|
||||
bytemuck::cast_slice::<u8, f32>(&bytes).to_vec()
|
||||
}
|
||||
|
||||
fn read_i32(path: &Path) -> Vec<i32> {
|
||||
let bytes = fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("can't read {}: {e}", path.display()));
|
||||
assert_eq!(bytes.len() % 4, 0, "{} not a multiple of 4 bytes", path.display());
|
||||
bytemuck::cast_slice::<u8, i32>(&bytes).to_vec()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let weights_dir: PathBuf = std::env::args()
|
||||
.nth(1)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("examples/dlrm/weights"));
|
||||
let manifest_path = weights_dir.join("manifest.json");
|
||||
let manifest_bytes = fs::read(&manifest_path)
|
||||
.unwrap_or_else(|e| panic!("can't read {}: {e}", manifest_path.display()));
|
||||
let manifest_text = std::str::from_utf8(&manifest_bytes).unwrap();
|
||||
// Extract num_cat and rows from the manifest with a tiny parse — we only
|
||||
// need two integers, so avoid pulling in serde_json.
|
||||
let extract = |key: &str| -> usize {
|
||||
let needle = format!("\"{key}\":");
|
||||
let i = manifest_text.find(&needle).expect("key not found");
|
||||
let rest = &manifest_text[i + needle.len()..];
|
||||
let s: String = rest
|
||||
.chars()
|
||||
.skip_while(|c| !c.is_ascii_digit())
|
||||
.take_while(|c| c.is_ascii_digit())
|
||||
.collect();
|
||||
s.parse().unwrap()
|
||||
};
|
||||
let num_cat = extract("num_cat");
|
||||
let rows = extract("rows");
|
||||
let num_fea = num_cat + 1;
|
||||
let pair_count = num_fea * (num_fea - 1) / 2;
|
||||
let top_in = pair_count + M_SPA;
|
||||
let ln_top: Vec<usize> = std::iter::once(top_in)
|
||||
.chain(LN_TOP_TAIL.iter().copied())
|
||||
.collect();
|
||||
println!(
|
||||
"check: num_cat={num_cat} rows={rows} F={num_fea} pairs={pair_count} top_in={top_in}"
|
||||
);
|
||||
println!(" ln_top={ln_top:?}");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dense = cx.named_tensor("dense", (BATCH, M_DEN)).persist();
|
||||
let stacked_w = cx
|
||||
.named_tensor("emb_stacked", (num_cat * rows, M_SPA))
|
||||
.persist();
|
||||
let mut sparse_inputs = Vec::with_capacity(num_cat);
|
||||
for i in 0..num_cat {
|
||||
sparse_inputs.push(
|
||||
cx.named_tensor(format!("idx_{i}").as_str(), (BATCH, L))
|
||||
.as_dtype(DType::Int)
|
||||
.persist(),
|
||||
);
|
||||
}
|
||||
|
||||
// Upstream DLRM applies ReLU to every bot layer (sigmoid_bot=-1 default).
|
||||
let mut bot_layers = Vec::new();
|
||||
for (i, win) in LN_BOT.windows(2).enumerate() {
|
||||
bot_layers.push(LinearWB::new(&mut cx, win[0], win[1], &format!("bot_{i}"), Act::Relu));
|
||||
}
|
||||
let mut h = dense;
|
||||
for l in bot_layers.iter() {
|
||||
h = l.forward(h);
|
||||
}
|
||||
let dense_out = h;
|
||||
|
||||
let row_offsets: Vec<usize> = (0..=num_cat).map(|k| k * rows).collect();
|
||||
let emb_stack =
|
||||
stacked_embedding_bag_sum_kernel(stacked_w, sparse_inputs.clone(), &row_offsets);
|
||||
let interactions = dlrm_pairwise_dot_lower_tri_stacked(dense_out, emb_stack);
|
||||
|
||||
let mut top_layers = Vec::new();
|
||||
let mut prev = top_in;
|
||||
let last_top = LN_TOP_TAIL.len() - 1;
|
||||
for (i, &h) in LN_TOP_TAIL.iter().enumerate() {
|
||||
let act = if i < last_top { Act::Relu } else { Act::Sigmoid };
|
||||
top_layers.push(LinearWB::new(&mut cx, prev, h, &format!("top_{i}"), act));
|
||||
prev = h;
|
||||
}
|
||||
// top_0 reads `dense_out` and `interactions` directly via the
|
||||
// split-A kernel — no materialized concat.
|
||||
let mut t = linear_bias_relu_split_a(
|
||||
dense_out,
|
||||
interactions,
|
||||
top_layers[0].w,
|
||||
top_layers[0].b,
|
||||
);
|
||||
for l in top_layers.iter().skip(1) {
|
||||
t = l.forward(t);
|
||||
}
|
||||
let out_t = t.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
|
||||
// Load weights/biases from disk.
|
||||
for (i, _win) in LN_BOT.windows(2).enumerate() {
|
||||
let l = &bot_layers[i];
|
||||
let w = read_f32(&weights_dir.join(format!("bot_{i}_w.bin")));
|
||||
let b = read_f32(&weights_dir.join(format!("bot_{i}_b.bin")));
|
||||
runtime.set_data(l.w, w);
|
||||
runtime.set_data(l.b, b);
|
||||
}
|
||||
for i in 0..top_layers.len() {
|
||||
let l = &top_layers[i];
|
||||
let w = read_f32(&weights_dir.join(format!("top_{i}_w.bin")));
|
||||
let b = read_f32(&weights_dir.join(format!("top_{i}_b.bin")));
|
||||
runtime.set_data(l.w, w);
|
||||
runtime.set_data(l.b, b);
|
||||
}
|
||||
// Read per-table weight files (as PyTorch dumped them) and concat them
|
||||
// into the single stacked weight that the fused kernel expects.
|
||||
let mut stacked = Vec::with_capacity(num_cat * rows * M_SPA);
|
||||
for i in 0..num_cat {
|
||||
let t = read_f32(&weights_dir.join(format!("emb_{i}.bin")));
|
||||
assert_eq!(t.len(), rows * M_SPA, "emb_{i}.bin shape mismatch");
|
||||
stacked.extend(t);
|
||||
}
|
||||
runtime.set_data(stacked_w, stacked);
|
||||
let dense_data = read_f32(&weights_dir.join("dense.bin"));
|
||||
runtime.set_data(dense, dense_data);
|
||||
for i in 0..num_cat {
|
||||
let idx = read_i32(&weights_dir.join(format!("idx_{i}.bin")));
|
||||
runtime.set_data(sparse_inputs[i], idx);
|
||||
}
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(50).trials(1).keep_best(2), &mut rng);
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let lum_out = runtime.get_f32(out_t);
|
||||
let expected = read_f32(&weights_dir.join("expected.bin"));
|
||||
assert_eq!(
|
||||
lum_out.len(),
|
||||
expected.len(),
|
||||
"output length mismatch: luminal={} expected={}",
|
||||
lum_out.len(),
|
||||
expected.len()
|
||||
);
|
||||
|
||||
let mut max_abs = 0.0f32;
|
||||
let mut sum_abs = 0.0f64;
|
||||
let mut max_rel = 0.0f32;
|
||||
let mut diff_count = 0usize;
|
||||
for (i, (&a, &b)) in lum_out.iter().zip(expected.iter()).enumerate() {
|
||||
let d = (a - b).abs();
|
||||
sum_abs += d as f64;
|
||||
if d > max_abs {
|
||||
max_abs = d;
|
||||
}
|
||||
let r = if b.abs() > 1e-8 { d / b.abs() } else { d };
|
||||
if r > max_rel {
|
||||
max_rel = r;
|
||||
}
|
||||
if d > 1e-4 {
|
||||
diff_count += 1;
|
||||
if diff_count <= 5 {
|
||||
println!(
|
||||
" diff @ {i}: luminal={a:.6} expected={b:.6} abs={d:.3e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mean_abs = sum_abs / lum_out.len() as f64;
|
||||
|
||||
println!();
|
||||
println!(
|
||||
" luminal head: {:?}",
|
||||
&lum_out[..8.min(lum_out.len())]
|
||||
);
|
||||
println!(
|
||||
" expected head: {:?}",
|
||||
&expected[..8.min(expected.len())]
|
||||
);
|
||||
println!(
|
||||
" max abs diff = {max_abs:.3e} mean abs diff = {mean_abs:.3e} max rel diff = {max_rel:.3e}"
|
||||
);
|
||||
println!(" elements with abs diff > 1e-4: {diff_count}/{}", lum_out.len());
|
||||
let tol: f32 = 1e-3;
|
||||
if max_abs < tol {
|
||||
println!("PASS (max abs diff < {tol})");
|
||||
} else {
|
||||
println!("FAIL (max abs diff >= {tol})");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
415
examples/dlrm/src/main.rs
Normal file
415
examples/dlrm/src/main.rs
Normal file
@@ -0,0 +1,415 @@
|
||||
//! DLRM forward-pass benchmark on luminal's CUDA backend.
|
||||
//!
|
||||
//! Mirrors `sweep_categories.py` from https://github.com/jss8649/tmp-dlrm-bench:
|
||||
//! batch=2048, m_spa=16, indices_per_bag=2, rows_per_table=4096,
|
||||
//! ln_bot=[3, 64, 16], top MLP scales with F = num_cat + 1:
|
||||
//! num_int = F*(F-1)/2 + m_spa, ln_top = [num_int, 64, 32, 1].
|
||||
//! arch_interaction_op="dot", arch_interaction_itself=False, sigmoid_top=2.
|
||||
//!
|
||||
//! CLI: `dlrm [--num-cat N] [--rows R]`. Defaults: num-cat=3, rows=4096.
|
||||
//!
|
||||
//! Harness: 5 rounds × 20 timed iters, 10 warmup before round 0,
|
||||
//! median-of-round-medians (same as bench_luminal_exact.py).
|
||||
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{
|
||||
dlrm_pairwise_dot_lower_tri_stacked, linear_bias, linear_bias_relu, linear_bias_relu_split_a,
|
||||
linear_bias_sigmoid, stacked_embedding_bag_sum_kernel,
|
||||
};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::time::Instant;
|
||||
|
||||
// ---- Fixed config (matches sweep_categories.py defaults; per-call
|
||||
// overridable from the CLI for sweep purposes) ----
|
||||
const M_DEN: usize = 3;
|
||||
const LN_TOP_TAIL: &[usize] = &[64, 32, 1]; // → ln_top = [num_int, 64, 32, 1]
|
||||
|
||||
const SEARCH_GRAPHS: usize = 200;
|
||||
const SEARCH_TRIALS: usize = 1;
|
||||
const SEARCH_KEEP_BEST: usize = 4;
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
|
||||
const PRE_ROUND_WARMUP: usize = 10;
|
||||
const TIMED_ITERS_PER_ROUND: usize = 20;
|
||||
const ROUNDS: usize = 5;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum Act {
|
||||
None,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
}
|
||||
|
||||
struct LinearWB {
|
||||
w: GraphTensor, // (out_dim, in_dim) — PyTorch/cuBLASLt convention
|
||||
b: GraphTensor, // (out_dim,)
|
||||
act: Act,
|
||||
}
|
||||
|
||||
impl LinearWB {
|
||||
fn new(cx: &mut Graph, in_dim: usize, out_dim: usize, name: &str, act: Act) -> Self {
|
||||
Self {
|
||||
w: cx
|
||||
.named_tensor(format!("{name}_w").as_str(), (out_dim, in_dim))
|
||||
.persist(),
|
||||
b: cx
|
||||
.named_tensor(format!("{name}_b").as_str(), out_dim)
|
||||
.persist(),
|
||||
act,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
match self.act {
|
||||
Act::None => linear_bias(x, self.w, self.b),
|
||||
Act::Relu => linear_bias_relu(x, self.w, self.b),
|
||||
Act::Sigmoid => linear_bias_sigmoid(x, self.w, self.b),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_indices(table_idx: usize, batch: usize, bag: usize, rows: usize) -> Vec<i32> {
|
||||
// Match sweep_categories.py's `_make_inputs`:
|
||||
// positions = arange(batch * L)
|
||||
// lS_i[k] = (positions * (2k+3) + (k+1)) % ROWS_PER_TABLE
|
||||
let n = rows as i64;
|
||||
let s = (2 * table_idx + 3) as i64;
|
||||
let o = (table_idx + 1) as i64;
|
||||
(0..(batch * bag) as i64)
|
||||
.map(|p| ((p * s + o).rem_euclid(n)) as i32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rand_normal(rng: &mut StdRng, n: usize, std: f32) -> Vec<f32> {
|
||||
let mut out = Vec::with_capacity(n);
|
||||
while out.len() < n {
|
||||
let u1: f32 = rng.random::<f32>().max(1e-9);
|
||||
let u2: f32 = rng.random::<f32>();
|
||||
let r = (-2.0 * u1.ln()).sqrt() * std;
|
||||
let z0 = r * (2.0 * std::f32::consts::PI * u2).cos();
|
||||
let z1 = r * (2.0 * std::f32::consts::PI * u2).sin();
|
||||
out.push(z0);
|
||||
if out.len() < n {
|
||||
out.push(z1);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn rand_uniform(rng: &mut StdRng, n: usize, hi: f32) -> Vec<f32> {
|
||||
(0..n).map(|_| (rng.random::<f32>() * 2.0 - 1.0) * hi).collect()
|
||||
}
|
||||
|
||||
fn build_dense_x(batch: usize, m_den: usize) -> Vec<f32> {
|
||||
let total = batch * m_den;
|
||||
(0..total)
|
||||
.map(|i| -1.0 + 2.0 * (i as f32) / ((total - 1) as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct Args {
|
||||
num_cat: usize,
|
||||
rows: usize,
|
||||
batch: usize,
|
||||
m_spa: usize,
|
||||
bag: usize,
|
||||
print_outputs: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut a = Args {
|
||||
num_cat: 3,
|
||||
rows: 4096,
|
||||
batch: 2048,
|
||||
m_spa: 16,
|
||||
bag: 2,
|
||||
print_outputs: false,
|
||||
};
|
||||
let mut args = std::env::args().skip(1);
|
||||
while let Some(arg) = args.next() {
|
||||
match arg.as_str() {
|
||||
"--num-cat" => {
|
||||
a.num_cat = args.next().expect("missing value for --num-cat").parse().unwrap();
|
||||
}
|
||||
"--rows" => {
|
||||
a.rows = args.next().expect("missing value for --rows").parse().unwrap();
|
||||
}
|
||||
"--batch" => {
|
||||
a.batch = args.next().expect("missing value for --batch").parse().unwrap();
|
||||
}
|
||||
"--m-spa" => {
|
||||
a.m_spa = args.next().expect("missing value for --m-spa").parse().unwrap();
|
||||
}
|
||||
"--bag" => {
|
||||
a.bag = args.next().expect("missing value for --bag").parse().unwrap();
|
||||
}
|
||||
"--print-outputs" => {
|
||||
a.print_outputs = true;
|
||||
}
|
||||
"-h" | "--help" => {
|
||||
eprintln!(
|
||||
"usage: dlrm [--num-cat N] [--rows R] [--batch B] [--m-spa D] [--bag L] [--print-outputs]"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => {
|
||||
eprintln!("unknown arg: {other}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
let Args {
|
||||
num_cat,
|
||||
rows,
|
||||
batch,
|
||||
m_spa,
|
||||
bag,
|
||||
print_outputs,
|
||||
} = args;
|
||||
let num_fea = num_cat + 1;
|
||||
let pair_count = num_fea * (num_fea - 1) / 2;
|
||||
let top_in = pair_count + m_spa;
|
||||
// ln_bot last entry must equal m_spa (bot output feeds into interaction
|
||||
// alongside emb rows that are m_spa wide).
|
||||
let ln_bot: Vec<usize> = vec![M_DEN, 64, m_spa];
|
||||
let ln_top: Vec<usize> = std::iter::once(top_in)
|
||||
.chain(LN_TOP_TAIL.iter().copied())
|
||||
.collect();
|
||||
|
||||
println!(
|
||||
"==== DLRM luminal config ==== num_cat={num_cat} F={num_fea} \
|
||||
rows/table={rows} batch={batch} m_spa={m_spa} L={bag}"
|
||||
);
|
||||
println!(" ln_bot={ln_bot:?} ln_top={ln_top:?} pairs={pair_count}");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let dense = cx
|
||||
.named_tensor("dense", (batch, M_DEN))
|
||||
.persist();
|
||||
let stacked_w = cx
|
||||
.named_tensor("emb_stacked", (num_cat * rows, m_spa))
|
||||
.persist();
|
||||
let mut sparse_inputs = Vec::with_capacity(num_cat);
|
||||
for i in 0..num_cat {
|
||||
sparse_inputs.push(
|
||||
cx.named_tensor(format!("idx_{i}").as_str(), (batch, bag))
|
||||
.as_dtype(DType::Int)
|
||||
.persist(),
|
||||
);
|
||||
}
|
||||
|
||||
// Upstream `dlrm_s_pytorch.DLRM_Net.create_mlp` with `sigmoid_bot=-1`
|
||||
// (the default) applies ReLU to every bot layer — including the final
|
||||
// one. The output of the bot MLP feeds the interaction op AFTER ReLU.
|
||||
let mut bot_layers = Vec::new();
|
||||
for (i, win) in ln_bot.windows(2).enumerate() {
|
||||
bot_layers.push(LinearWB::new(&mut cx, win[0], win[1], &format!("bot_{i}"), Act::Relu));
|
||||
}
|
||||
let mut h = dense;
|
||||
for l in bot_layers.iter() {
|
||||
h = l.forward(h);
|
||||
}
|
||||
let dense_out = h;
|
||||
|
||||
// One fused kernel for all num_cat tables. row_offsets[k] = k * rows
|
||||
// (uniform rows-per-table in this benchmark, matching upstream).
|
||||
let row_offsets: Vec<usize> = (0..=num_cat).map(|k| k * rows).collect();
|
||||
let emb_stack =
|
||||
stacked_embedding_bag_sum_kernel(stacked_w, sparse_inputs.clone(), &row_offsets);
|
||||
// emb_stack: (BATCH, num_cat, M_SPA)
|
||||
|
||||
// Feature interaction over [dense_out, emb_stack[:, 0, :], …,
|
||||
// emb_stack[:, num_cat-1, :]] in one fused kernel — no per-table slice.
|
||||
let interactions = dlrm_pairwise_dot_lower_tri_stacked(dense_out, emb_stack);
|
||||
|
||||
// Skip the materialized `cat(dense_out, interactions)`: the first top
|
||||
// layer reads both halves directly via the split-A matmul kernel. The
|
||||
// remaining top layers see the dense fused output and stay vanilla.
|
||||
let mut top_layers = Vec::new();
|
||||
let mut prev = top_in;
|
||||
let last_top = LN_TOP_TAIL.len() - 1;
|
||||
for (i, &h) in LN_TOP_TAIL.iter().enumerate() {
|
||||
let act = if i < last_top {
|
||||
Act::Relu
|
||||
} else {
|
||||
Act::Sigmoid
|
||||
};
|
||||
top_layers.push(LinearWB::new(&mut cx, prev, h, &format!("top_{i}"), act));
|
||||
prev = h;
|
||||
}
|
||||
// top_0: split-A matmul (no concat materialization).
|
||||
let mut t = linear_bias_relu_split_a(
|
||||
dense_out,
|
||||
interactions,
|
||||
top_layers[0].w,
|
||||
top_layers[0].b,
|
||||
);
|
||||
// top_1..: standard fused linear_bias_(relu|sigmoid) kernels.
|
||||
for l in top_layers.iter().skip(1) {
|
||||
t = l.forward(t);
|
||||
}
|
||||
let out = t.output();
|
||||
|
||||
println!("Building E-graph...");
|
||||
let t = Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
println!(" build: {:.2}s", t.elapsed().as_secs_f64());
|
||||
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
|
||||
for (i, win) in ln_bot.windows(2).enumerate() {
|
||||
let (a, b) = (win[0], win[1]);
|
||||
let l = &bot_layers[i];
|
||||
runtime.set_data(
|
||||
l.w,
|
||||
rand_normal(&mut rng, a * b, (2.0 / (a + b) as f32).sqrt()),
|
||||
);
|
||||
runtime.set_data(l.b, rand_normal(&mut rng, b, (1.0 / b as f32).sqrt()));
|
||||
}
|
||||
let mut top_shapes: Vec<(usize, usize)> = vec![];
|
||||
let mut prev = top_in;
|
||||
for &h in LN_TOP_TAIL.iter() {
|
||||
top_shapes.push((prev, h));
|
||||
prev = h;
|
||||
}
|
||||
for (i, &(a, b)) in top_shapes.iter().enumerate() {
|
||||
let l = &top_layers[i];
|
||||
runtime.set_data(
|
||||
l.w,
|
||||
rand_normal(&mut rng, a * b, (2.0 / (a + b) as f32).sqrt()),
|
||||
);
|
||||
runtime.set_data(l.b, rand_normal(&mut rng, b, (1.0 / b as f32).sqrt()));
|
||||
}
|
||||
|
||||
// One stacked weight: concatenated per-table uniform inits.
|
||||
let mut stacked_data: Vec<f32> = Vec::with_capacity(num_cat * rows * m_spa);
|
||||
for _ in 0..num_cat {
|
||||
stacked_data
|
||||
.extend(rand_uniform(&mut rng, rows * m_spa, 1.0 / (rows as f32).sqrt()));
|
||||
}
|
||||
runtime.set_data(stacked_w, stacked_data);
|
||||
|
||||
runtime.set_data(dense, build_dense_x(batch, M_DEN));
|
||||
for i in 0..num_cat {
|
||||
runtime.set_data(sparse_inputs[i], build_indices(i, batch, bag, rows));
|
||||
}
|
||||
|
||||
println!("Searching/compiling...");
|
||||
let t = Instant::now();
|
||||
runtime = cx.search_options(
|
||||
runtime,
|
||||
SearchOptions::new(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
println!(" search/compile: {:.2}s", t.elapsed().as_secs_f64());
|
||||
|
||||
{
|
||||
let host_ops = runtime.host_ops();
|
||||
let total = host_ops.len();
|
||||
let cublaslt = host_ops
|
||||
.iter()
|
||||
.filter(|op| format!("{op:?}").contains("CuBlasLt"))
|
||||
.count();
|
||||
let cudagraph = host_ops
|
||||
.iter()
|
||||
.filter(|op| format!("{op:?}").contains("CudaGraph"))
|
||||
.count();
|
||||
println!("Host ops: total={total} cuBLASLt={cublaslt} CudaGraph={cudagraph}");
|
||||
}
|
||||
|
||||
for _ in 0..PRE_ROUND_WARMUP {
|
||||
runtime.execute(&cx.dyn_map);
|
||||
}
|
||||
let _ = runtime.get_f32(out);
|
||||
|
||||
let mut round_medians = Vec::with_capacity(ROUNDS);
|
||||
for round in 0..ROUNDS {
|
||||
let mut times_us = Vec::with_capacity(TIMED_ITERS_PER_ROUND);
|
||||
for _ in 0..TIMED_ITERS_PER_ROUND {
|
||||
let t = Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
// execute() no longer syncs at end (so it stays capturable by
|
||||
// torch.cuda.CUDAGraph) — sync explicitly for accurate GPU time.
|
||||
runtime.synchronize_stream();
|
||||
times_us.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
let _ = runtime.get_f32(out);
|
||||
times_us.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let median = times_us[times_us.len() / 2] / 1000.0;
|
||||
println!(
|
||||
" round {round}: median {median:.3} ms min {:.3} max {:.3}",
|
||||
times_us[0] / 1000.0,
|
||||
times_us[times_us.len() - 1] / 1000.0
|
||||
);
|
||||
round_medians.push(median);
|
||||
}
|
||||
round_medians.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let med_of_med = round_medians[round_medians.len() / 2];
|
||||
let tput = batch as f64 / (med_of_med / 1000.0);
|
||||
println!();
|
||||
println!(
|
||||
"==== cfg(num_cat={num_cat},batch={batch},m_spa={m_spa},bag={bag},rows={rows}): \
|
||||
luminal median-of-round-medians = {med_of_med:.4} ms ({tput:.0} samples/s)"
|
||||
);
|
||||
if print_outputs {
|
||||
let sample = runtime.get_f32(out);
|
||||
let head: Vec<f32> = sample.iter().take(8).copied().collect();
|
||||
println!(" output[0..8] = {head:?} (len={})", sample.len());
|
||||
}
|
||||
|
||||
// Per-kernel breakdown. Available when `LUMINAL_KERNEL_TIMING=1` was
|
||||
// set before the first execute (it gates the event-record-node
|
||||
// insertion at graph-build time). `get_f32` above already
|
||||
// synchronized the stream, so the events have valid data.
|
||||
if std::env::var_os("LUMINAL_KERNEL_TIMING").is_some() {
|
||||
let timings = runtime.read_per_kernel_timings_ms();
|
||||
if timings.is_empty() {
|
||||
println!(" (no per-kernel timings available — events not recorded)");
|
||||
} else {
|
||||
// Aggregate per kernel name in case the same op appears more
|
||||
// than once (e.g. two Matmul2D_BiasRelu calls in the bot MLP).
|
||||
let mut sum_per_name: std::collections::BTreeMap<&'static str, (f32, usize)> =
|
||||
std::collections::BTreeMap::new();
|
||||
let mut total = 0.0_f32;
|
||||
for (name, ms) in &timings {
|
||||
let e = sum_per_name.entry(*name).or_insert((0.0, 0));
|
||||
e.0 += *ms;
|
||||
e.1 += 1;
|
||||
total += *ms;
|
||||
}
|
||||
println!();
|
||||
println!(" per-kernel GPU time (single replay, ms):");
|
||||
println!(
|
||||
" {:>40} {:>3} {:>10} {:>10} {:>5}",
|
||||
"kernel", "n", "total_ms", "each_ms", "pct"
|
||||
);
|
||||
// Sort by total descending so the bottleneck is on top.
|
||||
let mut sorted: Vec<_> = sum_per_name.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.0.partial_cmp(&a.1.0).unwrap());
|
||||
for (name, (sum, n)) in sorted {
|
||||
let pct = if total > 0.0 { 100.0 * sum / total } else { 0.0 };
|
||||
println!(
|
||||
" {:>40} {:>3} {:>10.4} {:>10.4} {:>4.1}%",
|
||||
name,
|
||||
n,
|
||||
sum,
|
||||
sum / (*n as f32),
|
||||
pct
|
||||
);
|
||||
}
|
||||
println!(" {:>40} {:>10.4}", "TOTAL", total);
|
||||
}
|
||||
}
|
||||
}
|
||||
101
examples/dlrm/summarize.py
Normal file
101
examples/dlrm/summarize.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Summarize a results.csv from sweep_all.py: pretty per-cell tables and
|
||||
ratios of luminal_compiled vs the PT baselines.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("csv", default="results.csv", nargs="?")
|
||||
args = ap.parse_args()
|
||||
|
||||
p = Path(args.csv)
|
||||
rows = list(csv.DictReader(p.open()))
|
||||
# Index: (num_cat, batch, variant) -> ms
|
||||
table: dict[tuple[int, int, str], float | None] = {}
|
||||
nums_cat: set[int] = set()
|
||||
batches: set[int] = set()
|
||||
variants: set[str] = set()
|
||||
for r in rows:
|
||||
if r["status"] != "ok":
|
||||
continue
|
||||
nc = int(r["num_cat"])
|
||||
bs = int(r["batch"])
|
||||
v = r["variant"]
|
||||
ms = float(r["ms"]) if r["ms"] else None
|
||||
if ms is None:
|
||||
continue
|
||||
table[(nc, bs, v)] = ms
|
||||
nums_cat.add(nc)
|
||||
batches.add(bs)
|
||||
variants.add(v)
|
||||
|
||||
nums_cat_s = sorted(nums_cat)
|
||||
batches_s = sorted(batches)
|
||||
variants_s = sorted(variants)
|
||||
|
||||
def fmt_ms(ms: float | None) -> str:
|
||||
if ms is None:
|
||||
return " - "
|
||||
if ms < 1.0:
|
||||
return f"{ms*1000:>5.0f} us"
|
||||
return f"{ms:>6.3f} ms"
|
||||
|
||||
print(f"# DLRMv1 sweep — {len(rows)} rows, {len(variants_s)} variants")
|
||||
print()
|
||||
print(f"Configurations: batch ∈ {batches_s}, num_cat ∈ {nums_cat_s}")
|
||||
print()
|
||||
|
||||
# Per-variant table.
|
||||
for v in variants_s:
|
||||
print(f"## {v}")
|
||||
print()
|
||||
hdr = "| nc \\ batch | " + " | ".join(f"{b:>6}" for b in batches_s) + " |"
|
||||
print(hdr)
|
||||
print("|" + "|".join(["---"] * (len(batches_s) + 1)) + "|")
|
||||
for nc in nums_cat_s:
|
||||
cells = [fmt_ms(table.get((nc, bs, v))) for bs in batches_s]
|
||||
print(f"| {nc:>3} | " + " | ".join(cells) + " |")
|
||||
print()
|
||||
|
||||
# Speedup ratios: luminal_compiled vs each PT variant
|
||||
if "luminal_compiled" in variants_s:
|
||||
for ref in ["pt_eager", "graph_safe_inductor_cg", "graph_safe_cg", "v3_inductor_cg"]:
|
||||
if ref not in variants_s:
|
||||
continue
|
||||
print(f"## speedup luminal_compiled vs {ref} (>1 = luminal faster)")
|
||||
print()
|
||||
print("| nc \\ batch | " + " | ".join(f"{b:>6}" for b in batches_s) + " |")
|
||||
print("|" + "|".join(["---"] * (len(batches_s) + 1)) + "|")
|
||||
wins = 0
|
||||
cells_total = 0
|
||||
for nc in nums_cat_s:
|
||||
rcells = []
|
||||
for bs in batches_s:
|
||||
lum = table.get((nc, bs, "luminal_compiled"))
|
||||
pt = table.get((nc, bs, ref))
|
||||
if lum is None or pt is None or lum <= 0:
|
||||
rcells.append(" - ")
|
||||
continue
|
||||
cells_total += 1
|
||||
r = pt / lum
|
||||
if r >= 1.0:
|
||||
wins += 1
|
||||
rcells.append(f"{r:>5.2f}x")
|
||||
print(f"| {nc:>3} | " + " | ".join(rcells) + " |")
|
||||
print()
|
||||
if cells_total:
|
||||
print(f"luminal faster in {wins}/{cells_total} cells ({100*wins/cells_total:.0f}%)")
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
115
examples/dlrm/sweep_all.py
Normal file
115
examples/dlrm/sweep_all.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Drive the full DLRM sweep across (batch, num_cat) and produce a CSV.
|
||||
|
||||
For each cell, we shell out to `sweep_pytorch.py` (one process per
|
||||
variant, so torch.compile cache state is fresh) and to `sweep_luminal.py`
|
||||
(also separate, so the luminal compiled graph builds from clean state).
|
||||
|
||||
Each cell × variant → one row in `results.csv`. Run from this dir:
|
||||
|
||||
python sweep_all.py [--variant ...] [--quick]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
PY = sys.executable
|
||||
|
||||
|
||||
def _powers(lo: int, hi: int) -> list[int]:
|
||||
out = []
|
||||
v = lo
|
||||
while v <= hi:
|
||||
out.append(v)
|
||||
v *= 2
|
||||
return out
|
||||
|
||||
|
||||
def _run_json(cmd: list[str]) -> dict | None:
|
||||
try:
|
||||
r = subprocess.run(cmd, cwd=HERE, capture_output=True, text=True, timeout=600)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f" TIMEOUT: {' '.join(cmd)}", file=sys.stderr)
|
||||
return None
|
||||
if r.returncode != 0:
|
||||
print(f" ERR : {' '.join(cmd)}\n stderr tail: {r.stderr[-400:]}", file=sys.stderr)
|
||||
return None
|
||||
for line in r.stdout.splitlines()[::-1]:
|
||||
s = line.strip()
|
||||
if s.startswith("{") and s.endswith("}"):
|
||||
try:
|
||||
return json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
PT_VARIANTS = ["pt_eager", "graph_safe_cg", "graph_safe_inductor_cg", "v3_inductor_cg"]
|
||||
LUMINAL_VARIANTS = ["luminal_compiled"]
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--batch", type=int, nargs="+", default=None)
|
||||
ap.add_argument("--num-cat", type=int, nargs="+", default=None)
|
||||
ap.add_argument("--variants", nargs="+", default=PT_VARIANTS + LUMINAL_VARIANTS)
|
||||
ap.add_argument("--quick", action="store_true",
|
||||
help="Smaller sweep (nc∈{2,8,32}, batch∈{32,512,2048}) for fast iteration.")
|
||||
ap.add_argument("--out", default="results.csv")
|
||||
args = ap.parse_args()
|
||||
|
||||
if args.quick:
|
||||
nums_cat = [2, 8, 32]
|
||||
batches = [32, 512, 2048]
|
||||
else:
|
||||
nums_cat = args.num_cat or _powers(2, 32)
|
||||
batches = args.batch or _powers(2, 2048)
|
||||
|
||||
rows: list[dict] = []
|
||||
t0 = time.time()
|
||||
total_cells = len(nums_cat) * len(batches) * len(args.variants)
|
||||
seen = 0
|
||||
for nc in nums_cat:
|
||||
for bs in batches:
|
||||
for variant in args.variants:
|
||||
seen += 1
|
||||
cmd_base = ["--num-cat", str(nc), "--batch", str(bs), "--json"]
|
||||
if variant == "luminal_compiled":
|
||||
cmd = [PY, "sweep_luminal.py"] + cmd_base
|
||||
else:
|
||||
cmd = [PY, "sweep_pytorch.py", "--variant", variant] + cmd_base
|
||||
r = _run_json(cmd)
|
||||
if r is None:
|
||||
print(f"[{seen}/{total_cells}] FAIL nc={nc:>2} batch={bs:>4} variant={variant}", flush=True)
|
||||
rows.append({"variant": variant, "num_cat": nc, "batch": bs, "ms": None,
|
||||
"samples_per_sec": None, "status": "failed"})
|
||||
else:
|
||||
r["status"] = "ok"
|
||||
rows.append(r)
|
||||
print(f"[{seen}/{total_cells}] {r['ms']:8.4f} ms nc={nc:>2} batch={bs:>4} variant={variant}", flush=True)
|
||||
|
||||
out_path = HERE / args.out
|
||||
with open(out_path, "w", newline="") as f:
|
||||
w = csv.DictWriter(f, fieldnames=["variant", "num_cat", "batch", "m_spa", "bag", "rows", "ms",
|
||||
"samples_per_sec", "status"])
|
||||
w.writeheader()
|
||||
for row in rows:
|
||||
row = dict(row)
|
||||
row.setdefault("m_spa", 16)
|
||||
row.setdefault("bag", 2)
|
||||
row.setdefault("rows", 4096)
|
||||
row.setdefault("status", "ok")
|
||||
w.writerow(row)
|
||||
print(f"\nWrote {out_path} ({len(rows)} rows, {time.time()-t0:.1f}s)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
98
examples/dlrm/sweep_luminal.py
Normal file
98
examples/dlrm/sweep_luminal.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Time DLRMv1 under `torch.compile(model, backend=luminal_backend)`.
|
||||
|
||||
Sibling to sweep_pytorch.py — same shape, same harness, same JSON output.
|
||||
|
||||
Strategy: import DLRMv1/make_inputs from sweep_pytorch (so the model
|
||||
definition is the single source of truth across all variants), warm up,
|
||||
then time 5 rounds × 20 iters × 10 warmup as elsewhere.
|
||||
|
||||
Outputs one JSON line per variant (currently `luminal_compiled`).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Make the sibling import resolve regardless of CWD.
|
||||
sys.path.insert(0, "/home/ubuntu/luminal/examples/dlrm")
|
||||
from sweep_pytorch import DLRMv1, make_inputs, time_rounds # noqa: E402
|
||||
|
||||
import luminal # noqa: E402
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
pr = torch._dynamo.config.recompile_limit
|
||||
pc = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 64
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = pr
|
||||
torch._dynamo.config.cache_size_limit = pc
|
||||
|
||||
|
||||
def _run_luminal_compiled(num_cat: int, batch: int, rows: int, m_spa: int, bag: int, device):
|
||||
"""Time `torch.compile(model, backend=luminal_backend)`.
|
||||
|
||||
Do NOT wrap the compiled call in `torch.cuda.CUDAGraph` — luminal's
|
||||
`cuda_lite` runtime already captures and replays a CUDA graph
|
||||
internally (one host op per `execute()` call), so an external wrap
|
||||
would just be hiding Python-wrapper / FFI overhead rather than
|
||||
measuring luminal's actual perf.
|
||||
"""
|
||||
plain = DLRMv1(num_cat, rows, m_spa).eval().to(device)
|
||||
inputs = make_inputs(num_cat, batch, bag, rows, device)
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(plain, backend=luminal_backend, fullgraph=False, dynamic=False)
|
||||
# Warm up: triggers the export + translate + search.
|
||||
with torch.no_grad():
|
||||
for _ in range(3):
|
||||
_ = compiled(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
return time_rounds(lambda: compiled(*inputs))
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--num-cat", type=int, default=3)
|
||||
ap.add_argument("--rows", type=int, default=4096)
|
||||
ap.add_argument("--batch", type=int, default=2048)
|
||||
ap.add_argument("--m-spa", type=int, default=16)
|
||||
ap.add_argument("--bag", type=int, default=2)
|
||||
ap.add_argument("--json", action="store_true")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda")
|
||||
ms = _run_luminal_compiled(args.num_cat, args.batch, args.rows, args.m_spa, args.bag, device)
|
||||
rec = {
|
||||
"variant": "luminal_compiled",
|
||||
"num_cat": args.num_cat,
|
||||
"batch": args.batch,
|
||||
"m_spa": args.m_spa,
|
||||
"bag": args.bag,
|
||||
"rows": args.rows,
|
||||
"ms": ms,
|
||||
"samples_per_sec": args.batch / (ms / 1000.0),
|
||||
}
|
||||
if args.json:
|
||||
print(json.dumps(rec))
|
||||
else:
|
||||
print(
|
||||
f"==== luminal_compiled cfg: num_cat={args.num_cat} batch={args.batch} ===="
|
||||
)
|
||||
print(f" luminal_compiled {ms:8.4f} ms "
|
||||
f"({rec['samples_per_sec']:>12,.0f} samples/s)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
332
examples/dlrm/sweep_pytorch.py
Normal file
332
examples/dlrm/sweep_pytorch.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Parameterized PyTorch sweep for DLRMv1 across (batch, num_cat).
|
||||
|
||||
Self-contained: defines DLRMv1 inline (vanilla `nn.EmbeddingBag` per table,
|
||||
strict lower-tri pairwise-dot interaction, top MLP), so it does NOT need
|
||||
facebookresearch/dlrm cloned locally.
|
||||
|
||||
Variants measured (toggle with `--variant`):
|
||||
|
||||
* `pt_eager` — eager forward of plain DLRMv1, no CG.
|
||||
* `graph_safe_cg` — `GraphSafeDLRM` wrap of plain DLRMv1, captured
|
||||
via `torch.cuda.CUDAGraph`, replayed.
|
||||
* `graph_safe_inductor_cg` — `GraphSafeDLRM` wrap, then `torch.compile(
|
||||
backend="inductor", mode="max-autotune-
|
||||
no-cudagraphs", fullgraph=False)`, then
|
||||
manual `torch.cuda.CUDAGraph` capture/replay.
|
||||
This is the user-named primary baseline.
|
||||
* `v3_inductor_cg` — `FusedDLRMv3` (stacked-emb-table rewrite,
|
||||
index_select+sum), `torch.compile(inductor,
|
||||
max-autotune-no-cudagraphs)`, manual CG.
|
||||
* `all` — all of the above.
|
||||
|
||||
Harness: 5 rounds × 20 timed iters with 10 warmup before round 0, median
|
||||
of CUDA-event-timed per-iter ms, then median across rounds.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# ---- Fixed dials matching the upstream sweep_categories.py defaults ------
|
||||
M_DEN = 3
|
||||
M_SPA = 16
|
||||
INDICES_PER_BAG = 2
|
||||
LN_BOT_TAIL = [64] # → ln_bot = [M_DEN, 64, M_SPA]
|
||||
LN_TOP_TAIL = [64, 32, 1] # → ln_top = [num_int, 64, 32, 1]
|
||||
SEED = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
pr = torch._dynamo.config.recompile_limit
|
||||
pc = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 64
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = pr
|
||||
torch._dynamo.config.cache_size_limit = pc
|
||||
|
||||
|
||||
def _build_mlp(layer_sizes: List[int], sigmoid_layer: int) -> nn.Sequential:
|
||||
layers: List[nn.Module] = []
|
||||
for i, (a, b) in enumerate(zip(layer_sizes, layer_sizes[1:])):
|
||||
layers.append(nn.Linear(a, b, bias=True))
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Vanilla DLRMv1 — what a user writes with nn.EmbeddingBag.
|
||||
# This is the model the luminal torch.compile backend must ingest.
|
||||
# ----------------------------------------------------------------------
|
||||
class DLRMv1(nn.Module):
|
||||
def __init__(self, num_cat: int, rows: int, m_spa: int):
|
||||
super().__init__()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
self.num_cat = num_cat
|
||||
self.rows = rows
|
||||
self.m_spa = m_spa
|
||||
self.num_fea = num_cat + 1 # +1 for dense
|
||||
self.pair_count = self.num_fea * (self.num_fea - 1) // 2
|
||||
self.top_in = self.pair_count + m_spa
|
||||
ln_bot = [M_DEN] + LN_BOT_TAIL + [m_spa]
|
||||
ln_top = [self.top_in] + LN_TOP_TAIL
|
||||
self.bot = _build_mlp(ln_bot, sigmoid_layer=-1)
|
||||
self.top = _build_mlp(ln_top, sigmoid_layer=len(ln_top) - 2)
|
||||
self.emb = nn.ModuleList(
|
||||
[nn.EmbeddingBag(rows, m_spa, mode="sum", sparse=False) for _ in range(num_cat)]
|
||||
)
|
||||
li, lj = [], []
|
||||
for i in range(self.num_fea):
|
||||
for j in range(i):
|
||||
li.append(i)
|
||||
lj.append(j)
|
||||
self.register_buffer("li", torch.tensor(li, dtype=torch.long), persistent=False)
|
||||
self.register_buffer("lj", torch.tensor(lj, dtype=torch.long), persistent=False)
|
||||
|
||||
def forward(self, dense_x, lS_o, lS_i):
|
||||
x = self.bot(dense_x) # (B, M_SPA)
|
||||
ly = [self.emb[k](lS_i[k], lS_o[k]) for k in range(self.num_cat)] # each (B, M_SPA)
|
||||
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1) # (B, F, D)
|
||||
Z = torch.bmm(T, T.transpose(1, 2)) # (B, F, F)
|
||||
Zflat = Z[:, self.li, self.lj] # (B, PAIRS)
|
||||
R = torch.cat([x, Zflat], dim=1)
|
||||
return self.top(R)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Graph-safe wrapper: identical math, but the strict-lower-tri (li, lj)
|
||||
# indices are pre-baked into buffers (so the forward does no Python-side
|
||||
# allocs, making it `torch.cuda.CUDAGraph`-capturable).
|
||||
# ----------------------------------------------------------------------
|
||||
class GraphSafeDLRM(nn.Module):
|
||||
def __init__(self, inner: DLRMv1):
|
||||
super().__init__()
|
||||
self.inner = inner
|
||||
# Reuse the inner's li/lj buffers — they're already pre-computed.
|
||||
|
||||
def forward(self, dense_x, lS_o, lS_i):
|
||||
inner = self.inner
|
||||
x = inner.bot(dense_x)
|
||||
ly = [inner.emb[k](lS_i[k], lS_o[k]) for k in range(inner.num_cat)]
|
||||
T = torch.cat([x.unsqueeze(1)] + [e.unsqueeze(1) for e in ly], dim=1)
|
||||
Z = torch.bmm(T, T.transpose(1, 2))
|
||||
Zflat = Z[:, inner.li, inner.lj]
|
||||
R = torch.cat([x, Zflat], dim=1)
|
||||
return inner.top(R)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Fused v3: stacked embedding table + index_select + view + sum. Same
|
||||
# math as `mode="sum"` EmbeddingBag at fixed bag size; Inductor can
|
||||
# fuse the gather+sum into one Triton kernel.
|
||||
# ----------------------------------------------------------------------
|
||||
class FusedDLRMv3(nn.Module):
|
||||
def __init__(self, plain: DLRMv1, bag: int):
|
||||
super().__init__()
|
||||
self.num_emb = plain.num_cat
|
||||
self.m_spa = plain.m_spa
|
||||
self.L = bag
|
||||
self.num_fea = plain.num_fea
|
||||
total_rows = plain.rows * plain.num_cat
|
||||
big = torch.empty(total_rows, plain.m_spa, dtype=torch.float32)
|
||||
starts = np.arange(plain.num_cat, dtype=np.int64) * plain.rows
|
||||
for k in range(plain.num_cat):
|
||||
big[starts[k] : starts[k] + plain.rows].copy_(plain.emb[k].weight.detach())
|
||||
self.emb_weight = nn.Parameter(big)
|
||||
self.register_buffer("row_offsets", torch.from_numpy(starts))
|
||||
self.bot = copy.deepcopy(plain.bot)
|
||||
self.top = copy.deepcopy(plain.top)
|
||||
self.register_buffer("li", plain.li.clone(), persistent=False)
|
||||
self.register_buffer("lj", plain.lj.clone(), persistent=False)
|
||||
|
||||
def pack_indices(self, lS_i):
|
||||
return (torch.stack(lS_i, dim=0) + self.row_offsets.view(self.num_emb, 1)).reshape(-1)
|
||||
|
||||
def forward(self, dense_x, flat_indices):
|
||||
bs = dense_x.shape[0]
|
||||
g = (
|
||||
torch.index_select(self.emb_weight, 0, flat_indices)
|
||||
.view(self.num_emb * bs, self.L, self.m_spa)
|
||||
.sum(dim=1)
|
||||
)
|
||||
ly = g.view(self.num_emb, bs, self.m_spa).transpose(0, 1)
|
||||
x = self.bot(dense_x)
|
||||
T = torch.cat([x.unsqueeze(1), ly], dim=1)
|
||||
Z = torch.bmm(T, T.transpose(1, 2))
|
||||
Zflat = Z[:, self.li, self.lj]
|
||||
R = torch.cat([x, Zflat], dim=1)
|
||||
return self.top(R)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Inputs match sweep_categories.py's `_make_inputs` for parity with the
|
||||
# hand-written luminal example.
|
||||
# ----------------------------------------------------------------------
|
||||
def make_inputs(num_cat: int, batch: int, bag: int, rows: int, device):
|
||||
dense = torch.linspace(
|
||||
-1.0, 1.0, steps=batch * M_DEN, dtype=torch.float32, device=device
|
||||
).reshape(batch, M_DEN)
|
||||
total = batch * bag
|
||||
positions = torch.arange(total, dtype=torch.int64, device=device)
|
||||
offsets = torch.arange(0, total, bag, dtype=torch.int64, device=device)
|
||||
lS_o = [offsets.clone() for _ in range(num_cat)]
|
||||
lS_i = [((positions * (2 * k + 3) + (k + 1)) % rows).to(torch.int64) for k in range(num_cat)]
|
||||
return dense, lS_o, lS_i
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Timing: 5 rounds × 20 iters × 10 warmup, median-of-round-medians,
|
||||
# CUDA event timing.
|
||||
# ----------------------------------------------------------------------
|
||||
def time_rounds(call, rounds: int = 5, iters: int = 20, warmup: int = 10) -> float:
|
||||
medians = []
|
||||
for r in range(rounds):
|
||||
if r == 0:
|
||||
for _ in range(warmup):
|
||||
call()
|
||||
torch.cuda.synchronize()
|
||||
s = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
e = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
for i in range(iters):
|
||||
s[i].record()
|
||||
call()
|
||||
e[i].record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = np.array([a.elapsed_time(b) for a, b in zip(s, e)])
|
||||
medians.append(float(np.median(elapsed)))
|
||||
return float(np.median(medians))
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _capture_cudagraph(callable_fn, warmup: int = 3) -> torch.cuda.CUDAGraph:
|
||||
"""Warm up on a side stream, then capture a CUDA graph of one call."""
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s), torch.no_grad():
|
||||
for _ in range(warmup):
|
||||
callable_fn()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
torch.cuda.synchronize()
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.no_grad(), torch.cuda.graph(g):
|
||||
_ = callable_fn()
|
||||
return g
|
||||
|
||||
|
||||
def _run_variant(variant: str, num_cat: int, batch: int, rows: int, m_spa: int, bag: int, device):
|
||||
plain = DLRMv1(num_cat, rows, m_spa).eval().to(device)
|
||||
inputs = make_inputs(num_cat, batch, bag, rows, device)
|
||||
results: dict = {}
|
||||
|
||||
if variant in ("pt_eager", "all"):
|
||||
with torch.no_grad():
|
||||
results["pt_eager"] = time_rounds(lambda: plain(*inputs))
|
||||
|
||||
if variant in ("graph_safe_cg", "all"):
|
||||
safe = GraphSafeDLRM(copy.deepcopy(plain)).to(device).eval()
|
||||
g = _capture_cudagraph(lambda: safe(*inputs))
|
||||
results["graph_safe_cg"] = time_rounds(lambda: g.replay())
|
||||
|
||||
if variant in ("graph_safe_inductor_cg", "all"):
|
||||
safe = GraphSafeDLRM(copy.deepcopy(plain)).to(device).eval()
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(
|
||||
safe,
|
||||
backend="inductor",
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
)
|
||||
g = _capture_cudagraph(lambda: compiled(*inputs), warmup=5)
|
||||
results["graph_safe_inductor_cg"] = time_rounds(lambda: g.replay())
|
||||
|
||||
if variant in ("v3_inductor_cg", "all"):
|
||||
v3 = FusedDLRMv3(plain, bag).to(device).eval()
|
||||
flat = v3.pack_indices(list(inputs[2])).contiguous()
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(
|
||||
v3.forward,
|
||||
backend="inductor",
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
)
|
||||
g = _capture_cudagraph(lambda: compiled(inputs[0], flat), warmup=5)
|
||||
results["v3_inductor_cg"] = time_rounds(lambda: g.replay())
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--num-cat", type=int, default=3)
|
||||
ap.add_argument("--rows", type=int, default=4096)
|
||||
ap.add_argument("--batch", type=int, default=2048)
|
||||
ap.add_argument("--m-spa", type=int, default=16)
|
||||
ap.add_argument("--bag", type=int, default=2)
|
||||
ap.add_argument(
|
||||
"--variant",
|
||||
choices=["pt_eager", "graph_safe_cg", "graph_safe_inductor_cg", "v3_inductor_cg", "all"],
|
||||
default="all",
|
||||
)
|
||||
ap.add_argument("--json", action="store_true", help="emit one JSON line per variant")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda")
|
||||
if not args.json:
|
||||
print(
|
||||
f"==== PyTorch sweep cfg: num_cat={args.num_cat} batch={args.batch} "
|
||||
f"m_spa={args.m_spa} bag={args.bag} rows={args.rows} ===="
|
||||
)
|
||||
|
||||
results = _run_variant(
|
||||
args.variant, args.num_cat, args.batch, args.rows, args.m_spa, args.bag, device
|
||||
)
|
||||
|
||||
if args.json:
|
||||
for variant, ms in results.items():
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"variant": variant,
|
||||
"num_cat": args.num_cat,
|
||||
"batch": args.batch,
|
||||
"m_spa": args.m_spa,
|
||||
"bag": args.bag,
|
||||
"rows": args.rows,
|
||||
"ms": ms,
|
||||
"samples_per_sec": args.batch / (ms / 1000.0),
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
print()
|
||||
for label, ms in results.items():
|
||||
tput = args.batch / (ms / 1000.0)
|
||||
print(f" {label:<32} {ms:8.4f} ms ({tput:>12,.0f} samples/s)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user