mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
4 Commits
codex/dlrm
...
codex/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b1e09cf23 | ||
|
|
6416ddb5f8 | ||
|
|
c9d4ce6217 | ||
|
|
7402503bd4 |
@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
|
||||
@@ -5,7 +5,6 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::host::describe_host_op;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
@@ -40,26 +39,6 @@ impl DynBackend for CudaLiteDynBackend {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
self.runtime
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.map(|name| (*name).to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
self.runtime
|
||||
.host_ops()
|
||||
.iter()
|
||||
.map(|op| describe_host_op(*op))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn print_execution_stats(&self) {
|
||||
self.runtime.print_execution_stats();
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -247,10 +247,6 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("CuBlasSgemmV2")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
@@ -419,6 +419,7 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
|
||||
match epilogue {
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
|
||||
@@ -977,18 +978,6 @@ impl CuBlasLt {
|
||||
&& normalize(self.stride_c) == normalize(self.stride_d)
|
||||
&& self.c_order == self.d_order
|
||||
}
|
||||
|
||||
pub(crate) fn debug_summary(&self) -> String {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
format!(
|
||||
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
|
||||
resolve(&self.m),
|
||||
resolve(&self.n),
|
||||
resolve(&self.k),
|
||||
resolve(&self.batch_count),
|
||||
epilogue_name(self.epilogue),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
@@ -1125,10 +1114,6 @@ impl HostOp for CuBlasLt {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("CuBlasLt")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use crate::kernel::CudaGraphOp;
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
@@ -80,24 +79,6 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
|
||||
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
|
||||
return op.debug_summary();
|
||||
}
|
||||
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
|
||||
let mut summary = op.debug_summary();
|
||||
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
|
||||
&& let Some(timing) = op.debug_timing_summary()
|
||||
{
|
||||
summary.push_str(" [");
|
||||
summary.push_str(&timing);
|
||||
summary.push(']');
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
op.stats_name().unwrap_or("unknown").to_string()
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
|
||||
@@ -12,10 +12,7 @@ use luminal::{
|
||||
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
hlir::{
|
||||
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
|
||||
Sin, Sqrt, SumReduce,
|
||||
},
|
||||
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
@@ -68,8 +65,6 @@ pub type Ops = (
|
||||
KernelConstant,
|
||||
KernelCast,
|
||||
KernelEmbed,
|
||||
KernelConcat2D,
|
||||
KernelEmbeddingBagSum,
|
||||
);
|
||||
|
||||
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
|
||||
@@ -1545,19 +1540,22 @@ impl KernelOp for KernelIota {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
vars.extend(self.range.dyn_vars());
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let range = self.range.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void iota_k(int *C{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {range}) return;
|
||||
C[const_z] = {};
|
||||
}}
|
||||
}}",
|
||||
@@ -1576,8 +1574,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.range, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.range.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2940,6 +2938,14 @@ impl KernelOp for KernelCast {
|
||||
) {
|
||||
let out_dtype = cuda_dtype(self.out_dtype);
|
||||
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
|
||||
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let size = self.size.to_kernel();
|
||||
|
||||
let kernel = if self.in_dtype.bits() < 8 {
|
||||
// Sub-byte packed types: multiple values packed per byte.
|
||||
@@ -2949,9 +2955,11 @@ impl KernelOp for KernelCast {
|
||||
let mask = (1u32 << bits) - 1;
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {size}) return;
|
||||
long long bit_offset = idx * {bits};
|
||||
long long byte_idx = bit_offset >> 3;
|
||||
int bit_pos = (int)(bit_offset & 7);
|
||||
@@ -2967,9 +2975,11 @@ extern \"C\" {{
|
||||
let in_dtype = cuda_dtype(self.in_dtype);
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {size}) return;
|
||||
out[const_z] = ({out_dtype})in[const_z];
|
||||
}}
|
||||
}}"
|
||||
@@ -2988,8 +2998,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.size, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.size.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -3280,12 +3290,15 @@ impl KernelOp for KernelEmbed {
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
let n_elements = total_threads.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {n_elements}) return;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
long long embed_idx = idx % embed_dim;
|
||||
@@ -3308,13 +3321,12 @@ extern \"C\" {{
|
||||
};
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(total_threads, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(total_threads.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
constants,
|
||||
)
|
||||
@@ -3359,361 +3371,3 @@ extern \"C\" {{
|
||||
"Embed"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelEmbeddingBagSum {
|
||||
n_bags: Expression,
|
||||
n_indices: Expression,
|
||||
hidden_dim: Expression,
|
||||
num_embeddings: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelEmbeddingBagSum {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelEmbeddingBagSum",
|
||||
&[
|
||||
("n_bags", EXPRESSION),
|
||||
("n_indices", EXPRESSION),
|
||||
("hidden_dim", EXPRESSION),
|
||||
("num_embeddings", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelEmbeddingBagSum {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert!(
|
||||
self.dtype == DType::F32,
|
||||
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
|
||||
self.dtype
|
||||
);
|
||||
let vars = self
|
||||
.n_bags
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.n_indices.dyn_vars())
|
||||
.chain(self.hidden_dim.dyn_vars())
|
||||
.chain(self.num_embeddings.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_bags = self.n_bags.to_kernel();
|
||||
let n_indices = self.n_indices.to_kernel();
|
||||
let hidden_dim = self.hidden_dim.to_kernel();
|
||||
let num_embeddings = self.num_embeddings.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
|
||||
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long bag = blockIdx.y;
|
||||
long long hidden_dim = {hidden_dim};
|
||||
long long n_bags = {n_bags};
|
||||
long long n_indices = {n_indices};
|
||||
long long num_embeddings = {num_embeddings};
|
||||
if (bag >= n_bags || dim >= hidden_dim) return;
|
||||
|
||||
int start_raw = offsets[bag];
|
||||
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
|
||||
int start = max(0, min(start_raw, (int)n_indices));
|
||||
int end = max(start, min(end_raw, (int)n_indices));
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int pos = start; pos < end; ++pos) {{
|
||||
int row = indices[pos];
|
||||
row = max(0, min(row, (int)num_embeddings - 1));
|
||||
sum += weight[(long long)row * hidden_dim + dim];
|
||||
}}
|
||||
|
||||
out[bag * hidden_dim + dim] = sum;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("embedding_bag_sum").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
|
||||
(self.hidden_dim.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.n_bags * self.hidden_dim
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.n_bags
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.n_indices.dyn_vars())
|
||||
.chain(self.hidden_dim.dyn_vars())
|
||||
.chain(self.num_embeddings.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Approximate: weights + indices + offsets
|
||||
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.n_indices * self.hidden_dim
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"EmbeddingBagSum"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConcat2D {
|
||||
rows: Expression,
|
||||
lhs_cols: Expression,
|
||||
rhs_cols: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelConcat2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelConcat2D",
|
||||
&[
|
||||
("rows", EXPRESSION),
|
||||
("lhs_cols", EXPRESSION),
|
||||
("rhs_cols", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite::<Concat2D, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelConcat2D {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert!(
|
||||
self.dtype == DType::F32,
|
||||
"KernelConcat2D only supports F32 today, got {:?}",
|
||||
self.dtype
|
||||
);
|
||||
let vars = self
|
||||
.rows
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.lhs_cols.dyn_vars())
|
||||
.chain(self.rhs_cols.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let rows = self.rows.to_kernel();
|
||||
let lhs_cols = self.lhs_cols.to_kernel();
|
||||
let rhs_cols = self.rhs_cols.to_kernel();
|
||||
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = {total};
|
||||
if (idx >= total) return;
|
||||
|
||||
long long rows = {rows};
|
||||
long long lhs_cols = {lhs_cols};
|
||||
long long rhs_cols = {rhs_cols};
|
||||
long long out_cols = lhs_cols + rhs_cols;
|
||||
if (rows == 0 || out_cols == 0) return;
|
||||
|
||||
long long row = idx / out_cols;
|
||||
long long col = idx - row * out_cols;
|
||||
if (col < lhs_cols) {{
|
||||
out[idx] = lhs[row * lhs_cols + col];
|
||||
}} else {{
|
||||
long long rhs_col = col - lhs_cols;
|
||||
out[idx] = rhs[row * rhs_cols + rhs_col];
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("concat_2d").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let output_size = self.output_size();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(output_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(output_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.rows * (self.lhs_cols + self.rhs_cols)
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.rows
|
||||
.dyn_vars()
|
||||
.into_iter()
|
||||
.chain(self.lhs_cols.dyn_vars())
|
||||
.chain(self.rhs_cols.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Concat2D"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
//! that can be executed like any other HostOp.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
@@ -143,8 +141,6 @@ struct CudaGraphOpState {
|
||||
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
|
||||
/// Timing events for profiling
|
||||
timing_events: Vec<cudarc::driver::sys::CUevent>,
|
||||
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
|
||||
last_kernel_timings_us: Vec<(&'static str, f64)>,
|
||||
}
|
||||
|
||||
impl CudaGraphOpState {
|
||||
@@ -159,7 +155,6 @@ impl CudaGraphOpState {
|
||||
last_dyn_values: FxHashMap::default(),
|
||||
last_buffer_ptrs: FxHashMap::default(),
|
||||
timing_events: Vec::new(),
|
||||
last_kernel_timings_us: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,41 +192,6 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_summary(&self) -> String {
|
||||
let state = self.state.borrow();
|
||||
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
|
||||
for kernel in &state.kernels {
|
||||
*counts.entry(kernel.kernel_name).or_default() += 1;
|
||||
}
|
||||
let mut counts: Vec<_> = counts.into_iter().collect();
|
||||
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
|
||||
let top = counts
|
||||
.into_iter()
|
||||
.take(4)
|
||||
.map(|(name, count)| format!("{name}x{count}"))
|
||||
.join(", ");
|
||||
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
|
||||
}
|
||||
|
||||
pub fn debug_timing_summary(&self) -> Option<String> {
|
||||
let state = self.state.borrow();
|
||||
if state.last_kernel_timings_us.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
|
||||
for (name, us) in &state.last_kernel_timings_us {
|
||||
*totals.entry(*name).or_default() += *us;
|
||||
}
|
||||
let mut totals: Vec<_> = totals.into_iter().collect();
|
||||
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
|
||||
let top = totals
|
||||
.into_iter()
|
||||
.take(4)
|
||||
.map(|(name, us)| format!("{name}={us:.0}us"))
|
||||
.join(", ");
|
||||
Some(top)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -606,23 +566,6 @@ impl CudaGraphOp {
|
||||
// Launch the graph
|
||||
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
|
||||
|
||||
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
|
||||
&& state.timing_events.len() >= state.kernels.len() + 1
|
||||
{
|
||||
stream.synchronize()?;
|
||||
let ctx = stream.context().clone();
|
||||
state.last_kernel_timings_us.clear();
|
||||
for idx in 0..state.kernels.len() {
|
||||
let start_event = state.timing_events[idx];
|
||||
let end_event = state.timing_events[idx + 1];
|
||||
let kernel_name = state.kernels[idx].kernel_name;
|
||||
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
|
||||
.map(|ms| ms as f64 * 1_000.0)
|
||||
.unwrap_or(0.0);
|
||||
state.last_kernel_timings_us.push((kernel_name, us));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -641,9 +584,8 @@ impl CudaGraphOp {
|
||||
state.kernel_params.clear();
|
||||
state.kernel_params.reserve(num_kernels);
|
||||
|
||||
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
|
||||
let tracing_enabled = enabled!(Level::TRACE);
|
||||
if tracing_enabled || profile_cuda_graph {
|
||||
if tracing_enabled {
|
||||
let needed_events = num_kernels + 1;
|
||||
while state.timing_events.len() < needed_events {
|
||||
state.timing_events.push(create_cuda_event(&ctx)?);
|
||||
@@ -759,7 +701,7 @@ impl CudaGraphOp {
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled || profile_cuda_graph {
|
||||
let timing_event = if tracing_enabled {
|
||||
Some(state.timing_events[idx])
|
||||
} else {
|
||||
None
|
||||
@@ -797,9 +739,7 @@ impl CudaGraphOp {
|
||||
prev_graph_node = Some(graph_node);
|
||||
}
|
||||
|
||||
if (tracing_enabled || profile_cuda_graph)
|
||||
&& let Some(prev) = prev_graph_node
|
||||
{
|
||||
if tracing_enabled && let Some(prev) = prev_graph_node {
|
||||
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use crate::{
|
||||
host::{DeviceBuffer, HostOp, describe_host_op},
|
||||
kernel::{
|
||||
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
|
||||
record_cuda_graph_timings, record_event_on_stream,
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
|
||||
|
||||
@@ -63,12 +60,6 @@ pub struct KernelStats {
|
||||
pub tflops: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HostOpStats {
|
||||
pub name: String,
|
||||
pub execution_time_us: f64,
|
||||
}
|
||||
|
||||
impl Debug for ExecutableHostOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "HostOp: ({:?})", self.internal)
|
||||
@@ -150,7 +141,6 @@ pub struct CudaRuntime {
|
||||
changed_hlir: FxHashSet<NodeIndex>,
|
||||
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
|
||||
pub last_kernel_stats: Vec<KernelStats>,
|
||||
pub last_host_op_stats: Vec<HostOpStats>,
|
||||
pub last_total_time_us: f64,
|
||||
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
/// When true, execute() skips input buffer consumption (used during search/profile)
|
||||
@@ -1213,7 +1203,6 @@ impl Runtime for CudaRuntime {
|
||||
changed_hlir: FxHashSet::default(),
|
||||
cuda_graph_timings: vec![],
|
||||
last_kernel_stats: vec![],
|
||||
last_host_op_stats: vec![],
|
||||
last_total_time_us: 0.0,
|
||||
kernel_cache: FxHashMap::default(),
|
||||
profiling: false,
|
||||
@@ -1465,8 +1454,6 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
|
||||
let mut host_timing_events = Vec::new();
|
||||
|
||||
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
@@ -1520,16 +1507,6 @@ impl Runtime for CudaRuntime {
|
||||
n_inputs = exec_op.inputs.len()
|
||||
)
|
||||
.entered();
|
||||
let host_op_timing = if profile_host_ops {
|
||||
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
|
||||
let ctx = exec_op.stream.context().clone();
|
||||
let start_event = create_cuda_event(&ctx).unwrap();
|
||||
let end_event = create_cuda_event(&ctx).unwrap();
|
||||
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
|
||||
Some((name, ctx, start_event, end_event))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
exec_op
|
||||
.internal
|
||||
.execute(
|
||||
@@ -1545,28 +1522,10 @@ impl Runtime for CudaRuntime {
|
||||
exec_op.internal.stats_name().unwrap_or("unknown")
|
||||
);
|
||||
});
|
||||
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
|
||||
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
|
||||
host_timing_events.push((name, ctx, start_event, end_event));
|
||||
}
|
||||
}
|
||||
// Single sync at end - CUDA stream ordering guarantees sequential execution
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
self.last_host_op_stats.clear();
|
||||
if profile_host_ops {
|
||||
for (name, ctx, start_event, end_event) in host_timing_events {
|
||||
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
|
||||
.map(|ms| ms as f64 * 1_000.0)
|
||||
.unwrap_or(0.0);
|
||||
self.last_host_op_stats.push(HostOpStats {
|
||||
name,
|
||||
execution_time_us,
|
||||
});
|
||||
destroy_cuda_event(&ctx, start_event);
|
||||
destroy_cuda_event(&ctx, end_event);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate last_kernel_stats from HostOps that report stats
|
||||
self.last_kernel_stats.clear();
|
||||
@@ -1902,16 +1861,6 @@ impl CudaRuntime {
|
||||
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
|
||||
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
|
||||
|
||||
if !self.last_host_op_stats.is_empty() {
|
||||
println!("\n=== Host Operation Statistics ===\n");
|
||||
println!("{:<20} {:>12}", "HostOp", "Time (us)");
|
||||
println!("{}", "-".repeat(34));
|
||||
for s in &self.last_host_op_stats {
|
||||
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
|
||||
}
|
||||
println!("{}", "-".repeat(34));
|
||||
}
|
||||
|
||||
// Print kernel stats
|
||||
if !self.last_kernel_stats.is_empty() {
|
||||
println!("\n=== Kernel Execution Statistics ===\n");
|
||||
|
||||
@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
|
||||
@@ -61,7 +61,8 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -478,7 +479,8 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -855,8 +855,6 @@ Two important details:
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
|
||||
@@ -98,7 +98,12 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -124,7 +129,9 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -248,23 +255,6 @@ impl CompiledGraph {
|
||||
self.runtime.device_type()
|
||||
}
|
||||
|
||||
/// Names of kernels compiled into the active runtime bucket, if available.
|
||||
#[getter]
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
self.runtime.kernel_names()
|
||||
}
|
||||
|
||||
/// Names of host ops in the active runtime bucket, if available.
|
||||
#[getter]
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
self.runtime.host_op_names()
|
||||
}
|
||||
|
||||
/// Print backend execution statistics for the last run, if supported.
|
||||
fn print_execution_stats(&self) {
|
||||
self.runtime.print_execution_stats();
|
||||
}
|
||||
|
||||
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
|
||||
#[getter]
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
@@ -493,10 +483,7 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
self.output_dtypes.clone()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -262,10 +262,13 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,14 +284,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -15,7 +15,16 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
pub min_val: i64,
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Result, bail};
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
@@ -8,62 +8,21 @@ use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let alpha = match op {
|
||||
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
|
||||
BinaryOp::Mul | BinaryOp::Div => 1.0,
|
||||
};
|
||||
|
||||
let lhs = node.inputs[0]
|
||||
.arg
|
||||
.as_tensor_name()
|
||||
.map(|name| self.get_tensor(name))
|
||||
.transpose()?;
|
||||
let rhs = node.inputs[1]
|
||||
.arg
|
||||
.as_tensor_name()
|
||||
.map(|name| self.get_tensor(name))
|
||||
.transpose()?;
|
||||
|
||||
match (lhs, rhs) {
|
||||
(Some(a), Some(mut b)) => {
|
||||
if alpha != 1.0 {
|
||||
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
|
||||
}
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
}
|
||||
(Some(a), None) => {
|
||||
let mut val = self.get_float_arg(node, 1)? as f32;
|
||||
if alpha != 1.0 {
|
||||
val *= alpha;
|
||||
}
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
(None, Some(mut b)) => {
|
||||
if alpha != 1.0 {
|
||||
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
|
||||
}
|
||||
let lhs_val = self.get_float_arg(node, 0)? as f32;
|
||||
let a = self
|
||||
.graph
|
||||
.constant_float(lhs_val)
|
||||
.cast(b.dtype)
|
||||
.expand_rhs(b.shape);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
}
|
||||
(None, None) => bail!("{} expects at least one tensor operand", node.target),
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,8 +389,11 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -111,7 +112,6 @@ impl<'a> Translator<'a> {
|
||||
result
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
|
||||
"torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
@@ -134,28 +134,8 @@ impl<'a> Translator<'a> {
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
if alpha == 0.0 && beta == 0.0 {
|
||||
self.graph
|
||||
.constant_float(0.0)
|
||||
.cast(mm.dtype)
|
||||
.expand_rhs(mm.shape)
|
||||
} else if beta == 0.0 {
|
||||
if alpha == 1.0 { mm } else { mm * alpha }
|
||||
} else if alpha == 0.0 {
|
||||
let input = if beta == 1.0 { input } else { input * beta };
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(input.dtype)
|
||||
.expand_rhs(mm.shape);
|
||||
let (input, _) = broadcast_binary(input, zero);
|
||||
input
|
||||
} else {
|
||||
let input = if beta == 1.0 { input } else { input * beta };
|
||||
let mm = if alpha == 1.0 { mm } else { mm * alpha };
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input + mm
|
||||
}
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
}
|
||||
|
||||
// Convolution
|
||||
@@ -171,11 +151,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
"torch.ops.aten._embedding_bag.default"
|
||||
| "torch.ops.aten._embedding_bag_forward_only.default" => {
|
||||
self.translate_embedding_bag(node)?
|
||||
}
|
||||
"<built-in function getitem>" => self.translate_getitem(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
@@ -190,9 +165,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
|
||||
self.translate_native_batch_norm_no_training(node)?
|
||||
}
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
@@ -249,6 +221,16 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -266,6 +248,13 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -304,18 +293,27 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
a.cumsum(dim)
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -411,6 +409,17 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -474,6 +483,28 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -12,64 +12,6 @@ const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn try_concat_2d_fast(
|
||||
&mut self,
|
||||
lhs: GraphTensor,
|
||||
rhs: GraphTensor,
|
||||
axis: usize,
|
||||
) -> Option<GraphTensor> {
|
||||
if axis != 1
|
||||
|| lhs.dtype != DType::F32
|
||||
|| rhs.dtype != DType::F32
|
||||
|| lhs.shape.len() != 2
|
||||
|| rhs.shape.len() != 2
|
||||
|| !lhs.shape.is_contiguous()
|
||||
|| !rhs.shape.is_contiguous()
|
||||
|| lhs.shape.dims[0] != rhs.shape.dims[0]
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let rows = lhs.shape.dims[0];
|
||||
let lhs_cols = lhs.shape.dims[1];
|
||||
let rhs_cols = rhs.shape.dims[1];
|
||||
let id = self.graph.add_op(
|
||||
luminal::hlir::Concat2D {
|
||||
rows,
|
||||
lhs_cols,
|
||||
rhs_cols,
|
||||
},
|
||||
&[lhs.id, rhs.id],
|
||||
);
|
||||
|
||||
Some(GraphTensor::from_id(
|
||||
id,
|
||||
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
|
||||
lhs.graph_ref,
|
||||
lhs.dtype,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
|
||||
let index = self
|
||||
.get_int_arg(node, 2)
|
||||
.context("select.int: missing index")?;
|
||||
|
||||
let dim_size = a.shape.dims[dim]
|
||||
.to_usize()
|
||||
.context("select.int: symbolic dims are not supported for negative indices")?;
|
||||
let normalized_index = if index < 0 {
|
||||
(dim_size as i64 + index) as usize
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
|
||||
.squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
@@ -138,43 +80,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
|
||||
sizes
|
||||
.into_iter()
|
||||
.map(|size| {
|
||||
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
|
||||
Ok(Expression::from(size as usize))
|
||||
})
|
||||
.collect::<Result<_>>()?
|
||||
} else {
|
||||
self.get_exprs_arg(node, 1)?
|
||||
};
|
||||
|
||||
anyhow::ensure!(
|
||||
repeats.len() >= a.shape.len(),
|
||||
"repeat: repeats rank {} is smaller than input rank {}",
|
||||
repeats.len(),
|
||||
a.shape.len()
|
||||
);
|
||||
|
||||
while a.shape.len() < repeats.len() {
|
||||
a = a.unsqueeze(0);
|
||||
}
|
||||
|
||||
Ok(a.repeat(repeats))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let index = self.get_int_arg(node, 1)?;
|
||||
anyhow::ensure!(
|
||||
index == 0,
|
||||
"getitem: only tuple[0] access is supported today, got index={index}"
|
||||
);
|
||||
self.get_input_tensor(node, 0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1).unwrap_or(0);
|
||||
@@ -215,6 +120,47 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -256,11 +202,7 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
for t in &tensors[1..] {
|
||||
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
|
||||
result = fast;
|
||||
} else {
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -317,79 +259,6 @@ impl<'a> Translator<'a> {
|
||||
bail!("index.Tensor: no index tensors in optional_tensors list");
|
||||
}
|
||||
index_names = found_tensors;
|
||||
|
||||
// Multiple explicit index tensors after leading `None`s mean
|
||||
// "keep the prefix dims, then advanced-index the contiguous
|
||||
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
|
||||
if first_non_none_dim > 0
|
||||
&& index_names.len() > 1
|
||||
&& first_non_none_dim + index_names.len() == source.shape.len()
|
||||
{
|
||||
let src_dims = source.shape.dims;
|
||||
let indexed_dims = &src_dims[first_non_none_dim..];
|
||||
let n_indexed = index_names.len();
|
||||
|
||||
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
|
||||
for i in (0..n_indexed - 1).rev() {
|
||||
strides[i] = strides[i + 1] * indexed_dims[i + 1];
|
||||
}
|
||||
|
||||
let mut flat_idx: Option<GraphTensor> = None;
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
let axis_size = indexed_dims[dim_idx];
|
||||
let idx_int = idx_tensor.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
|
||||
let is_negative = idx_int.lt(zero).cast(DType::Int);
|
||||
let idx_int = idx_int + is_negative * axis_size;
|
||||
|
||||
let stride = strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
idx_int
|
||||
} else {
|
||||
idx_int * stride
|
||||
};
|
||||
|
||||
flat_idx = Some(match flat_idx {
|
||||
Some(acc) => {
|
||||
let (acc_b, w_b) = broadcast_binary(acc, weighted);
|
||||
acc_b + w_b
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
|
||||
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
|
||||
let idx_shape = flat_idx.shape.dims.to_vec();
|
||||
let mut idx_numel = Expression::from(1usize);
|
||||
for dim in &idx_shape {
|
||||
idx_numel *= *dim;
|
||||
}
|
||||
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
|
||||
|
||||
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
|
||||
let mut indexed_size = Expression::from(1usize);
|
||||
for dim in indexed_dims {
|
||||
indexed_size *= *dim;
|
||||
}
|
||||
let mut flat_source_shape = prefix_dims.clone();
|
||||
flat_source_shape.push(indexed_size);
|
||||
let flat_source = reshape_tensor(source, flat_source_shape);
|
||||
|
||||
let mut expanded_idx = flat_idx;
|
||||
for _ in 0..prefix_dims.len() {
|
||||
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let mut target = prefix_dims.clone();
|
||||
target.push(idx_numel);
|
||||
expanded_idx.shape.expand(target);
|
||||
|
||||
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
|
||||
let mut result_shape = prefix_dims;
|
||||
result_shape.extend_from_slice(&idx_shape);
|
||||
return Ok(reshape_tensor(gathered, result_shape));
|
||||
}
|
||||
|
||||
// Simple case: single non-None index on a specific dim → gather_elements
|
||||
if first_non_none_dim > 0 && index_names.len() == 1 {
|
||||
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
@@ -505,6 +374,17 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
@@ -516,7 +396,12 @@ impl<'a> Translator<'a> {
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
|
||||
@@ -6,6 +6,20 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -37,32 +51,26 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
|
||||
// override below assumes contiguous, no-broadcast storage —
|
||||
// otherwise the `[1, N]` view treats stride-0 broadcast dims
|
||||
// as if they held N distinct values and reads past the backing
|
||||
// buffer. Materialize first when that's not the case (matches
|
||||
// the guard `translate_reshape` already applies).
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
let total = concrete_numel(&a)?;
|
||||
let has_broadcast = a
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(a.shape.strides.iter())
|
||||
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
|
||||
let a = if has_broadcast || !a.shape.is_contiguous() {
|
||||
a + 0.0
|
||||
} else {
|
||||
a
|
||||
};
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -86,4 +94,100 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,45 +28,6 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_embedding_bag_generic(
|
||||
&mut self,
|
||||
weight: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
offsets: GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
let n_indices = indices.shape.dims[0];
|
||||
let n_bags = offsets.shape.dims[0];
|
||||
|
||||
// Gather per-index embeddings: [E] -> [E, D].
|
||||
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
|
||||
let gathered = weight.gather(ids_expanded + arange);
|
||||
|
||||
// Bag assignment per position:
|
||||
// bag_id[pos] = count(offsets <= pos) - 1
|
||||
// This supports empty bags too, because repeated offsets simply skip a
|
||||
// bag id when no positions land in that interval.
|
||||
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
|
||||
let starts = offsets.expand_dim(1, n_indices);
|
||||
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
|
||||
- self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(vec![n_indices]);
|
||||
|
||||
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
|
||||
let bag_ids = bag_ids.expand_dim(0, n_bags);
|
||||
let mask = bag_ids
|
||||
.eq(bag_axis)
|
||||
.expand_dim(2, hidden_dim)
|
||||
.cast(gathered.dtype);
|
||||
let gathered = gathered.expand_dim(0, n_bags);
|
||||
|
||||
Ok((gathered * mask).sum(1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
.inputs
|
||||
@@ -219,45 +180,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
|
||||
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
anyhow::ensure!(
|
||||
mode == 0,
|
||||
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
|
||||
);
|
||||
|
||||
anyhow::ensure!(
|
||||
indices.shape.len() == 1 && offsets.shape.len() == 1,
|
||||
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
|
||||
indices.shape.len(),
|
||||
offsets.shape.len()
|
||||
);
|
||||
|
||||
if weight.dtype == DType::F32 {
|
||||
let id = self.graph.add_op(
|
||||
luminal::hlir::EmbeddingBagSum {
|
||||
n_bags: offsets.shape.dims[0],
|
||||
n_indices: indices.shape.dims[0],
|
||||
hidden_dim: weight.shape.dims[1],
|
||||
num_embeddings: weight.shape.dims[0],
|
||||
},
|
||||
&[weight.id, indices.id, offsets.id],
|
||||
);
|
||||
return Ok(GraphTensor::from_id(
|
||||
id,
|
||||
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
|
||||
weight.graph_ref,
|
||||
DType::F32,
|
||||
));
|
||||
}
|
||||
|
||||
self.translate_embedding_bag_generic(weight, indices, offsets)
|
||||
}
|
||||
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
|
||||
@@ -21,30 +21,6 @@ const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn expand_channel_parameter(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
parameter: GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
anyhow::ensure!(
|
||||
input.shape.len() >= 2,
|
||||
"batch_norm: expected rank >= 2 input, got rank {}",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
parameter.shape.len() == 1,
|
||||
"batch_norm: expected 1D channel parameter, got rank {}",
|
||||
parameter.shape.len()
|
||||
);
|
||||
|
||||
let mut expanded = parameter.unsqueeze(0);
|
||||
for axis in 2..input.shape.len() {
|
||||
expanded = expanded.unsqueeze(axis);
|
||||
}
|
||||
expanded.shape.expand(input.dims().to_vec());
|
||||
Ok(expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
@@ -125,30 +101,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_native_batch_norm_no_training(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
|
||||
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
|
||||
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
|
||||
|
||||
let mut result = (input - running_mean) / (running_var + eps).sqrt();
|
||||
|
||||
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
|
||||
result = result * weight;
|
||||
}
|
||||
|
||||
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
|
||||
result = result + bias;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
@@ -261,12 +213,18 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -317,4 +275,52 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,14 +8,12 @@ from .compiled_model import CompiledModel
|
||||
# Import Rust extension components (built by maturin)
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .main import luminal_backend, register_backend
|
||||
from .pt2 import compile
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"compile",
|
||||
"luminal_backend",
|
||||
"register_backend",
|
||||
"CompiledGraph",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
@@ -27,14 +28,6 @@ class CompiledModel:
|
||||
self._input_names = input_names or graph_result.input_names
|
||||
self._output_names = graph_result.output_names
|
||||
self._output_shapes = graph_result.output_shapes
|
||||
output_dtype_codes = graph_result.output_dtypes
|
||||
self._output_dtypes = [
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._output_names))
|
||||
]
|
||||
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
@@ -42,9 +35,6 @@ class CompiledModel:
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Cache converted/contiguous views for repeated calls with the same input
|
||||
# tensors so we don't rebuild sparse index buffers every forward.
|
||||
self._prepared_input_cache = [None] * len(self._input_names)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
@@ -53,80 +43,6 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
self._single_float_output_fast_path = (
|
||||
self._supports_device_ptrs
|
||||
and not self._has_dynamic_dims
|
||||
and len(self._output_names) == 1
|
||||
and self._output_dtypes[0].is_floating_point
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
|
||||
return (
|
||||
id(tensor),
|
||||
getattr(tensor, "_version", None),
|
||||
tensor.data_ptr(),
|
||||
expected_dtype,
|
||||
)
|
||||
|
||||
def _prepare_input_tensor(
|
||||
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
detached = tensor.detach()
|
||||
needs_preparation = (
|
||||
detached.dtype != expected_dtype or not detached.is_contiguous()
|
||||
)
|
||||
if not needs_preparation:
|
||||
return detached
|
||||
|
||||
cache_key = self._input_cache_key(detached, expected_dtype)
|
||||
cached = self._prepared_input_cache[index]
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
|
||||
prepared = detached.contiguous().to(expected_dtype)
|
||||
self._prepared_input_cache[index] = (cache_key, prepared)
|
||||
return prepared
|
||||
|
||||
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
"""Bind the current user inputs into the Rust graph."""
|
||||
input_refs = []
|
||||
for index, (name, tensor, expected_dtype) in enumerate(
|
||||
zip(self._input_names, user_inputs, self._input_dtypes)
|
||||
):
|
||||
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
n_bytes = prepared.numel() * prepared.element_size()
|
||||
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
|
||||
input_refs.append(prepared)
|
||||
else:
|
||||
if prepared.device.type != "cpu":
|
||||
prepared = prepared.cpu()
|
||||
n_bytes = prepared.numel() * prepared.element_size()
|
||||
dtype_code = _torch_dtype_code(prepared.dtype)
|
||||
self._graph.set_input_from_ptr(
|
||||
name, prepared.data_ptr(), n_bytes, dtype_code
|
||||
)
|
||||
return input_refs
|
||||
|
||||
def _run_static_single_float_output(
|
||||
self, user_inputs: List[torch.Tensor], input_device: torch.device
|
||||
):
|
||||
_input_refs = self._bind_user_inputs(user_inputs)
|
||||
output_name = self._output_names[0]
|
||||
output_dtype = self._output_dtypes[0]
|
||||
out = torch.empty(
|
||||
self._static_output_shapes[0], dtype=output_dtype, device=input_device
|
||||
)
|
||||
self._graph.set_output_device_ptr(
|
||||
output_name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
self._graph.run()
|
||||
if not self._graph.output_is_zero_copy(output_name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
output_name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
return (out,)
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -170,18 +86,25 @@ class CompiledModel:
|
||||
if self._has_dynamic_dims:
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
elif (
|
||||
self._single_float_output_fast_path
|
||||
and input_device.type != "cpu"
|
||||
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
|
||||
):
|
||||
return self._run_static_single_float_output(user_inputs, input_device)
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = self._bind_user_inputs(user_inputs)
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
@@ -189,6 +112,8 @@ class CompiledModel:
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
@@ -196,8 +121,8 @@ class CompiledModel:
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
@@ -210,13 +135,18 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = output_tensors[i]
|
||||
@@ -225,11 +155,12 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -253,13 +184,17 @@ class CompiledModel:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
self._output_dtypes[i]
|
||||
if i < len(self._output_dtypes)
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
if out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
@@ -274,41 +209,3 @@ class CompiledModel:
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
def _leaf_paths(tree: Any, prefix=()):
|
||||
if torch.is_tensor(tree):
|
||||
return [prefix]
|
||||
if isinstance(tree, (list, tuple)):
|
||||
paths = []
|
||||
for idx, value in enumerate(tree):
|
||||
paths.extend(_leaf_paths(value, prefix + (idx,)))
|
||||
return paths
|
||||
if isinstance(tree, dict):
|
||||
paths = []
|
||||
for key, value in tree.items():
|
||||
paths.extend(_leaf_paths(value, prefix + (key,)))
|
||||
return paths
|
||||
return [prefix]
|
||||
|
||||
|
||||
def _follow_path(tree: Any, path):
|
||||
value = tree
|
||||
for key in path:
|
||||
value = value[key]
|
||||
return value
|
||||
|
||||
|
||||
class StructuredCompiledModel:
|
||||
"""Preserve a module's original nested input structure for direct PT2 compile()."""
|
||||
|
||||
def __init__(self, compiled_model, example_args):
|
||||
self._compiled = compiled_model
|
||||
self._leaf_paths = _leaf_paths(example_args)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._compiled, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
|
||||
return self._compiled(*flat_inputs)
|
||||
|
||||
@@ -9,12 +9,10 @@ import inspect
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .compiled_model import CompiledModel, StructuredCompiledModel
|
||||
from .compiled_model import CompiledModel
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
@@ -186,66 +184,6 @@ def _save_and_compile(
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _has_cuda_inputs(flat_example_inputs):
|
||||
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
|
||||
|
||||
|
||||
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
|
||||
"""Search env overrides for direct compile() calls.
|
||||
|
||||
CUDA DLRM-style models benefit materially from a deeper per-candidate
|
||||
profile and from keeping more parents alive between generations. Keep the
|
||||
defaults narrow so CPU and env-configured callers are unchanged.
|
||||
"""
|
||||
has_cuda = _has_cuda_inputs(flat_example_inputs)
|
||||
overrides = {}
|
||||
|
||||
if search_trials is not None:
|
||||
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
|
||||
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
|
||||
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
|
||||
|
||||
if search_keep_best is not None:
|
||||
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
|
||||
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
|
||||
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(overrides):
|
||||
sentinel = object()
|
||||
previous = {}
|
||||
try:
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = old_value
|
||||
|
||||
|
||||
def _strip_exported_weights_for_zero_copy(ep, original_weights):
|
||||
"""Shrink the saved .pt2 artifact when original weights will be reused."""
|
||||
if not original_weights:
|
||||
return
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
if isinstance(orig, torch.nn.Parameter):
|
||||
replacement = torch.nn.Parameter(
|
||||
replacement, requires_grad=orig.requires_grad
|
||||
)
|
||||
ep._state_dict[key] = replacement
|
||||
del orig
|
||||
|
||||
|
||||
def _safe_int_bound(value):
|
||||
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
|
||||
|
||||
@@ -463,9 +401,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
def compile(
|
||||
model,
|
||||
example_input,
|
||||
search_iterations=None,
|
||||
search_trials=None,
|
||||
search_keep_best=None,
|
||||
search_iterations=25,
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
@@ -477,13 +413,7 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor — or a list/tuple of tensors for
|
||||
multi-input models.
|
||||
search_iterations: Number of optimization search iterations. When None,
|
||||
defaults to 200 on CUDA inputs and 10 otherwise.
|
||||
search_trials: Optional per-candidate profiling trials inside Luminal's
|
||||
search. When unset, direct CUDA compile defaults to 5.
|
||||
search_keep_best: Optional number of parent candidates to retain
|
||||
between search generations. When unset, direct CUDA compile
|
||||
defaults to 3.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
|
||||
@@ -502,26 +432,17 @@ def compile(
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(
|
||||
example_input
|
||||
if isinstance(example_input, (list, tuple))
|
||||
else [example_input]
|
||||
)
|
||||
|
||||
if isinstance(example_input, (list, tuple)):
|
||||
example_args = tuple(example_input)
|
||||
else:
|
||||
example_args = (example_input,)
|
||||
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
|
||||
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(flat_example_inputs)
|
||||
|
||||
if search_iterations is None:
|
||||
search_iterations = (
|
||||
200
|
||||
if _has_cuda_inputs(flat_example_inputs)
|
||||
else 10
|
||||
)
|
||||
search_env = _direct_search_env(
|
||||
flat_example_inputs,
|
||||
search_trials=search_trials,
|
||||
search_keep_best=search_keep_best,
|
||||
)
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -567,16 +488,63 @@ def compile(
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
|
||||
original_weights = model.state_dict()
|
||||
_strip_exported_weights_for_zero_copy(ep, original_weights)
|
||||
with _temporary_env(search_env):
|
||||
compiled = _save_and_compile(
|
||||
ep,
|
||||
factory,
|
||||
search_iterations,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
return StructuredCompiledModel(compiled, example_args)
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def _drop_input_guards(ep):
|
||||
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
|
||||
|
||||
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
|
||||
``i = torch.tensor(2)``), torch.export records two equivalent input
|
||||
guards: ``L['i'].item() == 2`` (referencing the original local source)
|
||||
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
|
||||
Two failures stack on top of each other:
|
||||
|
||||
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
|
||||
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
|
||||
a literal ``L`` reference in the generated ``_guards_fn`` and raising
|
||||
``NameError: name 'L' is not defined`` during retracing.
|
||||
2. Even after dropping the unresolvable guard, the surviving
|
||||
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
|
||||
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
|
||||
a graph break.
|
||||
|
||||
These guards exist solely to validate inputs at runtime in eager-mode
|
||||
consumers of the ExportedProgram; the luminal compiler does its own
|
||||
input shape/dtype checks against the compiled graph signature, so we
|
||||
are not losing any safety by clearing them.
|
||||
"""
|
||||
|
||||
if hasattr(ep, "_guards_code"):
|
||||
ep._guards_code = []
|
||||
|
||||
|
||||
def _drop_dead_data_dependent_ops(gm):
|
||||
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
|
||||
|
||||
When dynamo specializes a 0-d int input by tracing through ``.item()``,
|
||||
the resulting graph may contain a dead ``aten.item.default`` node whose
|
||||
output is never consumed. luminal's translator does not lower
|
||||
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
|
||||
node in the graph causes a graph break at compile time. Eliminating it
|
||||
keeps the (correctly specialized) downstream graph in a single subgraph.
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
changed = False
|
||||
for node in list(graph.nodes):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
|
||||
and len(node.users) == 0
|
||||
):
|
||||
graph.erase_node(node)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
@@ -658,13 +626,23 @@ def _eager_pt2_compile(
|
||||
if dynamic_shapes is None:
|
||||
raise
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
# LUM-499: drop dynamo-emitted input guards before run_decompositions
|
||||
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
|
||||
# data-dependent .item() calls and unresolved `L[...]` references.
|
||||
_drop_input_guards(ep)
|
||||
_drop_dead_data_dependent_ops(ep.graph_module)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers
|
||||
# from the EP before saving. The Rust side uses device pointers for these
|
||||
# weights, not the .pt2 file data, so serializing them is pure IO waste
|
||||
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
|
||||
_strip_exported_weights_for_zero_copy(ep, original_weights)
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
|
||||
# Save EP to disk, then free it and the traced graph module before Rust
|
||||
# compilation. torch.export clones the state_dict internally; holding ep
|
||||
@@ -678,20 +656,11 @@ def _eager_pt2_compile(
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
default_search_iterations = (
|
||||
50
|
||||
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
|
||||
else 10
|
||||
)
|
||||
search_iterations = int(
|
||||
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
|
||||
)
|
||||
|
||||
try:
|
||||
return _save_and_compile(
|
||||
pt2_path,
|
||||
factory,
|
||||
search_iterations,
|
||||
10,
|
||||
original_weights=original_weights,
|
||||
user_indices=user_indices,
|
||||
)
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
# DLRM CUDA Benchmark
|
||||
|
||||
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
|
||||
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
|
||||
|
||||
| Path | Median latency | Throughput |
|
||||
| --- | ---: | ---: |
|
||||
| eager | 0.267 ms | 7,674,321 candidates/s |
|
||||
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
|
||||
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
|
||||
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |
|
||||
@@ -1,213 +0,0 @@
|
||||
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
|
||||
|
||||
These tests are intended for the local integration workflow where the
|
||||
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
|
||||
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
|
||||
then run the same model through ``backend=luminal_backend``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("sklearn")
|
||||
pytest.importorskip("tqdm")
|
||||
|
||||
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
|
||||
if not DEEPCTR_ROOT.exists():
|
||||
pytest.skip(
|
||||
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
deepctr_root = str(DEEPCTR_ROOT)
|
||||
if deepctr_root not in sys.path:
|
||||
sys.path.insert(0, deepctr_root)
|
||||
|
||||
from deepctr_torch.inputs import (
|
||||
DenseFeat,
|
||||
SparseFeat,
|
||||
VarLenSparseFeat,
|
||||
build_input_features,
|
||||
)
|
||||
from deepctr_torch.models import DCN
|
||||
from deepctr_torch.models.din import DIN
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _stack_features(
|
||||
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
|
||||
) -> torch.Tensor:
|
||||
parts = []
|
||||
for name in build_input_features(feature_columns):
|
||||
value = np.asarray(feature_dict[name])
|
||||
if value.ndim == 1:
|
||||
value = np.expand_dims(value, axis=1)
|
||||
parts.append(value)
|
||||
stacked = np.concatenate(parts, axis=-1)
|
||||
return torch.tensor(stacked, dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
if isinstance(output, tuple) and len(output) == 1:
|
||||
return output[0]
|
||||
return output
|
||||
|
||||
|
||||
def _assert_allclose(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
|
||||
) -> None:
|
||||
max_diff = torch.max(torch.abs(lhs - rhs)).item()
|
||||
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
|
||||
|
||||
|
||||
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return _unwrap(model(*inputs))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
prev_recompile_limit = torch._dynamo.config.recompile_limit
|
||||
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = prev_recompile_limit
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
|
||||
|
||||
|
||||
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _make_dcn(
|
||||
device: torch.device, cross_parameterization: str
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
|
||||
torch.manual_seed(0)
|
||||
feature_columns = [
|
||||
SparseFeat("s0", 5, embedding_dim=4),
|
||||
SparseFeat("s1", 7, embedding_dim=4),
|
||||
DenseFeat("d0", 1),
|
||||
DenseFeat("d1", 1),
|
||||
]
|
||||
feature_dict = {
|
||||
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
|
||||
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
|
||||
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
|
||||
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
model = DCN(
|
||||
linear_feature_columns=feature_columns,
|
||||
dnn_feature_columns=feature_columns,
|
||||
cross_num=2,
|
||||
cross_parameterization=cross_parameterization,
|
||||
dnn_hidden_units=(16,),
|
||||
dnn_dropout=0.0,
|
||||
device=str(device),
|
||||
).eval()
|
||||
inputs = (_stack_features(feature_columns, feature_dict, device),)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
|
||||
torch.manual_seed(0)
|
||||
feature_columns = [
|
||||
SparseFeat("user", 4, embedding_dim=4),
|
||||
SparseFeat("gender", 2, embedding_dim=4),
|
||||
SparseFeat("item_id", 4, embedding_dim=8),
|
||||
SparseFeat("cate_id", 3, embedding_dim=4),
|
||||
DenseFeat("pay_score", 1),
|
||||
VarLenSparseFeat(
|
||||
SparseFeat(
|
||||
"hist_item_id",
|
||||
vocabulary_size=4,
|
||||
embedding_dim=8,
|
||||
embedding_name="item_id",
|
||||
),
|
||||
maxlen=4,
|
||||
length_name="seq_length",
|
||||
),
|
||||
VarLenSparseFeat(
|
||||
SparseFeat(
|
||||
"hist_cate_id",
|
||||
vocabulary_size=3,
|
||||
embedding_dim=4,
|
||||
embedding_name="cate_id",
|
||||
),
|
||||
maxlen=4,
|
||||
length_name="seq_length",
|
||||
),
|
||||
]
|
||||
feature_dict = {
|
||||
"user": np.array([0, 1, 2, 3], dtype=np.int64),
|
||||
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
|
||||
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
|
||||
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
|
||||
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
|
||||
"hist_item_id": np.array(
|
||||
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
|
||||
dtype=np.int64,
|
||||
),
|
||||
"hist_cate_id": np.array(
|
||||
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
|
||||
dtype=np.int64,
|
||||
),
|
||||
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
|
||||
}
|
||||
model = DIN(
|
||||
feature_columns,
|
||||
["item_id", "cate_id"],
|
||||
dnn_dropout=0.0,
|
||||
device=str(device),
|
||||
).eval()
|
||||
inputs = (_stack_features(feature_columns, feature_dict, device),)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
|
||||
def test_deepctr_dcn_matches_inductor_when_supported(
|
||||
device: torch.device, cross_parameterization: str
|
||||
) -> None:
|
||||
model, inputs = _make_dcn(device, cross_parameterization)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
|
||||
|
||||
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
|
||||
model, inputs = _make_din(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
@@ -1,456 +0,0 @@
|
||||
"""DLRM coverage for the luminal torch.compile backend.
|
||||
|
||||
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
|
||||
that eager mode, ``torch.compile(..., backend="inductor")``, and
|
||||
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
|
||||
configurations, including a CUDA benchmark that compares Luminal against
|
||||
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("sklearn")
|
||||
|
||||
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
|
||||
if not DLRM_ROOT.exists():
|
||||
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
|
||||
|
||||
dlrm_root = str(DLRM_ROOT)
|
||||
if dlrm_root not in sys.path:
|
||||
sys.path.insert(0, dlrm_root)
|
||||
|
||||
|
||||
def _install_dlrm_import_stubs() -> None:
|
||||
ext_dist = types.ModuleType("extend_distributed")
|
||||
ext_dist.my_size = 1
|
||||
ext_dist.dist = None
|
||||
ext_dist.get_split_lengths = lambda n: (n, [n])
|
||||
ext_dist.get_my_slice = lambda n: slice(0, n)
|
||||
|
||||
class _AllToAll:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def wait(self):
|
||||
return self._values
|
||||
|
||||
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
|
||||
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"extend_distributed", loader=None
|
||||
)
|
||||
sys.modules["extend_distributed"] = ext_dist
|
||||
|
||||
mlperf_logger = types.ModuleType("mlperf_logger")
|
||||
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"mlperf_logger", loader=None
|
||||
)
|
||||
sys.modules["mlperf_logger"] = mlperf_logger
|
||||
|
||||
tensorboard = types.ModuleType("torch.utils.tensorboard")
|
||||
|
||||
class SummaryWriter:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def add_scalar(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
tensorboard.SummaryWriter = SummaryWriter
|
||||
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"torch.utils.tensorboard", loader=None
|
||||
)
|
||||
sys.modules["torch.utils.tensorboard"] = tensorboard
|
||||
|
||||
onnx = types.ModuleType("onnx")
|
||||
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
|
||||
sys.modules["onnx"] = onnx
|
||||
|
||||
|
||||
_install_dlrm_import_stubs()
|
||||
|
||||
import dlrm_s_pytorch as dlrm_mod
|
||||
from luminal import luminal_backend
|
||||
|
||||
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
|
||||
|
||||
|
||||
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
if isinstance(output, tuple) and len(output) == 1:
|
||||
return output[0]
|
||||
return output
|
||||
|
||||
|
||||
def _assert_allclose(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
|
||||
) -> None:
|
||||
max_diff = torch.max(torch.abs(lhs - rhs)).item()
|
||||
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
|
||||
|
||||
|
||||
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return _unwrap(model(*inputs))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _relaxed_dynamo_limits():
|
||||
prev_recompile_limit = torch._dynamo.config.recompile_limit
|
||||
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.recompile_limit = prev_recompile_limit
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
|
||||
|
||||
|
||||
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _compile_inductor(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(copy.deepcopy(model), backend="inductor")
|
||||
|
||||
|
||||
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
return _unwrap(compiled(*inputs))
|
||||
|
||||
|
||||
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(
|
||||
copy.deepcopy(model),
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
)
|
||||
|
||||
|
||||
def _compile_luminal(model: torch.nn.Module):
|
||||
with _relaxed_dynamo_limits():
|
||||
torch._dynamo.reset()
|
||||
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
|
||||
|
||||
|
||||
def _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters: int,
|
||||
timed_iters: int,
|
||||
mark_step_begin: bool = False,
|
||||
) -> dict[str, float]:
|
||||
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup_iters):
|
||||
if mark_step_begin:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
_unwrap(compiled_model(*inputs))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
|
||||
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
|
||||
|
||||
with torch.no_grad():
|
||||
for idx in range(timed_iters):
|
||||
if mark_step_begin:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
starts[idx].record()
|
||||
_unwrap(compiled_model(*inputs))
|
||||
ends[idx].record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = np.array(
|
||||
[start.elapsed_time(end) for start, end in zip(starts, ends)],
|
||||
dtype=np.float64,
|
||||
)
|
||||
return {
|
||||
"mean_ms": float(elapsed_ms.mean()),
|
||||
"median_ms": float(np.median(elapsed_ms)),
|
||||
"min_ms": float(elapsed_ms.min()),
|
||||
}
|
||||
|
||||
|
||||
def _timed_cuda_rounds(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
pre_round_warmup_iters: int,
|
||||
timed_iters: int,
|
||||
rounds: int,
|
||||
mark_step_begin: bool = False,
|
||||
) -> dict[str, float | list[float]]:
|
||||
assert rounds > 0, "rounds must be positive"
|
||||
|
||||
stats = _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters=pre_round_warmup_iters,
|
||||
timed_iters=timed_iters,
|
||||
mark_step_begin=mark_step_begin,
|
||||
)
|
||||
round_medians = [stats["median_ms"]]
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
stats = _timed_cuda_runs(
|
||||
compiled_model,
|
||||
*inputs,
|
||||
warmup_iters=0,
|
||||
timed_iters=timed_iters,
|
||||
mark_step_begin=mark_step_begin,
|
||||
)
|
||||
round_medians.append(stats["median_ms"])
|
||||
|
||||
round_medians_np = np.array(round_medians, dtype=np.float64)
|
||||
return {
|
||||
"round_medians_ms": [float(value) for value in round_medians],
|
||||
"median_ms": float(np.median(round_medians_np)),
|
||||
"mean_ms": float(round_medians_np.mean()),
|
||||
"min_ms": float(round_medians_np.min()),
|
||||
}
|
||||
|
||||
|
||||
def _make_dlrm(
|
||||
device: torch.device,
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
m_spa = 4
|
||||
ln_emb = np.array([8, 6, 4])
|
||||
ln_bot = np.array([3, 4])
|
||||
num_fea = ln_emb.size + 1
|
||||
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
|
||||
ln_top = np.array([num_int, 8, 1])
|
||||
|
||||
model = dlrm_mod.DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op="dot",
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_top=1,
|
||||
).eval()
|
||||
|
||||
inputs = (
|
||||
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
|
||||
[
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 1], dtype=torch.int64, device=device),
|
||||
],
|
||||
[
|
||||
torch.tensor([1, 2], dtype=torch.int64, device=device),
|
||||
torch.tensor([0, 3], dtype=torch.int64, device=device),
|
||||
torch.tensor([2, 1], dtype=torch.int64, device=device),
|
||||
],
|
||||
)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def _make_dlrm_batch_2048(
|
||||
device: torch.device,
|
||||
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
batch_size = 2048
|
||||
indices_per_bag = 2
|
||||
m_spa = 16
|
||||
ln_emb = np.array([4096, 2048, 1024])
|
||||
ln_bot = np.array([3, 64, m_spa])
|
||||
num_fea = ln_emb.size + 1
|
||||
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
|
||||
ln_top = np.array([num_int, 64, 32, 1])
|
||||
|
||||
model = dlrm_mod.DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op="dot",
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_top=2,
|
||||
).eval()
|
||||
|
||||
dense_x = torch.linspace(
|
||||
-1.0,
|
||||
1.0,
|
||||
steps=batch_size * 3,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
).reshape(batch_size, 3)
|
||||
|
||||
total_sparse_indices = batch_size * indices_per_bag
|
||||
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
|
||||
offsets = torch.arange(
|
||||
0,
|
||||
total_sparse_indices,
|
||||
indices_per_bag,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
inputs = (
|
||||
dense_x,
|
||||
[offsets.clone(), offsets.clone(), offsets.clone()],
|
||||
[
|
||||
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
|
||||
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
|
||||
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
|
||||
],
|
||||
)
|
||||
return model.to(device), inputs
|
||||
|
||||
|
||||
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
|
||||
model, inputs = _make_dlrm(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
inductor = _run_inductor(model, *inputs)
|
||||
_assert_allclose(inductor, eager, "inductor vs eager")
|
||||
|
||||
luminal = _run_luminal(model, *inputs)
|
||||
_assert_allclose(luminal, eager, "luminal vs eager")
|
||||
_assert_allclose(luminal, inductor, "luminal vs inductor")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
if device.type != "cuda":
|
||||
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
|
||||
|
||||
model, inputs = _make_dlrm_batch_2048(device)
|
||||
|
||||
eager = _run_eager(model, *inputs)
|
||||
|
||||
eager_model = copy.deepcopy(model).to(device).eval()
|
||||
inductor_compiled = _compile_inductor(model)
|
||||
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad():
|
||||
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
|
||||
|
||||
luminal_compiled = _compile_luminal(model)
|
||||
with torch.no_grad():
|
||||
luminal_output = _unwrap(luminal_compiled(*inputs))
|
||||
|
||||
with torch.no_grad():
|
||||
inductor_default_output = _unwrap(inductor_compiled(*inputs))
|
||||
|
||||
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
|
||||
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
|
||||
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
|
||||
_assert_allclose(
|
||||
luminal_output,
|
||||
inductor_default_output,
|
||||
"luminal vs inductor default",
|
||||
atol=1e-4,
|
||||
)
|
||||
_assert_allclose(
|
||||
luminal_output,
|
||||
inductor_output,
|
||||
"luminal vs inductor reduce-overhead",
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
benchmark_rounds = 5
|
||||
benchmark_iters = 20
|
||||
post_compile_warmup_iters = 10
|
||||
|
||||
eager_stats = _timed_cuda_rounds(
|
||||
eager_model,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
inductor_default_stats = _timed_cuda_rounds(
|
||||
inductor_compiled,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
inductor_stats = _timed_cuda_rounds(
|
||||
inductor_reduce_overhead,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
mark_step_begin=True,
|
||||
)
|
||||
luminal_stats = _timed_cuda_rounds(
|
||||
luminal_compiled,
|
||||
*inputs,
|
||||
pre_round_warmup_iters=post_compile_warmup_iters,
|
||||
timed_iters=benchmark_iters,
|
||||
rounds=benchmark_rounds,
|
||||
)
|
||||
|
||||
batch_size = inputs[0].shape[0]
|
||||
benchmark_results = [
|
||||
("eager", eager_stats),
|
||||
("inductor default", inductor_default_stats),
|
||||
("inductor reduce-overhead", inductor_stats),
|
||||
("luminal backend", luminal_stats),
|
||||
]
|
||||
ranked_results = sorted(
|
||||
benchmark_results,
|
||||
key=lambda item: float(item[1]["median_ms"]),
|
||||
)
|
||||
speed_lines = []
|
||||
for idx, (label, stats) in enumerate(ranked_results, start=1):
|
||||
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
|
||||
rounds_repr = ", ".join(
|
||||
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
|
||||
)
|
||||
speed_lines.append(
|
||||
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
|
||||
f" ({throughput:,.0f} candidates/s)"
|
||||
f" [round medians: {rounds_repr}]"
|
||||
)
|
||||
|
||||
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
|
||||
inductor_stats["median_ms"]
|
||||
)
|
||||
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
|
||||
|
||||
print(
|
||||
"\n"
|
||||
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
|
||||
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
|
||||
f" Ranking by median latency:\n"
|
||||
+ "\n".join(speed_lines)
|
||||
+ "\n"
|
||||
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
|
||||
f" {luminal_vs_inductor:.3f}x\n"
|
||||
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -220,6 +221,7 @@ from test_models import (
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv1dFloorDivPositionalModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
|
||||
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
|
||||
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
|
||||
eager = model(x)
|
||||
out = model_compiled(x)
|
||||
assert out.dtype == eager.dtype == torch.int64
|
||||
assert torch.equal(out, eager)
|
||||
|
||||
|
||||
def test_reduce_sum_3d_axis1(device: torch.device):
|
||||
"""Test sum reduction along axis 1 for a 3D tensor."""
|
||||
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
|
||||
@@ -2022,9 +2035,16 @@ def test_split(device: torch.device):
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking.
|
||||
|
||||
Parametrized over int32/int64 to verify luminal preserves whichever
|
||||
integer dtype the eager model declares (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
|
||||
device
|
||||
)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
@@ -2033,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
|
||||
assert output.dtype == original.dtype, (
|
||||
f"luminal returned {output.dtype}, eager produced {original.dtype}"
|
||||
)
|
||||
assert torch.equal(output, original)
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Focused proof for built MoE routing support.
|
||||
|
||||
Parametrized over int32/int64 for the integer-valued outputs to verify
|
||||
luminal preserves the dtype declared by the eager model (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
@@ -2050,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
for actual, eager in zip(output, expected):
|
||||
assert actual.dtype == eager.dtype, (
|
||||
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
|
||||
)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
@@ -2477,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_floor_div_positional_pt2(device: torch.device):
|
||||
"""Conv1d stride output uses floor division before positional add."""
|
||||
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
|
||||
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.shape == original.shape == (15, 16)
|
||||
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
|
||||
@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
"""Tests deterministic duplicate ordering for exported argsort.
|
||||
|
||||
``idx_dtype`` parameterizes the integer dtype of the returned indices so
|
||||
the test can verify dtype preservation across luminal's int dtype paths
|
||||
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
|
||||
lets us drive the same model toward int32 or int64 outputs.
|
||||
"""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
|
||||
|
||||
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
|
||||
group_ids) to the requested dtype so the test can sweep int32 and int64
|
||||
output paths (LUM-486). Internal indices stay int64 because torch.gather
|
||||
/ torch.scatter require int64 index tensors.
|
||||
"""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
routed_indices.to(self.idx_dtype),
|
||||
masked_values,
|
||||
dispatch,
|
||||
dispatch.to(self.idx_dtype),
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
group_ids.to(self.idx_dtype),
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dFloorDivPositionalModel(torch.nn.Module):
|
||||
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
self.position = torch.nn.Parameter(torch.randn(15, 16))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.gelu(self.conv1(x))
|
||||
x = torch.nn.functional.gelu(self.conv2(x))
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
return x + self.position
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
|
||||
1544
crates/luminal_python/tests/test_scalars.py
Normal file
1544
crates/luminal_python/tests/test_scalars.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -315,8 +315,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,18 +8,14 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
@@ -38,11 +37,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -63,11 +57,14 @@ fn main() {
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -75,15 +72,66 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
&prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
print_token_ids: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -93,7 +141,7 @@ fn main() {
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -121,12 +169,26 @@ fn main() {
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -165,10 +227,21 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
|
||||
@@ -462,8 +462,13 @@ fn hlir_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,14 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -66,10 +74,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -77,12 +88,65 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -94,13 +158,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -159,12 +226,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -246,8 +246,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,7 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -54,10 +69,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -65,12 +83,58 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -82,13 +146,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -147,12 +214,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -287,8 +287,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,15 +8,27 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 30)
|
||||
} else {
|
||||
30
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
|
||||
} else {
|
||||
50
|
||||
};
|
||||
let prompt = "The capital of France is";
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -24,7 +38,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -47,10 +60,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -58,14 +74,63 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -76,7 +141,7 @@ fn main() {
|
||||
const STOP_TOKEN: u32 = 151643;
|
||||
|
||||
// Prefill: process prompt tokens one at a time
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -105,13 +170,27 @@ fn main() {
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -150,13 +229,23 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report benchmarks
|
||||
println!();
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
|
||||
@@ -287,7 +287,8 @@ impl QwenMoE {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -385,8 +386,13 @@ fn attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -174,8 +174,13 @@ fn decoder_self_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
|
||||
k_full.shape.dims[1] = total;
|
||||
v_full.shape.dims[1] = total;
|
||||
|
||||
let q = split_heads(q);
|
||||
|
||||
|
||||
58
examples_common/benchmark_stdio.rs
Normal file
58
examples_common/benchmark_stdio.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::{
|
||||
io::{BufRead, Write},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
std::env::args().any(|arg| arg == "--stdio")
|
||||
}
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn emit_ready() {
|
||||
println!("READY");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn serve(mut f: impl FnMut(&str)) {
|
||||
emit_ready();
|
||||
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines() {
|
||||
let line = line.unwrap();
|
||||
f(&line);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_token(token: &str) {
|
||||
println!("TOK\t{}", escape_token(token));
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn emit_eoq(generated: usize, query_start: Instant) {
|
||||
println!(
|
||||
"EOQ\t{}\t{:.3}",
|
||||
generated,
|
||||
query_start.elapsed().as_secs_f64() * 1e3
|
||||
);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn escape_token(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for ch in s.chars() {
|
||||
match ch {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(ch),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -7,14 +7,13 @@
|
||||
//! - [`NativeDynBackend`]: the reference implementation for CPU
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use half::{bf16, f16};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::dtype::DType;
|
||||
use crate::graph::{Graph, SearchOptions};
|
||||
use crate::graph::Graph;
|
||||
use crate::hlir::{NativeData, NativeRuntime, Output};
|
||||
use crate::op::Runtime;
|
||||
|
||||
@@ -47,18 +46,6 @@ pub trait DynBackend {
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
|
||||
|
||||
// --- Optional diagnostics --------------------------------------------
|
||||
|
||||
fn kernel_names(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn host_op_names(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn print_execution_stats(&self) {}
|
||||
|
||||
// --- Optional device pointer support (GPU backends) --------------------
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
@@ -175,9 +162,7 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rng = rand::rng();
|
||||
let search_options = search_options_from_env(args.search_iters);
|
||||
let mut rt = graph.search_options(rt, search_options, &mut rng);
|
||||
let mut rt = graph.search(rt, args.search_iters);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
@@ -194,39 +179,6 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
Ok(wrap(rt))
|
||||
}
|
||||
|
||||
fn env_usize(name: &str) -> Option<usize> {
|
||||
std::env::var(name).ok()?.parse().ok()
|
||||
}
|
||||
|
||||
fn env_duration_ms(name: &str) -> Option<Duration> {
|
||||
Some(Duration::from_millis(env_usize(name)? as u64))
|
||||
}
|
||||
|
||||
fn search_options_from_env(limit: usize) -> SearchOptions {
|
||||
let mut options = SearchOptions::new(limit);
|
||||
|
||||
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
|
||||
options = options.generation_size(generation_size);
|
||||
}
|
||||
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
|
||||
options = options.mutations(mutations);
|
||||
}
|
||||
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
|
||||
options = options.trials(trials);
|
||||
}
|
||||
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
|
||||
options = options.keep_best(keep_best);
|
||||
}
|
||||
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
|
||||
options = options.profile_timeout(profile_timeout);
|
||||
}
|
||||
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
|
||||
options = options.group_timeout(group_timeout);
|
||||
}
|
||||
|
||||
options
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,6 +11,7 @@ impl Add for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn add(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to add tensors. Got {:?} and {:?}",
|
||||
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn mul(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(
|
||||
self.dims(),
|
||||
rhs.dims(),
|
||||
"Dims must match to multiply tensors."
|
||||
);
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
|
||||
@@ -474,6 +480,42 @@ pub(super) mod tests {
|
||||
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to add tensors.")]
|
||||
fn test_add_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a + b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to multiply tensors.")]
|
||||
fn test_mul_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a * b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to mod tensors.")]
|
||||
fn test_mod_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a % b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to lt tensors.")]
|
||||
fn test_lt_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a.lt(b);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -557,6 +599,27 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_mod_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
|
||||
test_binary_transforms(
|
||||
size,
|
||||
(),
|
||||
|a, b| a % b.expand_rhs(a.shape),
|
||||
|a, b| {
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let rhs_scalar = b.to_scalar::<f32>().unwrap();
|
||||
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
|
||||
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
identity,
|
||||
shift_from_zero,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -570,6 +633,28 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_lt_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS for `lt`.
|
||||
test_binary(
|
||||
size,
|
||||
(),
|
||||
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|
||||
|a, b| {
|
||||
let scalar = b.to_scalar::<f32>().unwrap();
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let result: Vec<f32> = lhs
|
||||
.iter()
|
||||
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
|
||||
.collect();
|
||||
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
|
||||
@@ -467,7 +467,7 @@ impl GraphTensor {
|
||||
let mut win = Vec::with_capacity(n);
|
||||
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
|
||||
let effective_window = *d * (*k - 1) + 1;
|
||||
win.push(((*dim - effective_window) / s) + 1);
|
||||
win.push((*dim - effective_window).floor_div(s) + 1);
|
||||
}
|
||||
|
||||
// [win..., kernel...]
|
||||
@@ -905,6 +905,14 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
|
||||
let mut cx = Graph::new();
|
||||
let inp = cx.tensor((80, 3000));
|
||||
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
|
||||
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsqueeze() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
293
src/hlir.rs
293
src/hlir.rs
@@ -217,38 +217,32 @@ pub fn reduce_sort(name: &str) -> SortDef {
|
||||
}
|
||||
|
||||
pub type HLIROps = (
|
||||
(
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
Exp2,
|
||||
Log2,
|
||||
),
|
||||
(
|
||||
Sin,
|
||||
Recip,
|
||||
Sqrt,
|
||||
Add,
|
||||
Mul,
|
||||
Mod,
|
||||
LessThan,
|
||||
Gather,
|
||||
Concat2D,
|
||||
EmbeddingBagSum,
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
),
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
Exp2,
|
||||
Log2,
|
||||
Sin,
|
||||
Recip,
|
||||
Sqrt,
|
||||
Add,
|
||||
Mul,
|
||||
Mod,
|
||||
LessThan,
|
||||
Gather,
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1727,9 +1721,7 @@ impl NativeOp for Add {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
|
||||
}
|
||||
NativeData::Bool(a) => {
|
||||
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1816,9 +1808,7 @@ impl NativeOp for Mul {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
|
||||
}
|
||||
NativeData::Bool(a) => {
|
||||
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2136,233 +2126,6 @@ impl NativeOp for Gather {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Concat2D {
|
||||
pub rows: Expression,
|
||||
pub lhs_cols: Expression,
|
||||
pub rhs_cols: Expression,
|
||||
}
|
||||
impl Display for Concat2D {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Concat2D")
|
||||
}
|
||||
}
|
||||
impl HLIROp for Concat2D {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (Concat2D {} {} {}) {})",
|
||||
self.rows.to_egglog(),
|
||||
self.lhs_cols.to_egglog(),
|
||||
self.rhs_cols.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for Concat2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"Concat2D",
|
||||
&[
|
||||
("rows", EXPRESSION),
|
||||
("lhs_cols", EXPRESSION),
|
||||
("rhs_cols", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for Concat2D {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
let rows = self.rows.exec(dyn_map).unwrap();
|
||||
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
|
||||
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
|
||||
|
||||
fn concat_rows<T: Clone>(
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
rows: usize,
|
||||
lhs_cols: usize,
|
||||
rhs_cols: usize,
|
||||
) -> Vec<T> {
|
||||
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
|
||||
for row in 0..rows {
|
||||
let lhs_base = row * lhs_cols;
|
||||
let rhs_base = row * rhs_cols;
|
||||
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
|
||||
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
match (inputs[0], inputs[1]) {
|
||||
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
|
||||
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
|
||||
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
|
||||
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
|
||||
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
|
||||
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
|
||||
}
|
||||
_ => panic!("Concat2D inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct EmbeddingBagSum {
|
||||
pub n_bags: Expression,
|
||||
pub n_indices: Expression,
|
||||
pub hidden_dim: Expression,
|
||||
pub num_embeddings: Expression,
|
||||
}
|
||||
impl Display for EmbeddingBagSum {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "EmbeddingBagSum")
|
||||
}
|
||||
}
|
||||
impl HLIROp for EmbeddingBagSum {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (EmbeddingBagSum {} {} {} {}) {})",
|
||||
self.n_bags.to_egglog(),
|
||||
self.n_indices.to_egglog(),
|
||||
self.hidden_dim.to_egglog(),
|
||||
self.num_embeddings.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for EmbeddingBagSum {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"EmbeddingBagSum",
|
||||
&[
|
||||
("n_bags", EXPRESSION),
|
||||
("n_indices", EXPRESSION),
|
||||
("hidden_dim", EXPRESSION),
|
||||
("num_embeddings", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
|
||||
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for EmbeddingBagSum {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
let n_bags = self.n_bags.exec(dyn_map).unwrap();
|
||||
let n_indices = self.n_indices.exec(dyn_map).unwrap();
|
||||
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
|
||||
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
|
||||
|
||||
let clamp_index = |value: i32, limit: usize| -> usize {
|
||||
if limit == 0 {
|
||||
return 0;
|
||||
}
|
||||
value.clamp(0, limit.saturating_sub(1) as i32) as usize
|
||||
};
|
||||
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
|
||||
|
||||
let NativeData::Int(indices) = inputs[1] else {
|
||||
panic!("EmbeddingBagSum indices must be Int")
|
||||
};
|
||||
let NativeData::Int(offsets) = inputs[2] else {
|
||||
panic!("EmbeddingBagSum offsets must be Int")
|
||||
};
|
||||
|
||||
let bag_bounds = |bag: usize| -> (usize, usize) {
|
||||
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
|
||||
let end = offsets
|
||||
.get(bag + 1)
|
||||
.copied()
|
||||
.map(clamp_offset)
|
||||
.unwrap_or(n_indices);
|
||||
(start, end.max(start))
|
||||
};
|
||||
|
||||
match inputs[0] {
|
||||
NativeData::F32(weight) => {
|
||||
let mut out = vec![0.0f32; n_bags * hidden_dim];
|
||||
for bag in 0..n_bags {
|
||||
let (start, end) = bag_bounds(bag);
|
||||
let out_base = bag * hidden_dim;
|
||||
for pos in start..end {
|
||||
let row = clamp_index(indices[pos], num_embeddings);
|
||||
let row_base = row * hidden_dim;
|
||||
for dim in 0..hidden_dim {
|
||||
out[out_base + dim] += weight[row_base + dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
NativeData::F32(out)
|
||||
}
|
||||
_ => panic!("EmbeddingBagSum only supports F32 weights"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter Op (inverse of Gather)
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
|
||||
@@ -455,6 +455,31 @@ impl Expression {
|
||||
terms.push(Term::CeilDiv);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Floor Division
|
||||
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 {
|
||||
return self;
|
||||
}
|
||||
if self == 0 {
|
||||
return 0.into();
|
||||
}
|
||||
if self == rhs {
|
||||
return 1.into();
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
|
||||
&& let Some(c) = floor_div_i64(a, b)
|
||||
{
|
||||
return c.into();
|
||||
}
|
||||
|
||||
// Shape dimensions are non-negative, so the existing integer Div term
|
||||
// evaluates with floor semantics for dynamic shape expressions.
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Div);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Less than
|
||||
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
@@ -654,6 +679,16 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
|
||||
depth == 1
|
||||
}
|
||||
|
||||
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
|
||||
let q = a.checked_div(b)?;
|
||||
let r = a.checked_rem(b)?;
|
||||
if r != 0 && ((r > 0) != (b > 0)) {
|
||||
q.checked_sub(1)
|
||||
} else {
|
||||
Some(q)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Term> for Expression {
|
||||
fn from(value: Term) -> Self {
|
||||
Expression::new(vec![value])
|
||||
@@ -994,8 +1029,12 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
|
||||
|
||||
impl std::iter::Product for Expression {
|
||||
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
|
||||
// Empty product is the multiplicative identity, 1 — not 0. Returning
|
||||
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
|
||||
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
|
||||
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
|
||||
let Some(mut p) = iter.next() else {
|
||||
return 0.into();
|
||||
return 1.into();
|
||||
};
|
||||
for n in iter {
|
||||
p *= n;
|
||||
@@ -1106,6 +1145,27 @@ mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_product_is_one() {
|
||||
// The empty product (e.g. for a rank-0 tensor's shape) must be the
|
||||
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
|
||||
// emitters use `shape.iter().product()` to compute `numel`, and a
|
||||
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
|
||||
// launch with grid=(0, 1, 1) and crash at runtime.
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(
|
||||
empty.into_iter().product::<Expression>(),
|
||||
Expression::from(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_sum_is_zero() {
|
||||
// Sanity check the additive identity stays 0 (it always was).
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_simplifications() {
|
||||
let x = expr('x');
|
||||
|
||||
Reference in New Issue
Block a user