mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
4 Commits
bump-egglo
...
codex/dlrm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6b0eb0ec1 | ||
|
|
1dcd0370ce | ||
|
|
6757a4e37b | ||
|
|
631451f8b8 |
43
README.md
43
README.md
@@ -45,49 +45,6 @@ cd ./examples/llama
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The CUDA test suites have a fast default path and explicit opt-in slow coverage.
|
||||
Slow tests include exhaustive cuBLASLt rewrite sweeps, CUDA search/genome fuzzing,
|
||||
large Python model compiles, and pretrained/full-width model tests.
|
||||
|
||||
Rust CUDA unit tests, matching the default CI path:
|
||||
|
||||
```bash
|
||||
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
|
||||
cargo test -p luminal_cuda_lite -- --test-threads=1
|
||||
```
|
||||
|
||||
Rust CUDA slow tests only:
|
||||
|
||||
```bash
|
||||
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
|
||||
cargo test -p luminal_cuda_lite -- --ignored --test-threads=1
|
||||
```
|
||||
|
||||
Rust CUDA full suite, including ignored slow tests:
|
||||
|
||||
```bash
|
||||
CUDARC_CUDA_VERSION=12080 CUDA_COMPUTE_CAP="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n1 | tr -d .)" \
|
||||
cargo test -p luminal_cuda_lite -- --include-ignored --test-threads=1
|
||||
```
|
||||
|
||||
Python CUDA tests, matching the default CI path:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda MATURIN_PEP517_ARGS="--features cuda --profile release" CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
|
||||
```
|
||||
|
||||
Python CUDA full suite, including slow tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda MATURIN_PEP517_ARGS="--features cuda --profile release" CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Speed
|
||||
|
||||
@@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
|
||||
@@ -5,6 +5,7 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::host::describe_host_op;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
@@ -39,6 +40,26 @@ impl DynBackend for CudaLiteDynBackend {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
self.runtime
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.map(|name| (*name).to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
self.runtime
|
||||
.host_ops()
|
||||
.iter()
|
||||
.map(|op| describe_host_op(*op))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn print_execution_stats(&self) {
|
||||
self.runtime.print_execution_stats();
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -247,6 +247,10 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("CuBlasSgemmV2")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
@@ -419,7 +419,6 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
|
||||
match epilogue {
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
|
||||
@@ -978,6 +977,18 @@ impl CuBlasLt {
|
||||
&& normalize(self.stride_c) == normalize(self.stride_d)
|
||||
&& self.c_order == self.d_order
|
||||
}
|
||||
|
||||
pub(crate) fn debug_summary(&self) -> String {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
format!(
|
||||
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
|
||||
resolve(&self.m),
|
||||
resolve(&self.n),
|
||||
resolve(&self.k),
|
||||
resolve(&self.batch_count),
|
||||
epilogue_name(self.epilogue),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
@@ -1114,6 +1125,10 @@ impl HostOp for CuBlasLt {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("CuBlasLt")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use crate::kernel::CudaGraphOp;
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
@@ -79,6 +80,24 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
|
||||
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
|
||||
return op.debug_summary();
|
||||
}
|
||||
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
|
||||
let mut summary = op.debug_summary();
|
||||
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
|
||||
&& let Some(timing) = op.debug_timing_summary()
|
||||
{
|
||||
summary.push_str(" [");
|
||||
summary.push_str(&timing);
|
||||
summary.push(']');
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
op.stats_name().unwrap_or("unknown").to_string()
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
|
||||
@@ -12,7 +12,10 @@ use luminal::{
|
||||
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
|
||||
hlir::{
|
||||
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
|
||||
Sin, Sqrt, SumReduce,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
@@ -65,6 +68,8 @@ pub type Ops = (
|
||||
KernelConstant,
|
||||
KernelCast,
|
||||
KernelEmbed,
|
||||
KernelConcat2D,
|
||||
KernelEmbeddingBagSum,
|
||||
);
|
||||
|
||||
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
|
||||
@@ -3266,14 +3271,20 @@ impl KernelOp for KernelEmbed {
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.embed_dim.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table) {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
@@ -3284,10 +3295,7 @@ extern \"C\" {{
|
||||
int token_id = token_ids[token_offset];
|
||||
out[out_offset + embed_idx] = embed_table[(long long)token_id * embed_dim + embed_idx];
|
||||
}}
|
||||
}}",
|
||||
vars.iter()
|
||||
.map(|i| format!("__constant__ int const_{i}[1];"))
|
||||
.join("\n"),
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
@@ -3298,10 +3306,8 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let constants = vars
|
||||
.into_iter()
|
||||
.map(|d| (d, module.get_global(&format!("const_{d}"), stream).unwrap()))
|
||||
.collect();
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
@@ -3353,3 +3359,361 @@ extern \"C\" {{
|
||||
"Embed"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelEmbeddingBagSum {
|
||||
n_bags: Expression,
|
||||
n_indices: Expression,
|
||||
hidden_dim: Expression,
|
||||
num_embeddings: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelEmbeddingBagSum {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelEmbeddingBagSum",
|
||||
&[
|
||||
("n_bags", EXPRESSION),
|
||||
("n_indices", EXPRESSION),
|
||||
("hidden_dim", EXPRESSION),
|
||||
("num_embeddings", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelEmbeddingBagSum {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert!(
|
||||
self.dtype == DType::F32,
|
||||
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
|
||||
self.dtype
|
||||
);
|
||||
let vars = self
|
||||
.n_bags
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.n_indices.dyn_vars())
|
||||
.chain(self.hidden_dim.dyn_vars())
|
||||
.chain(self.num_embeddings.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_bags = self.n_bags.to_kernel();
|
||||
let n_indices = self.n_indices.to_kernel();
|
||||
let hidden_dim = self.hidden_dim.to_kernel();
|
||||
let num_embeddings = self.num_embeddings.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
|
||||
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long bag = blockIdx.y;
|
||||
long long hidden_dim = {hidden_dim};
|
||||
long long n_bags = {n_bags};
|
||||
long long n_indices = {n_indices};
|
||||
long long num_embeddings = {num_embeddings};
|
||||
if (bag >= n_bags || dim >= hidden_dim) return;
|
||||
|
||||
int start_raw = offsets[bag];
|
||||
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
|
||||
int start = max(0, min(start_raw, (int)n_indices));
|
||||
int end = max(start, min(end_raw, (int)n_indices));
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int pos = start; pos < end; ++pos) {{
|
||||
int row = indices[pos];
|
||||
row = max(0, min(row, (int)num_embeddings - 1));
|
||||
sum += weight[(long long)row * hidden_dim + dim];
|
||||
}}
|
||||
|
||||
out[bag * hidden_dim + dim] = sum;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("embedding_bag_sum").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
|
||||
(self.hidden_dim.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.n_bags * self.hidden_dim
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.n_bags
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.n_indices.dyn_vars())
|
||||
.chain(self.hidden_dim.dyn_vars())
|
||||
.chain(self.num_embeddings.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Approximate: weights + indices + offsets
|
||||
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.n_indices * self.hidden_dim
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"EmbeddingBagSum"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConcat2D {
|
||||
rows: Expression,
|
||||
lhs_cols: Expression,
|
||||
rhs_cols: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelConcat2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelConcat2D",
|
||||
&[
|
||||
("rows", EXPRESSION),
|
||||
("lhs_cols", EXPRESSION),
|
||||
("rhs_cols", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite::<Concat2D, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelConcat2D {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert!(
|
||||
self.dtype == DType::F32,
|
||||
"KernelConcat2D only supports F32 today, got {:?}",
|
||||
self.dtype
|
||||
);
|
||||
let vars = self
|
||||
.rows
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.lhs_cols.dyn_vars())
|
||||
.chain(self.rhs_cols.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let rows = self.rows.to_kernel();
|
||||
let lhs_cols = self.lhs_cols.to_kernel();
|
||||
let rhs_cols = self.rhs_cols.to_kernel();
|
||||
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = {total};
|
||||
if (idx >= total) return;
|
||||
|
||||
long long rows = {rows};
|
||||
long long lhs_cols = {lhs_cols};
|
||||
long long rhs_cols = {rhs_cols};
|
||||
long long out_cols = lhs_cols + rhs_cols;
|
||||
if (rows == 0 || out_cols == 0) return;
|
||||
|
||||
long long row = idx / out_cols;
|
||||
long long col = idx - row * out_cols;
|
||||
if (col < lhs_cols) {{
|
||||
out[idx] = lhs[row * lhs_cols + col];
|
||||
}} else {{
|
||||
long long rhs_col = col - lhs_cols;
|
||||
out[idx] = rhs[row * rhs_cols + rhs_col];
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("concat_2d").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let output_size = self.output_size();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(output_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(output_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.rows * (self.lhs_cols + self.rhs_cols)
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.rows
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.lhs_cols.dyn_vars())
|
||||
.chain(self.rhs_cols.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Concat2D"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,8 +509,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
scatter_kernel,
|
||||
(n_src, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(n_src.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
//! that can be executed like any other HostOp.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
@@ -141,6 +143,8 @@ struct CudaGraphOpState {
|
||||
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
|
||||
/// Timing events for profiling
|
||||
timing_events: Vec<cudarc::driver::sys::CUevent>,
|
||||
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
|
||||
last_kernel_timings_us: Vec<(&'static str, f64)>,
|
||||
}
|
||||
|
||||
impl CudaGraphOpState {
|
||||
@@ -155,6 +159,7 @@ impl CudaGraphOpState {
|
||||
last_dyn_values: FxHashMap::default(),
|
||||
last_buffer_ptrs: FxHashMap::default(),
|
||||
timing_events: Vec::new(),
|
||||
last_kernel_timings_us: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,6 +197,41 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_summary(&self) -> String {
|
||||
let state = self.state.borrow();
|
||||
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
|
||||
for kernel in &state.kernels {
|
||||
*counts.entry(kernel.kernel_name).or_default() += 1;
|
||||
}
|
||||
let mut counts: Vec<_> = counts.into_iter().collect();
|
||||
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
|
||||
let top = counts
|
||||
.into_iter()
|
||||
.take(4)
|
||||
.map(|(name, count)| format!("{name}x{count}"))
|
||||
.join(", ");
|
||||
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
|
||||
}
|
||||
|
||||
pub fn debug_timing_summary(&self) -> Option<String> {
|
||||
let state = self.state.borrow();
|
||||
if state.last_kernel_timings_us.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
|
||||
for (name, us) in &state.last_kernel_timings_us {
|
||||
*totals.entry(*name).or_default() += *us;
|
||||
}
|
||||
let mut totals: Vec<_> = totals.into_iter().collect();
|
||||
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
|
||||
let top = totals
|
||||
.into_iter()
|
||||
.take(4)
|
||||
.map(|(name, us)| format!("{name}={us:.0}us"))
|
||||
.join(", ");
|
||||
Some(top)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -566,6 +606,23 @@ impl CudaGraphOp {
|
||||
// Launch the graph
|
||||
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
|
||||
|
||||
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
|
||||
&& state.timing_events.len() >= state.kernels.len() + 1
|
||||
{
|
||||
stream.synchronize()?;
|
||||
let ctx = stream.context().clone();
|
||||
state.last_kernel_timings_us.clear();
|
||||
for idx in 0..state.kernels.len() {
|
||||
let start_event = state.timing_events[idx];
|
||||
let end_event = state.timing_events[idx + 1];
|
||||
let kernel_name = state.kernels[idx].kernel_name;
|
||||
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
|
||||
.map(|ms| ms as f64 * 1_000.0)
|
||||
.unwrap_or(0.0);
|
||||
state.last_kernel_timings_us.push((kernel_name, us));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -584,8 +641,9 @@ impl CudaGraphOp {
|
||||
state.kernel_params.clear();
|
||||
state.kernel_params.reserve(num_kernels);
|
||||
|
||||
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
|
||||
let tracing_enabled = enabled!(Level::TRACE);
|
||||
if tracing_enabled {
|
||||
if tracing_enabled || profile_cuda_graph {
|
||||
let needed_events = num_kernels + 1;
|
||||
while state.timing_events.len() < needed_events {
|
||||
state.timing_events.push(create_cuda_event(&ctx)?);
|
||||
@@ -701,7 +759,7 @@ impl CudaGraphOp {
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
let timing_event = if tracing_enabled || profile_cuda_graph {
|
||||
Some(state.timing_events[idx])
|
||||
} else {
|
||||
None
|
||||
@@ -739,7 +797,9 @@ impl CudaGraphOp {
|
||||
prev_graph_node = Some(graph_node);
|
||||
}
|
||||
|
||||
if tracing_enabled && let Some(prev) = prev_graph_node {
|
||||
if (tracing_enabled || profile_cuda_graph)
|
||||
&& let Some(prev) = prev_graph_node
|
||||
{
|
||||
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::{
|
||||
host::{DeviceBuffer, HostOp},
|
||||
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
|
||||
host::{DeviceBuffer, HostOp, describe_host_op},
|
||||
kernel::{
|
||||
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
|
||||
record_cuda_graph_timings, record_event_on_stream,
|
||||
},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
|
||||
|
||||
@@ -60,6 +63,12 @@ pub struct KernelStats {
|
||||
pub tflops: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HostOpStats {
|
||||
pub name: String,
|
||||
pub execution_time_us: f64,
|
||||
}
|
||||
|
||||
impl Debug for ExecutableHostOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "HostOp: ({:?})", self.internal)
|
||||
@@ -141,6 +150,7 @@ pub struct CudaRuntime {
|
||||
changed_hlir: FxHashSet<NodeIndex>,
|
||||
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
|
||||
pub last_kernel_stats: Vec<KernelStats>,
|
||||
pub last_host_op_stats: Vec<HostOpStats>,
|
||||
pub last_total_time_us: f64,
|
||||
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
/// When true, execute() skips input buffer consumption (used during search/profile)
|
||||
@@ -1203,6 +1213,7 @@ impl Runtime for CudaRuntime {
|
||||
changed_hlir: FxHashSet::default(),
|
||||
cuda_graph_timings: vec![],
|
||||
last_kernel_stats: vec![],
|
||||
last_host_op_stats: vec![],
|
||||
last_total_time_us: 0.0,
|
||||
kernel_cache: FxHashMap::default(),
|
||||
profiling: false,
|
||||
@@ -1454,6 +1465,8 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
|
||||
let mut host_timing_events = Vec::new();
|
||||
|
||||
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
@@ -1507,6 +1520,16 @@ impl Runtime for CudaRuntime {
|
||||
n_inputs = exec_op.inputs.len()
|
||||
)
|
||||
.entered();
|
||||
let host_op_timing = if profile_host_ops {
|
||||
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
|
||||
let ctx = exec_op.stream.context().clone();
|
||||
let start_event = create_cuda_event(&ctx).unwrap();
|
||||
let end_event = create_cuda_event(&ctx).unwrap();
|
||||
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
|
||||
Some((name, ctx, start_event, end_event))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
exec_op
|
||||
.internal
|
||||
.execute(
|
||||
@@ -1522,10 +1545,28 @@ impl Runtime for CudaRuntime {
|
||||
exec_op.internal.stats_name().unwrap_or("unknown")
|
||||
);
|
||||
});
|
||||
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
|
||||
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
|
||||
host_timing_events.push((name, ctx, start_event, end_event));
|
||||
}
|
||||
}
|
||||
// Single sync at end - CUDA stream ordering guarantees sequential execution
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
self.last_host_op_stats.clear();
|
||||
if profile_host_ops {
|
||||
for (name, ctx, start_event, end_event) in host_timing_events {
|
||||
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
|
||||
.map(|ms| ms as f64 * 1_000.0)
|
||||
.unwrap_or(0.0);
|
||||
self.last_host_op_stats.push(HostOpStats {
|
||||
name,
|
||||
execution_time_us,
|
||||
});
|
||||
destroy_cuda_event(&ctx, start_event);
|
||||
destroy_cuda_event(&ctx, end_event);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate last_kernel_stats from HostOps that report stats
|
||||
self.last_kernel_stats.clear();
|
||||
@@ -1861,6 +1902,16 @@ impl CudaRuntime {
|
||||
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
|
||||
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
|
||||
|
||||
if !self.last_host_op_stats.is_empty() {
|
||||
println!("\n=== Host Operation Statistics ===\n");
|
||||
println!("{:<20} {:>12}", "HostOp", "Time (us)");
|
||||
println!("{}", "-".repeat(34));
|
||||
for s in &self.last_host_op_stats {
|
||||
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
|
||||
}
|
||||
println!("{}", "-".repeat(34));
|
||||
}
|
||||
|
||||
// Print kernel stats
|
||||
if !self.last_kernel_stats.is_empty() {
|
||||
println!("\n=== Kernel Execution Statistics ===\n");
|
||||
|
||||
@@ -248,6 +248,23 @@ impl CompiledGraph {
|
||||
self.runtime.device_type()
|
||||
}
|
||||
|
||||
/// Names of kernels compiled into the active runtime bucket, if available.
|
||||
#[getter]
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
self.runtime.kernel_names()
|
||||
}
|
||||
|
||||
/// Names of host ops in the active runtime bucket, if available.
|
||||
#[getter]
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
self.runtime.host_op_names()
|
||||
}
|
||||
|
||||
/// Print backend execution statistics for the last run, if supported.
|
||||
fn print_execution_stats(&self) {
|
||||
self.runtime.print_execution_stats();
|
||||
}
|
||||
|
||||
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
|
||||
#[getter]
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, bail};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
@@ -8,21 +8,62 @@ use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
let alpha = match op {
|
||||
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
|
||||
BinaryOp::Mul | BinaryOp::Div => 1.0,
|
||||
};
|
||||
|
||||
let lhs = node.inputs[0]
|
||||
.arg
|
||||
.as_tensor_name()
|
||||
.map(|name| self.get_tensor(name))
|
||||
.transpose()?;
|
||||
let rhs = node.inputs[1]
|
||||
.arg
|
||||
.as_tensor_name()
|
||||
.map(|name| self.get_tensor(name))
|
||||
.transpose()?;
|
||||
|
||||
match (lhs, rhs) {
|
||||
(Some(a), Some(mut b)) => {
|
||||
if alpha != 1.0 {
|
||||
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
|
||||
}
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
}
|
||||
(Some(a), None) => {
|
||||
let mut val = self.get_float_arg(node, 1)? as f32;
|
||||
if alpha != 1.0 {
|
||||
val *= alpha;
|
||||
}
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
(None, Some(mut b)) => {
|
||||
if alpha != 1.0 {
|
||||
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
|
||||
}
|
||||
let lhs_val = self.get_float_arg(node, 0)? as f32;
|
||||
let a = self
|
||||
.graph
|
||||
.constant_float(lhs_val)
|
||||
.cast(b.dtype)
|
||||
.expand_rhs(b.shape);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
}
|
||||
(None, None) => bail!("{} expects at least one tensor operand", node.target),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ impl<'a> Translator<'a> {
|
||||
result
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
|
||||
"torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
@@ -133,8 +134,28 @@ impl<'a> Translator<'a> {
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
if alpha == 0.0 && beta == 0.0 {
|
||||
self.graph
|
||||
.constant_float(0.0)
|
||||
.cast(mm.dtype)
|
||||
.expand_rhs(mm.shape)
|
||||
} else if beta == 0.0 {
|
||||
if alpha == 1.0 { mm } else { mm * alpha }
|
||||
} else if alpha == 0.0 {
|
||||
let input = if beta == 1.0 { input } else { input * beta };
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(input.dtype)
|
||||
.expand_rhs(mm.shape);
|
||||
let (input, _) = broadcast_binary(input, zero);
|
||||
input
|
||||
} else {
|
||||
let input = if beta == 1.0 { input } else { input * beta };
|
||||
let mm = if alpha == 1.0 { mm } else { mm * alpha };
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input + mm
|
||||
}
|
||||
}
|
||||
|
||||
// Convolution
|
||||
@@ -147,8 +168,14 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
"torch.ops.aten._embedding_bag.default"
|
||||
| "torch.ops.aten._embedding_bag_forward_only.default" => {
|
||||
self.translate_embedding_bag(node)?
|
||||
}
|
||||
"<built-in function getitem>" => self.translate_getitem(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
@@ -163,6 +190,9 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
|
||||
self.translate_native_batch_norm_no_training(node)?
|
||||
}
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
|
||||
@@ -12,6 +12,64 @@ const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn try_concat_2d_fast(
|
||||
&mut self,
|
||||
lhs: GraphTensor,
|
||||
rhs: GraphTensor,
|
||||
axis: usize,
|
||||
) -> Option<GraphTensor> {
|
||||
if axis != 1
|
||||
|| lhs.dtype != DType::F32
|
||||
|| rhs.dtype != DType::F32
|
||||
|| lhs.shape.len() != 2
|
||||
|| rhs.shape.len() != 2
|
||||
|| !lhs.shape.is_contiguous()
|
||||
|| !rhs.shape.is_contiguous()
|
||||
|| lhs.shape.dims[0] != rhs.shape.dims[0]
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let rows = lhs.shape.dims[0];
|
||||
let lhs_cols = lhs.shape.dims[1];
|
||||
let rhs_cols = rhs.shape.dims[1];
|
||||
let id = self.graph.add_op(
|
||||
luminal::hlir::Concat2D {
|
||||
rows,
|
||||
lhs_cols,
|
||||
rhs_cols,
|
||||
},
|
||||
&[lhs.id, rhs.id],
|
||||
);
|
||||
|
||||
Some(GraphTensor::from_id(
|
||||
id,
|
||||
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
|
||||
lhs.graph_ref,
|
||||
lhs.dtype,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
|
||||
let index = self
|
||||
.get_int_arg(node, 2)
|
||||
.context("select.int: missing index")?;
|
||||
|
||||
let dim_size = a.shape.dims[dim]
|
||||
.to_usize()
|
||||
.context("select.int: symbolic dims are not supported for negative indices")?;
|
||||
let normalized_index = if index < 0 {
|
||||
(dim_size as i64 + index) as usize
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
|
||||
.squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
@@ -80,6 +138,43 @@ impl<'a> Translator<'a> {
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
|
||||
sizes
|
||||
.into_iter()
|
||||
.map(|size| {
|
||||
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
|
||||
Ok(Expression::from(size as usize))
|
||||
})
|
||||
.collect::<Result<_>>()?
|
||||
} else {
|
||||
self.get_exprs_arg(node, 1)?
|
||||
};
|
||||
|
||||
anyhow::ensure!(
|
||||
repeats.len() >= a.shape.len(),
|
||||
"repeat: repeats rank {} is smaller than input rank {}",
|
||||
repeats.len(),
|
||||
a.shape.len()
|
||||
);
|
||||
|
||||
while a.shape.len() < repeats.len() {
|
||||
a = a.unsqueeze(0);
|
||||
}
|
||||
|
||||
Ok(a.repeat(repeats))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let index = self.get_int_arg(node, 1)?;
|
||||
anyhow::ensure!(
|
||||
index == 0,
|
||||
"getitem: only tuple[0] access is supported today, got index={index}"
|
||||
);
|
||||
self.get_input_tensor(node, 0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1).unwrap_or(0);
|
||||
@@ -161,7 +256,11 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
for t in &tensors[1..] {
|
||||
result = result.concat_along(*t, dim);
|
||||
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
|
||||
result = fast;
|
||||
} else {
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -218,6 +317,79 @@ impl<'a> Translator<'a> {
|
||||
bail!("index.Tensor: no index tensors in optional_tensors list");
|
||||
}
|
||||
index_names = found_tensors;
|
||||
|
||||
// Multiple explicit index tensors after leading `None`s mean
|
||||
// "keep the prefix dims, then advanced-index the contiguous
|
||||
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
|
||||
if first_non_none_dim > 0
|
||||
&& index_names.len() > 1
|
||||
&& first_non_none_dim + index_names.len() == source.shape.len()
|
||||
{
|
||||
let src_dims = source.shape.dims;
|
||||
let indexed_dims = &src_dims[first_non_none_dim..];
|
||||
let n_indexed = index_names.len();
|
||||
|
||||
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
|
||||
for i in (0..n_indexed - 1).rev() {
|
||||
strides[i] = strides[i + 1] * indexed_dims[i + 1];
|
||||
}
|
||||
|
||||
let mut flat_idx: Option<GraphTensor> = None;
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
let axis_size = indexed_dims[dim_idx];
|
||||
let idx_int = idx_tensor.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
|
||||
let is_negative = idx_int.lt(zero).cast(DType::Int);
|
||||
let idx_int = idx_int + is_negative * axis_size;
|
||||
|
||||
let stride = strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
idx_int
|
||||
} else {
|
||||
idx_int * stride
|
||||
};
|
||||
|
||||
flat_idx = Some(match flat_idx {
|
||||
Some(acc) => {
|
||||
let (acc_b, w_b) = broadcast_binary(acc, weighted);
|
||||
acc_b + w_b
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
|
||||
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
|
||||
let idx_shape = flat_idx.shape.dims.to_vec();
|
||||
let mut idx_numel = Expression::from(1usize);
|
||||
for dim in &idx_shape {
|
||||
idx_numel *= *dim;
|
||||
}
|
||||
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
|
||||
|
||||
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
|
||||
let mut indexed_size = Expression::from(1usize);
|
||||
for dim in indexed_dims {
|
||||
indexed_size *= *dim;
|
||||
}
|
||||
let mut flat_source_shape = prefix_dims.clone();
|
||||
flat_source_shape.push(indexed_size);
|
||||
let flat_source = reshape_tensor(source, flat_source_shape);
|
||||
|
||||
let mut expanded_idx = flat_idx;
|
||||
for _ in 0..prefix_dims.len() {
|
||||
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let mut target = prefix_dims.clone();
|
||||
target.push(idx_numel);
|
||||
expanded_idx.shape.expand(target);
|
||||
|
||||
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
|
||||
let mut result_shape = prefix_dims;
|
||||
result_shape.extend_from_slice(&idx_shape);
|
||||
return Ok(reshape_tensor(gathered, result_shape));
|
||||
}
|
||||
|
||||
// Simple case: single non-None index on a specific dim → gather_elements
|
||||
if first_non_none_dim > 0 && index_names.len() == 1 {
|
||||
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
|
||||
@@ -28,6 +28,45 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_embedding_bag_generic(
|
||||
&mut self,
|
||||
weight: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
offsets: GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
let n_indices = indices.shape.dims[0];
|
||||
let n_bags = offsets.shape.dims[0];
|
||||
|
||||
// Gather per-index embeddings: [E] -> [E, D].
|
||||
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
|
||||
let gathered = weight.gather(ids_expanded + arange);
|
||||
|
||||
// Bag assignment per position:
|
||||
// bag_id[pos] = count(offsets <= pos) - 1
|
||||
// This supports empty bags too, because repeated offsets simply skip a
|
||||
// bag id when no positions land in that interval.
|
||||
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
|
||||
let starts = offsets.expand_dim(1, n_indices);
|
||||
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
|
||||
- self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(vec![n_indices]);
|
||||
|
||||
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
|
||||
let bag_ids = bag_ids.expand_dim(0, n_bags);
|
||||
let mask = bag_ids
|
||||
.eq(bag_axis)
|
||||
.expand_dim(2, hidden_dim)
|
||||
.cast(gathered.dtype);
|
||||
let gathered = gathered.expand_dim(0, n_bags);
|
||||
|
||||
Ok((gathered * mask).sum(1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
.inputs
|
||||
@@ -180,6 +219,45 @@ impl<'a> Translator<'a> {
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
|
||||
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
anyhow::ensure!(
|
||||
mode == 0,
|
||||
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
indices.shape.len() == 1 && offsets.shape.len() == 1,
|
||||
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
|
||||
indices.shape.len(),
|
||||
offsets.shape.len()
|
||||
);
|
||||
|
||||
if weight.dtype == DType::F32 {
|
||||
let id = self.graph.add_op(
|
||||
luminal::hlir::EmbeddingBagSum {
|
||||
n_bags: offsets.shape.dims[0],
|
||||
n_indices: indices.shape.dims[0],
|
||||
hidden_dim: weight.shape.dims[1],
|
||||
num_embeddings: weight.shape.dims[0],
|
||||
},
|
||||
&[weight.id, indices.id, offsets.id],
|
||||
);
|
||||
return Ok(GraphTensor::from_id(
|
||||
id,
|
||||
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
|
||||
weight.graph_ref,
|
||||
DType::F32,
|
||||
));
|
||||
}
|
||||
|
||||
self.translate_embedding_bag_generic(weight, indices, offsets)
|
||||
}
|
||||
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
|
||||
@@ -21,6 +21,30 @@ const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn expand_channel_parameter(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
parameter: GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
anyhow::ensure!(
|
||||
input.shape.len() >= 2,
|
||||
"batch_norm: expected rank >= 2 input, got rank {}",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
parameter.shape.len() == 1,
|
||||
"batch_norm: expected 1D channel parameter, got rank {}",
|
||||
parameter.shape.len()
|
||||
);
|
||||
|
||||
let mut expanded = parameter.unsqueeze(0);
|
||||
for axis in 2..input.shape.len() {
|
||||
expanded = expanded.unsqueeze(axis);
|
||||
}
|
||||
expanded.shape.expand(input.dims().to_vec());
|
||||
Ok(expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
@@ -101,6 +125,30 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_native_batch_norm_no_training(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
|
||||
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
|
||||
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
|
||||
|
||||
let mut result = (input - running_mean) / (running_var + eps).sqrt();
|
||||
|
||||
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
|
||||
result = result * weight;
|
||||
}
|
||||
|
||||
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
|
||||
result = result + bias;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
|
||||
@@ -8,12 +8,14 @@ from .compiled_model import CompiledModel
|
||||
# Import Rust extension components (built by maturin)
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .main import luminal_backend, register_backend
|
||||
from .pt2 import compile
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"compile",
|
||||
"luminal_backend",
|
||||
"register_backend",
|
||||
"CompiledGraph",
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
|
||||
from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
@@ -28,6 +27,14 @@ class CompiledModel:
|
||||
self._input_names = input_names or graph_result.input_names
|
||||
self._output_names = graph_result.output_names
|
||||
self._output_shapes = graph_result.output_shapes
|
||||
output_dtype_codes = graph_result.output_dtypes
|
||||
self._output_dtypes = [
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._output_names))
|
||||
]
|
||||
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
@@ -35,6 +42,9 @@ class CompiledModel:
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Cache converted/contiguous views for repeated calls with the same input
|
||||
# tensors so we don't rebuild sparse index buffers every forward.
|
||||
self._prepared_input_cache = [None] * len(self._input_names)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
@@ -43,6 +53,80 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
self._single_float_output_fast_path = (
|
||||
self._supports_device_ptrs
|
||||
and not self._has_dynamic_dims
|
||||
and len(self._output_names) == 1
|
||||
and self._output_dtypes[0].is_floating_point
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
|
||||
return (
|
||||
id(tensor),
|
||||
getattr(tensor, "_version", None),
|
||||
tensor.data_ptr(),
|
||||
expected_dtype,
|
||||
)
|
||||
|
||||
def _prepare_input_tensor(
|
||||
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
detached = tensor.detach()
|
||||
needs_preparation = (
|
||||
detached.dtype != expected_dtype or not detached.is_contiguous()
|
||||
)
|
||||
if not needs_preparation:
|
||||
return detached
|
||||
|
||||
cache_key = self._input_cache_key(detached, expected_dtype)
|
||||
cached = self._prepared_input_cache[index]
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
|
||||
prepared = detached.contiguous().to(expected_dtype)
|
||||
self._prepared_input_cache[index] = (cache_key, prepared)
|
||||
return prepared
|
||||
|
||||
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
"""Bind the current user inputs into the Rust graph."""
|
||||
input_refs = []
|
||||
for index, (name, tensor, expected_dtype) in enumerate(
|
||||
zip(self._input_names, user_inputs, self._input_dtypes)
|
||||
):
|
||||
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
n_bytes = prepared.numel() * prepared.element_size()
|
||||
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
|
||||
input_refs.append(prepared)
|
||||
else:
|
||||
if prepared.device.type != "cpu":
|
||||
prepared = prepared.cpu()
|
||||
n_bytes = prepared.numel() * prepared.element_size()
|
||||
dtype_code = _torch_dtype_code(prepared.dtype)
|
||||
self._graph.set_input_from_ptr(
|
||||
name, prepared.data_ptr(), n_bytes, dtype_code
|
||||
)
|
||||
return input_refs
|
||||
|
||||
def _run_static_single_float_output(
|
||||
self, user_inputs: List[torch.Tensor], input_device: torch.device
|
||||
):
|
||||
_input_refs = self._bind_user_inputs(user_inputs)
|
||||
output_name = self._output_names[0]
|
||||
output_dtype = self._output_dtypes[0]
|
||||
out = torch.empty(
|
||||
self._static_output_shapes[0], dtype=output_dtype, device=input_device
|
||||
)
|
||||
self._graph.set_output_device_ptr(
|
||||
output_name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
self._graph.run()
|
||||
if not self._graph.output_is_zero_copy(output_name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
output_name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
return (out,)
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -86,25 +170,18 @@ class CompiledModel:
|
||||
if self._has_dynamic_dims:
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
elif (
|
||||
self._single_float_output_fast_path
|
||||
and input_device.type != "cpu"
|
||||
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
|
||||
):
|
||||
return self._run_static_single_float_output(user_inputs, input_device)
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
_input_refs = self._bind_user_inputs(user_inputs)
|
||||
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
@@ -112,8 +189,6 @@ class CompiledModel:
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
@@ -121,8 +196,8 @@ class CompiledModel:
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
@@ -140,8 +215,8 @@ class CompiledModel:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
else torch.float32
|
||||
)
|
||||
out = output_tensors[i]
|
||||
@@ -178,8 +253,8 @@ class CompiledModel:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
@@ -199,3 +274,41 @@ class CompiledModel:
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
def _leaf_paths(tree: Any, prefix=()):
|
||||
if torch.is_tensor(tree):
|
||||
return [prefix]
|
||||
if isinstance(tree, (list, tuple)):
|
||||
paths = []
|
||||
for idx, value in enumerate(tree):
|
||||
paths.extend(_leaf_paths(value, prefix + (idx,)))
|
||||
return paths
|
||||
if isinstance(tree, dict):
|
||||
paths = []
|
||||
for key, value in tree.items():
|
||||
paths.extend(_leaf_paths(value, prefix + (key,)))
|
||||
return paths
|
||||
return [prefix]
|
||||
|
||||
|
||||
def _follow_path(tree: Any, path):
|
||||
value = tree
|
||||
for key in path:
|
||||
value = value[key]
|
||||
return value
|
||||
|
||||
|
||||
class StructuredCompiledModel:
|
||||
"""Preserve a module's original nested input structure for direct PT2 compile()."""
|
||||
|
||||
def __init__(self, compiled_model, example_args):
|
||||
self._compiled = compiled_model
|
||||
self._leaf_paths = _leaf_paths(example_args)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._compiled, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
|
||||
return self._compiled(*flat_inputs)
|
||||
|
||||
@@ -9,10 +9,12 @@ import inspect
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .compiled_model import CompiledModel
|
||||
from .compiled_model import CompiledModel, StructuredCompiledModel
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
@@ -184,6 +186,66 @@ def _save_and_compile(
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _has_cuda_inputs(flat_example_inputs):
|
||||
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
|
||||
|
||||
|
||||
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
|
||||
"""Search env overrides for direct compile() calls.
|
||||
|
||||
CUDA DLRM-style models benefit materially from a deeper per-candidate
|
||||
profile and from keeping more parents alive between generations. Keep the
|
||||
defaults narrow so CPU and env-configured callers are unchanged.
|
||||
"""
|
||||
has_cuda = _has_cuda_inputs(flat_example_inputs)
|
||||
overrides = {}
|
||||
|
||||
if search_trials is not None:
|
||||
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
|
||||
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
|
||||
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
|
||||
|
||||
if search_keep_best is not None:
|
||||
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
|
||||
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
|
||||
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(overrides):
|
||||
sentinel = object()
|
||||
previous = {}
|
||||
try:
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = old_value
|
||||
|
||||
|
||||
def _strip_exported_weights_for_zero_copy(ep, original_weights):
|
||||
"""Shrink the saved .pt2 artifact when original weights will be reused."""
|
||||
if not original_weights:
|
||||
return
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
if isinstance(orig, torch.nn.Parameter):
|
||||
replacement = torch.nn.Parameter(
|
||||
replacement, requires_grad=orig.requires_grad
|
||||
)
|
||||
ep._state_dict[key] = replacement
|
||||
del orig
|
||||
|
||||
|
||||
def _safe_int_bound(value):
|
||||
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
|
||||
|
||||
@@ -401,7 +463,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
def compile(
|
||||
model,
|
||||
example_input,
|
||||
search_iterations=25,
|
||||
search_iterations=None,
|
||||
search_trials=None,
|
||||
search_keep_best=None,
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
@@ -413,7 +477,13 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor — or a list/tuple of tensors for
|
||||
multi-input models.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
search_iterations: Number of optimization search iterations. When None,
|
||||
defaults to 200 on CUDA inputs and 10 otherwise.
|
||||
search_trials: Optional per-candidate profiling trials inside Luminal's
|
||||
search. When unset, direct CUDA compile defaults to 5.
|
||||
search_keep_best: Optional number of parent candidates to retain
|
||||
between search generations. When unset, direct CUDA compile
|
||||
defaults to 3.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
|
||||
@@ -432,17 +502,26 @@ def compile(
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(
|
||||
example_input
|
||||
if isinstance(example_input, (list, tuple))
|
||||
else [example_input]
|
||||
)
|
||||
|
||||
if isinstance(example_input, (list, tuple)):
|
||||
example_args = tuple(example_input)
|
||||
else:
|
||||
example_args = (example_input,)
|
||||
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
|
||||
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(flat_example_inputs)
|
||||
|
||||
if search_iterations is None:
|
||||
search_iterations = (
|
||||
200
|
||||
if _has_cuda_inputs(flat_example_inputs)
|
||||
else 10
|
||||
)
|
||||
search_env = _direct_search_env(
|
||||
flat_example_inputs,
|
||||
search_trials=search_trials,
|
||||
search_keep_best=search_keep_best,
|
||||
)
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -488,7 +567,16 @@ def compile(
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
original_weights = model.state_dict()
|
||||
_strip_exported_weights_for_zero_copy(ep, original_weights)
|
||||
with _temporary_env(search_env):
|
||||
compiled = _save_and_compile(
|
||||
ep,
|
||||
factory,
|
||||
search_iterations,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
return StructuredCompiledModel(compiled, example_args)
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
@@ -576,12 +664,7 @@ def _eager_pt2_compile(
|
||||
# from the EP before saving. The Rust side uses device pointers for these
|
||||
# weights, not the .pt2 file data, so serializing them is pure IO waste
|
||||
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
_strip_exported_weights_for_zero_copy(ep, original_weights)
|
||||
|
||||
# Save EP to disk, then free it and the traced graph module before Rust
|
||||
# compilation. torch.export clones the state_dict internally; holding ep
|
||||
@@ -595,11 +678,20 @@ def _eager_pt2_compile(
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
default_search_iterations = (
|
||||
50
|
||||
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
|
||||
else 10
|
||||
)
|
||||
search_iterations = int(
|
||||
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
|
||||
)
|
||||
|
||||
try:
|
||||
return _save_and_compile(
|
||||
pt2_path,
|
||||
factory,
|
||||
10,
|
||||
search_iterations,
|
||||
original_weights=original_weights,
|
||||
user_indices=user_indices,
|
||||
)
|
||||
|
||||
11
crates/luminal_python/tests/for_jake.md
Normal file
11
crates/luminal_python/tests/for_jake.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# DLRM CUDA Benchmark
|
||||
|
||||
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
|
||||
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
|
||||
|
||||
| Path | Median latency | Throughput |
|
||||
| --- | ---: | ---: |
|
||||
| eager | 0.267 ms | 7,674,321 candidates/s |
|
||||
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
|
||||
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
|
||||
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |
|
||||
213
crates/luminal_python/tests/test_deepctr_torch.py
Normal file
213
crates/luminal_python/tests/test_deepctr_torch.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
|
||||
|
||||
These tests are intended for the local integration workflow where the
|
||||
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
|
||||
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
|
||||
then run the same model through ``backend=luminal_backend``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("sklearn")
|
||||
pytest.importorskip("tqdm")
|
||||
|
||||
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
|
||||
if not DEEPCTR_ROOT.exists():
|
||||
pytest.skip(
|
||||
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
deepctr_root = str(DEEPCTR_ROOT)
|
||||
if deepctr_root not in sys.path:
|
||||
sys.path.insert(0, deepctr_root)
|
||||
|
||||
from deepctr_torch.inputs import (
|
||||
DenseFeat,
|
||||
SparseFeat,
|
||||
VarLenSparseFeat,
|
||||
build_input_features,
|
||||
)
|
||||
from deepctr_torch.models import DCN
|
||||
from deepctr_torch.models.din import DIN
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _stack_features(
|
||||
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
|
||||
) -> torch.Tensor:
|
||||
parts = []
|
||||
for name in build_input_features(feature_columns):
|
||||
value = np.asarray(feature_dict[name])
|
||||
if value.ndim == 1:
|
||||
value = np.expand_dims(value, axis=1)
|
||||
parts.append(value)
|
||||
stacked = np.concatenate(parts, axis=-1)
|
||||
return torch.tensor(stacked, dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
if isinstance(output, tuple) and len(output) == 1:
|
||||
return output[0]
|
||||
return output
|
||||
|
||||
|
||||
def _assert_allclose(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
|
||||
) -> None:
|
||||
max_diff = torch.max(torch.abs(lhs - rhs)).item()
|
||||
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
|
||||
|
||||
|
||||
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return _unwrap(model(*inputs))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
prev_recompile_limit = torch._dynamo.config.recompile_limit
|
||||
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = prev_recompile_limit
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
|
||||
|
||||
|
||||
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _make_dcn(
|
||||
device: torch.device, cross_parameterization: str
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
|
||||
torch.manual_seed(0)
|
||||
feature_columns = [
|
||||
SparseFeat("s0", 5, embedding_dim=4),
|
||||
SparseFeat("s1", 7, embedding_dim=4),
|
||||
DenseFeat("d0", 1),
|
||||
DenseFeat("d1", 1),
|
||||
]
|
||||
feature_dict = {
|
||||
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
|
||||
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
|
||||
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
|
||||
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
model = DCN(
|
||||
linear_feature_columns=feature_columns,
|
||||
dnn_feature_columns=feature_columns,
|
||||
cross_num=2,
|
||||
cross_parameterization=cross_parameterization,
|
||||
dnn_hidden_units=(16,),
|
||||
dnn_dropout=0.0,
|
||||
device=str(device),
|
||||
).eval()
|
||||
inputs = (_stack_features(feature_columns, feature_dict, device),)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
|
||||
torch.manual_seed(0)
|
||||
feature_columns = [
|
||||
SparseFeat("user", 4, embedding_dim=4),
|
||||
SparseFeat("gender", 2, embedding_dim=4),
|
||||
SparseFeat("item_id", 4, embedding_dim=8),
|
||||
SparseFeat("cate_id", 3, embedding_dim=4),
|
||||
DenseFeat("pay_score", 1),
|
||||
VarLenSparseFeat(
|
||||
SparseFeat(
|
||||
"hist_item_id",
|
||||
vocabulary_size=4,
|
||||
embedding_dim=8,
|
||||
embedding_name="item_id",
|
||||
),
|
||||
maxlen=4,
|
||||
length_name="seq_length",
|
||||
),
|
||||
VarLenSparseFeat(
|
||||
SparseFeat(
|
||||
"hist_cate_id",
|
||||
vocabulary_size=3,
|
||||
embedding_dim=4,
|
||||
embedding_name="cate_id",
|
||||
),
|
||||
maxlen=4,
|
||||
length_name="seq_length",
|
||||
),
|
||||
]
|
||||
feature_dict = {
|
||||
"user": np.array([0, 1, 2, 3], dtype=np.int64),
|
||||
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
|
||||
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
|
||||
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
|
||||
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
|
||||
"hist_item_id": np.array(
|
||||
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
|
||||
dtype=np.int64,
|
||||
),
|
||||
"hist_cate_id": np.array(
|
||||
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
|
||||
dtype=np.int64,
|
||||
),
|
||||
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
|
||||
}
|
||||
model = DIN(
|
||||
feature_columns,
|
||||
["item_id", "cate_id"],
|
||||
dnn_dropout=0.0,
|
||||
device=str(device),
|
||||
).eval()
|
||||
inputs = (_stack_features(feature_columns, feature_dict, device),)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
|
||||
def test_deepctr_dcn_matches_inductor_when_supported(
|
||||
device: torch.device, cross_parameterization: str
|
||||
) -> None:
|
||||
model, inputs = _make_dcn(device, cross_parameterization)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
|
||||
|
||||
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
|
||||
model, inputs = _make_din(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
456
crates/luminal_python/tests/test_dlrm.py
Normal file
456
crates/luminal_python/tests/test_dlrm.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""DLRM coverage for the luminal torch.compile backend.
|
||||
|
||||
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
|
||||
that eager mode, ``torch.compile(..., backend="inductor")``, and
|
||||
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
|
||||
configurations, including a CUDA benchmark that compares Luminal against
|
||||
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("sklearn")
|
||||
|
||||
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
|
||||
if not DLRM_ROOT.exists():
|
||||
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
|
||||
|
||||
dlrm_root = str(DLRM_ROOT)
|
||||
if dlrm_root not in sys.path:
|
||||
sys.path.insert(0, dlrm_root)
|
||||
|
||||
|
||||
def _install_dlrm_import_stubs() -> None:
|
||||
ext_dist = types.ModuleType("extend_distributed")
|
||||
ext_dist.my_size = 1
|
||||
ext_dist.dist = None
|
||||
ext_dist.get_split_lengths = lambda n: (n, [n])
|
||||
ext_dist.get_my_slice = lambda n: slice(0, n)
|
||||
|
||||
class _AllToAll:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def wait(self):
|
||||
return self._values
|
||||
|
||||
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
|
||||
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"extend_distributed", loader=None
|
||||
)
|
||||
sys.modules["extend_distributed"] = ext_dist
|
||||
|
||||
mlperf_logger = types.ModuleType("mlperf_logger")
|
||||
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"mlperf_logger", loader=None
|
||||
)
|
||||
sys.modules["mlperf_logger"] = mlperf_logger
|
||||
|
||||
tensorboard = types.ModuleType("torch.utils.tensorboard")
|
||||
|
||||
class SummaryWriter:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def add_scalar(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
tensorboard.SummaryWriter = SummaryWriter
|
||||
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"torch.utils.tensorboard", loader=None
|
||||
)
|
||||
sys.modules["torch.utils.tensorboard"] = tensorboard
|
||||
|
||||
onnx = types.ModuleType("onnx")
|
||||
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
|
||||
sys.modules["onnx"] = onnx
|
||||
|
||||
|
||||
_install_dlrm_import_stubs()
|
||||
|
||||
import dlrm_s_pytorch as dlrm_mod
|
||||
from luminal import luminal_backend
|
||||
|
||||
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
|
||||
|
||||
|
||||
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
if isinstance(output, tuple) and len(output) == 1:
|
||||
return output[0]
|
||||
return output
|
||||
|
||||
|
||||
def _assert_allclose(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
|
||||
) -> None:
|
||||
max_diff = torch.max(torch.abs(lhs - rhs)).item()
|
||||
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
|
||||
|
||||
|
||||
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return _unwrap(model(*inputs))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
prev_recompile_limit = torch._dynamo.config.recompile_limit
|
||||
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = prev_recompile_limit
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
|
||||
|
||||
|
||||
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _compile_inductor(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
|
||||
|
||||
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(
|
||||
copy.deepcopy(model),
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
)
|
||||
|
||||
|
||||
def _compile_luminal(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
|
||||
|
||||
def _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters: int,
|
||||
timed_iters: int,
|
||||
mark_step_begin: bool = False,
|
||||
) -> dict[str, float]:
|
||||
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup_iters):
|
||||
if mark_step_begin:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
_unwrap(compiled_model(*inputs))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
|
||||
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
|
||||
|
||||
with torch.no_grad():
|
||||
for idx in range(timed_iters):
|
||||
if mark_step_begin:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
starts[idx].record()
|
||||
_unwrap(compiled_model(*inputs))
|
||||
ends[idx].record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = np.array(
|
||||
[start.elapsed_time(end) for start, end in zip(starts, ends)],
|
||||
dtype=np.float64,
|
||||
)
|
||||
return {
|
||||
"mean_ms": float(elapsed_ms.mean()),
|
||||
"median_ms": float(np.median(elapsed_ms)),
|
||||
"min_ms": float(elapsed_ms.min()),
|
||||
}
|
||||
|
||||
|
||||
def _timed_cuda_rounds(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
pre_round_warmup_iters: int,
|
||||
timed_iters: int,
|
||||
rounds: int,
|
||||
mark_step_begin: bool = False,
|
||||
) -> dict[str, float | list[float]]:
|
||||
assert rounds > 0, "rounds must be positive"
|
||||
|
||||
stats = _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters=pre_round_warmup_iters,
|
||||
timed_iters=timed_iters,
|
||||
mark_step_begin=mark_step_begin,
|
||||
)
|
||||
round_medians = [stats["median_ms"]]
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
stats = _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters=0,
|
||||
timed_iters=timed_iters,
|
||||
mark_step_begin=mark_step_begin,
|
||||
)
|
||||
round_medians.append(stats["median_ms"])
|
||||
|
||||
round_medians_np = np.array(round_medians, dtype=np.float64)
|
||||
return {
|
||||
"round_medians_ms": [float(value) for value in round_medians],
|
||||
"median_ms": float(np.median(round_medians_np)),
|
||||
"mean_ms": float(round_medians_np.mean()),
|
||||
"min_ms": float(round_medians_np.min()),
|
||||
}
|
||||
|
||||
|
||||
def _make_dlrm(
|
||||
device: torch.device,
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
m_spa = 4
|
||||
ln_emb = np.array([8, 6, 4])
|
||||
ln_bot = np.array([3, 4])
|
||||
num_fea = ln_emb.size + 1
|
||||
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
|
||||
ln_top = np.array([num_int, 8, 1])
|
||||
|
||||
model = dlrm_mod.DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op="dot",
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_top=1,
|
||||
).eval()
|
||||
|
||||
inputs = (
|
||||
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
|
||||
[
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
],
|
||||
[
|
||||
torch.tensor([1, 2], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 3], dtype=torch.int64, device=device),
|
||||
torch.tensor([2, 1], dtype=torch.int64, device=device),
|
||||
],
|
||||
)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def _make_dlrm_batch_2048(
|
||||
device: torch.device,
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
batch_size = 2048
|
||||
indices_per_bag = 2
|
||||
m_spa = 16
|
||||
ln_emb = np.array([4096, 2048, 1024])
|
||||
ln_bot = np.array([3, 64, m_spa])
|
||||
num_fea = ln_emb.size + 1
|
||||
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
|
||||
ln_top = np.array([num_int, 64, 32, 1])
|
||||
|
||||
model = dlrm_mod.DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op="dot",
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_top=2,
|
||||
).eval()
|
||||
|
||||
dense_x = torch.linspace(
|
||||
-1.0,
|
||||
1.0,
|
||||
steps=batch_size * 3,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
).reshape(batch_size, 3)
|
||||
|
||||
total_sparse_indices = batch_size * indices_per_bag
|
||||
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
|
||||
offsets = torch.arange(
|
||||
0,
|
||||
total_sparse_indices,
|
||||
indices_per_bag,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
inputs = (
|
||||
dense_x,
|
||||
[offsets.clone(), offsets.clone(), offsets.clone()],
|
||||
[
|
||||
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
|
||||
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
|
||||
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
|
||||
],
|
||||
)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
|
||||
model, inputs = _make_dlrm(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
if device.type != "cuda":
|
||||
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
|
||||
|
||||
model, inputs = _make_dlrm_batch_2048(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
|
||||
eager_model = copy.deepcopy(model).to(device).eval()
|
||||
inductor_compiled = _compile_inductor(model)
|
||||
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad():
|
||||
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
|
||||
|
||||
luminal_compiled = _compile_luminal(model)
|
||||
with torch.no_grad():
|
||||
luminal_output = _unwrap(luminal_compiled(*inputs))
|
||||
|
||||
with torch.no_grad():
|
||||
inductor_default_output = _unwrap(inductor_compiled(*inputs))
|
||||
|
||||
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
|
||||
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
|
||||
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
|
||||
_assert_allclose(
|
||||
luminal_output,
|
||||
inductor_default_output,
|
||||
"luminal vs inductor default",
|
||||
atol=1e-4,
|
||||
)
|
||||
_assert_allclose(
|
||||
luminal_output,
|
||||
inductor_output,
|
||||
"luminal vs inductor reduce-overhead",
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
benchmark_rounds = 5
|
||||
benchmark_iters = 20
|
||||
post_compile_warmup_iters = 10
|
||||
|
||||
eager_stats = _timed_cuda_rounds(
|
||||
eager_model,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
inductor_default_stats = _timed_cuda_rounds(
|
||||
inductor_compiled,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
inductor_stats = _timed_cuda_rounds(
|
||||
inductor_reduce_overhead,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
mark_step_begin=True,
|
||||
)
|
||||
luminal_stats = _timed_cuda_rounds(
|
||||
luminal_compiled,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
|
||||
batch_size = inputs[0].shape[0]
|
||||
benchmark_results = [
|
||||
("eager", eager_stats),
|
||||
("inductor default", inductor_default_stats),
|
||||
("inductor reduce-overhead", inductor_stats),
|
||||
("luminal backend", luminal_stats),
|
||||
]
|
||||
ranked_results = sorted(
|
||||
benchmark_results,
|
||||
key=lambda item: float(item[1]["median_ms"]),
|
||||
)
|
||||
speed_lines = []
|
||||
for idx, (label, stats) in enumerate(ranked_results, start=1):
|
||||
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
|
||||
rounds_repr = ", ".join(
|
||||
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
|
||||
)
|
||||
speed_lines.append(
|
||||
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
|
||||
f" ({throughput:,.0f} candidates/s)"
|
||||
f" [round medians: {rounds_repr}]"
|
||||
)
|
||||
|
||||
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
|
||||
inductor_stats["median_ms"]
|
||||
)
|
||||
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
|
||||
|
||||
print(
|
||||
"\n"
|
||||
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
|
||||
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
|
||||
f" Ranking by median latency:\n"
|
||||
+ "\n".join(speed_lines)
|
||||
+ "\n"
|
||||
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
|
||||
f" {luminal_vs_inductor:.3f}x\n"
|
||||
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
|
||||
)
|
||||
@@ -7,13 +7,14 @@
|
||||
//! - [`NativeDynBackend`]: the reference implementation for CPU
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use half::{bf16, f16};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::dtype::DType;
|
||||
use crate::graph::Graph;
|
||||
use crate::graph::{Graph, SearchOptions};
|
||||
use crate::hlir::{NativeData, NativeRuntime, Output};
|
||||
use crate::op::Runtime;
|
||||
|
||||
@@ -46,6 +47,18 @@ pub trait DynBackend {
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
|
||||
|
||||
// --- Optional diagnostics --------------------------------------------
|
||||
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn print_execution_stats(&self) {}
|
||||
|
||||
// --- Optional device pointer support (GPU backends) --------------------
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
@@ -162,7 +175,9 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, args.search_iters);
|
||||
let mut rng = rand::rng();
|
||||
let search_options = search_options_from_env(args.search_iters);
|
||||
let mut rt = graph.search_options(rt, search_options, &mut rng);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
@@ -179,6 +194,39 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
Ok(wrap(rt))
|
||||
}
|
||||
|
||||
fn env_usize(name: &str) -> Option<usize> {
|
||||
std::env::var(name).ok()?.parse().ok()
|
||||
}
|
||||
|
||||
fn env_duration_ms(name: &str) -> Option<Duration> {
|
||||
Some(Duration::from_millis(env_usize(name)? as u64))
|
||||
}
|
||||
|
||||
fn search_options_from_env(limit: usize) -> SearchOptions {
|
||||
let mut options = SearchOptions::new(limit);
|
||||
|
||||
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
|
||||
options = options.generation_size(generation_size);
|
||||
}
|
||||
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
|
||||
options = options.mutations(mutations);
|
||||
}
|
||||
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
|
||||
options = options.trials(trials);
|
||||
}
|
||||
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
|
||||
options = options.keep_best(keep_best);
|
||||
}
|
||||
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
|
||||
options = options.profile_timeout(profile_timeout);
|
||||
}
|
||||
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
|
||||
options = options.group_timeout(group_timeout);
|
||||
}
|
||||
|
||||
options
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
293
src/hlir.rs
293
src/hlir.rs
@@ -217,32 +217,38 @@ pub fn reduce_sort(name: &str) -> SortDef {
|
||||
}
|
||||
|
||||
pub type HLIROps = (
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
Exp2,
|
||||
Log2,
|
||||
Sin,
|
||||
Recip,
|
||||
Sqrt,
|
||||
Add,
|
||||
Mul,
|
||||
Mod,
|
||||
LessThan,
|
||||
Gather,
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
(
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
Exp2,
|
||||
Log2,
|
||||
),
|
||||
(
|
||||
Sin,
|
||||
Recip,
|
||||
Sqrt,
|
||||
Add,
|
||||
Mul,
|
||||
Mod,
|
||||
LessThan,
|
||||
Gather,
|
||||
Concat2D,
|
||||
EmbeddingBagSum,
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
),
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1721,7 +1727,9 @@ impl NativeOp for Add {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
|
||||
NativeData::Bool(a) => {
|
||||
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1808,7 +1816,9 @@ impl NativeOp for Mul {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
|
||||
NativeData::Bool(a) => {
|
||||
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2126,6 +2136,233 @@ impl NativeOp for Gather {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Concat2D {
|
||||
pub rows: Expression,
|
||||
pub lhs_cols: Expression,
|
||||
pub rhs_cols: Expression,
|
||||
}
|
||||
impl Display for Concat2D {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Concat2D")
|
||||
}
|
||||
}
|
||||
impl HLIROp for Concat2D {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (Concat2D {} {} {}) {})",
|
||||
self.rows.to_egglog(),
|
||||
self.lhs_cols.to_egglog(),
|
||||
self.rhs_cols.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for Concat2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"Concat2D",
|
||||
&[
|
||||
("rows", EXPRESSION),
|
||||
("lhs_cols", EXPRESSION),
|
||||
("rhs_cols", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for Concat2D {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
let rows = self.rows.exec(dyn_map).unwrap();
|
||||
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
|
||||
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
|
||||
|
||||
fn concat_rows<T: Clone>(
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
rows: usize,
|
||||
lhs_cols: usize,
|
||||
rhs_cols: usize,
|
||||
) -> Vec<T> {
|
||||
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
|
||||
for row in 0..rows {
|
||||
let lhs_base = row * lhs_cols;
|
||||
let rhs_base = row * rhs_cols;
|
||||
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
|
||||
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
match (inputs[0], inputs[1]) {
|
||||
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
|
||||
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
|
||||
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
|
||||
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
|
||||
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
|
||||
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
_ => panic!("Concat2D inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct EmbeddingBagSum {
|
||||
pub n_bags: Expression,
|
||||
pub n_indices: Expression,
|
||||
pub hidden_dim: Expression,
|
||||
pub num_embeddings: Expression,
|
||||
}
|
||||
impl Display for EmbeddingBagSum {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "EmbeddingBagSum")
|
||||
}
|
||||
}
|
||||
impl HLIROp for EmbeddingBagSum {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (EmbeddingBagSum {} {} {} {}) {})",
|
||||
self.n_bags.to_egglog(),
|
||||
self.n_indices.to_egglog(),
|
||||
self.hidden_dim.to_egglog(),
|
||||
self.num_embeddings.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for EmbeddingBagSum {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"EmbeddingBagSum",
|
||||
&[
|
||||
("n_bags", EXPRESSION),
|
||||
("n_indices", EXPRESSION),
|
||||
("hidden_dim", EXPRESSION),
|
||||
("num_embeddings", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for EmbeddingBagSum {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
let n_bags = self.n_bags.exec(dyn_map).unwrap();
|
||||
let n_indices = self.n_indices.exec(dyn_map).unwrap();
|
||||
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
|
||||
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
|
||||
|
||||
let clamp_index = |value: i32, limit: usize| -> usize {
|
||||
if limit == 0 {
|
||||
return 0;
|
||||
}
|
||||
value.clamp(0, limit.saturating_sub(1) as i32) as usize
|
||||
};
|
||||
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
|
||||
|
||||
let NativeData::Int(indices) = inputs[1] else {
|
||||
panic!("EmbeddingBagSum indices must be Int")
|
||||
};
|
||||
let NativeData::Int(offsets) = inputs[2] else {
|
||||
panic!("EmbeddingBagSum offsets must be Int")
|
||||
};
|
||||
|
||||
let bag_bounds = |bag: usize| -> (usize, usize) {
|
||||
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
|
||||
let end = offsets
|
||||
.get(bag + 1)
|
||||
.copied()
|
||||
.map(clamp_offset)
|
||||
.unwrap_or(n_indices);
|
||||
(start, end.max(start))
|
||||
};
|
||||
|
||||
match inputs[0] {
|
||||
NativeData::F32(weight) => {
|
||||
let mut out = vec![0.0f32; n_bags * hidden_dim];
|
||||
for bag in 0..n_bags {
|
||||
let (start, end) = bag_bounds(bag);
|
||||
let out_base = bag * hidden_dim;
|
||||
for pos in start..end {
|
||||
let row = clamp_index(indices[pos], num_embeddings);
|
||||
let row_base = row * hidden_dim;
|
||||
for dim in 0..hidden_dim {
|
||||
out[out_base + dim] += weight[row_base + dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
NativeData::F32(out)
|
||||
}
|
||||
_ => panic!("EmbeddingBagSum only supports F32 weights"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter Op (inverse of Gather)
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
|
||||
Reference in New Issue
Block a user