mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
4 Commits
codex/dlrm
...
codex/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b1e09cf23 | ||
|
|
6416ddb5f8 | ||
|
|
c9d4ce6217 | ||
|
|
7402503bd4 |
@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
|
||||
@@ -1540,19 +1540,22 @@ impl KernelOp for KernelIota {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
vars.extend(self.range.dyn_vars());
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let range = self.range.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void iota_k(int *C{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {range}) return;
|
||||
C[const_z] = {};
|
||||
}}
|
||||
}}",
|
||||
@@ -1571,8 +1574,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.range, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.range.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2935,6 +2938,14 @@ impl KernelOp for KernelCast {
|
||||
) {
|
||||
let out_dtype = cuda_dtype(self.out_dtype);
|
||||
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
|
||||
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let size = self.size.to_kernel();
|
||||
|
||||
let kernel = if self.in_dtype.bits() < 8 {
|
||||
// Sub-byte packed types: multiple values packed per byte.
|
||||
@@ -2944,9 +2955,11 @@ impl KernelOp for KernelCast {
|
||||
let mask = (1u32 << bits) - 1;
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {size}) return;
|
||||
long long bit_offset = idx * {bits};
|
||||
long long byte_idx = bit_offset >> 3;
|
||||
int bit_pos = (int)(bit_offset & 7);
|
||||
@@ -2962,9 +2975,11 @@ extern \"C\" {{
|
||||
let in_dtype = cuda_dtype(self.in_dtype);
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {size}) return;
|
||||
out[const_z] = ({out_dtype})in[const_z];
|
||||
}}
|
||||
}}"
|
||||
@@ -2983,8 +2998,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.size, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.size.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -3275,12 +3290,15 @@ impl KernelOp for KernelEmbed {
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
let n_elements = total_threads.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {n_elements}) return;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
long long embed_idx = idx % embed_dim;
|
||||
@@ -3303,13 +3321,12 @@ extern \"C\" {{
|
||||
};
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(total_threads, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(total_threads.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
constants,
|
||||
)
|
||||
|
||||
@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
|
||||
@@ -61,7 +61,8 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -478,7 +479,8 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -855,8 +855,6 @@ Two important details:
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
|
||||
@@ -98,7 +98,12 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -124,7 +129,9 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -476,10 +483,7 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
self.output_dtypes.clone()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -262,10 +262,13 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,14 +284,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -15,7 +15,16 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
pub min_val: i64,
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,8 +389,11 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -147,6 +148,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -219,6 +221,16 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -236,6 +248,13 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -274,18 +293,27 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
a.cumsum(dim)
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -381,6 +409,17 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -444,6 +483,28 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -120,6 +120,47 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -333,6 +374,17 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
@@ -344,7 +396,12 @@ impl<'a> Translator<'a> {
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
|
||||
@@ -6,6 +6,20 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -37,32 +51,26 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
|
||||
// override below assumes contiguous, no-broadcast storage —
|
||||
// otherwise the `[1, N]` view treats stride-0 broadcast dims
|
||||
// as if they held N distinct values and reads past the backing
|
||||
// buffer. Materialize first when that's not the case (matches
|
||||
// the guard `translate_reshape` already applies).
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
let total = concrete_numel(&a)?;
|
||||
let has_broadcast = a
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(a.shape.strides.iter())
|
||||
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
|
||||
let a = if has_broadcast || !a.shape.is_contiguous() {
|
||||
a + 0.0
|
||||
} else {
|
||||
a
|
||||
};
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -86,4 +94,100 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,12 +213,18 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -269,4 +275,52 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,11 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
@@ -150,11 +155,12 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -182,9 +188,13 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
if out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
|
||||
@@ -491,6 +491,62 @@ def compile(
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def _drop_input_guards(ep):
|
||||
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
|
||||
|
||||
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
|
||||
``i = torch.tensor(2)``), torch.export records two equivalent input
|
||||
guards: ``L['i'].item() == 2`` (referencing the original local source)
|
||||
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
|
||||
Two failures stack on top of each other:
|
||||
|
||||
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
|
||||
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
|
||||
a literal ``L`` reference in the generated ``_guards_fn`` and raising
|
||||
``NameError: name 'L' is not defined`` during retracing.
|
||||
2. Even after dropping the unresolvable guard, the surviving
|
||||
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
|
||||
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
|
||||
a graph break.
|
||||
|
||||
These guards exist solely to validate inputs at runtime in eager-mode
|
||||
consumers of the ExportedProgram; the luminal compiler does its own
|
||||
input shape/dtype checks against the compiled graph signature, so we
|
||||
are not losing any safety by clearing them.
|
||||
"""
|
||||
|
||||
if hasattr(ep, "_guards_code"):
|
||||
ep._guards_code = []
|
||||
|
||||
|
||||
def _drop_dead_data_dependent_ops(gm):
|
||||
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
|
||||
|
||||
When dynamo specializes a 0-d int input by tracing through ``.item()``,
|
||||
the resulting graph may contain a dead ``aten.item.default`` node whose
|
||||
output is never consumed. luminal's translator does not lower
|
||||
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
|
||||
node in the graph causes a graph break at compile time. Eliminating it
|
||||
keeps the (correctly specialized) downstream graph in a single subgraph.
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
changed = False
|
||||
for node in list(graph.nodes):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
|
||||
and len(node.users) == 0
|
||||
):
|
||||
graph.erase_node(node)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
"""Match the historical `dynamic_dim="auto"` heuristic.
|
||||
|
||||
@@ -570,6 +626,11 @@ def _eager_pt2_compile(
|
||||
if dynamic_shapes is None:
|
||||
raise
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
# LUM-499: drop dynamo-emitted input guards before run_decompositions
|
||||
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
|
||||
# data-dependent .item() calls and unresolved `L[...]` references.
|
||||
_drop_input_guards(ep)
|
||||
_drop_dead_data_dependent_ops(ep.graph_module)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -220,6 +221,7 @@ from test_models import (
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv1dFloorDivPositionalModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
|
||||
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
|
||||
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
|
||||
eager = model(x)
|
||||
out = model_compiled(x)
|
||||
assert out.dtype == eager.dtype == torch.int64
|
||||
assert torch.equal(out, eager)
|
||||
|
||||
|
||||
def test_reduce_sum_3d_axis1(device: torch.device):
|
||||
"""Test sum reduction along axis 1 for a 3D tensor."""
|
||||
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
|
||||
@@ -2022,9 +2035,16 @@ def test_split(device: torch.device):
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking.
|
||||
|
||||
Parametrized over int32/int64 to verify luminal preserves whichever
|
||||
integer dtype the eager model declares (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
|
||||
device
|
||||
)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
@@ -2033,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
|
||||
assert output.dtype == original.dtype, (
|
||||
f"luminal returned {output.dtype}, eager produced {original.dtype}"
|
||||
)
|
||||
assert torch.equal(output, original)
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Focused proof for built MoE routing support.
|
||||
|
||||
Parametrized over int32/int64 for the integer-valued outputs to verify
|
||||
luminal preserves the dtype declared by the eager model (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
@@ -2050,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
for actual, eager in zip(output, expected):
|
||||
assert actual.dtype == eager.dtype, (
|
||||
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
|
||||
)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
@@ -2477,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_floor_div_positional_pt2(device: torch.device):
|
||||
"""Conv1d stride output uses floor division before positional add."""
|
||||
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
|
||||
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.shape == original.shape == (15, 16)
|
||||
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
|
||||
@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
"""Tests deterministic duplicate ordering for exported argsort.
|
||||
|
||||
``idx_dtype`` parameterizes the integer dtype of the returned indices so
|
||||
the test can verify dtype preservation across luminal's int dtype paths
|
||||
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
|
||||
lets us drive the same model toward int32 or int64 outputs.
|
||||
"""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
|
||||
|
||||
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
|
||||
group_ids) to the requested dtype so the test can sweep int32 and int64
|
||||
output paths (LUM-486). Internal indices stay int64 because torch.gather
|
||||
/ torch.scatter require int64 index tensors.
|
||||
"""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
routed_indices.to(self.idx_dtype),
|
||||
masked_values,
|
||||
dispatch,
|
||||
dispatch.to(self.idx_dtype),
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
group_ids.to(self.idx_dtype),
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dFloorDivPositionalModel(torch.nn.Module):
|
||||
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
self.position = torch.nn.Parameter(torch.randn(15, 16))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.gelu(self.conv1(x))
|
||||
x = torch.nn.functional.gelu(self.conv2(x))
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
return x + self.position
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
|
||||
1544
crates/luminal_python/tests/test_scalars.py
Normal file
1544
crates/luminal_python/tests/test_scalars.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -315,8 +315,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,18 +8,14 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
@@ -38,11 +37,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -63,11 +57,14 @@ fn main() {
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -75,15 +72,66 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
&prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
print_token_ids: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -93,7 +141,7 @@ fn main() {
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -121,12 +169,26 @@ fn main() {
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -165,10 +227,21 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
|
||||
@@ -462,8 +462,13 @@ fn hlir_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,14 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -66,10 +74,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -77,12 +88,65 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -94,13 +158,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -159,12 +226,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -246,8 +246,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,7 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -54,10 +69,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -65,12 +83,58 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -82,13 +146,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -147,12 +214,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -287,8 +287,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,15 +8,27 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 30)
|
||||
} else {
|
||||
30
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
|
||||
} else {
|
||||
50
|
||||
};
|
||||
let prompt = "The capital of France is";
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -24,7 +38,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -47,10 +60,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -58,14 +74,63 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -76,7 +141,7 @@ fn main() {
|
||||
const STOP_TOKEN: u32 = 151643;
|
||||
|
||||
// Prefill: process prompt tokens one at a time
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -105,13 +170,27 @@ fn main() {
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -150,13 +229,23 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report benchmarks
|
||||
println!();
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
|
||||
@@ -287,7 +287,8 @@ impl QwenMoE {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -385,8 +386,13 @@ fn attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -174,8 +174,13 @@ fn decoder_self_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
|
||||
k_full.shape.dims[1] = total;
|
||||
v_full.shape.dims[1] = total;
|
||||
|
||||
let q = split_heads(q);
|
||||
|
||||
|
||||
58
examples_common/benchmark_stdio.rs
Normal file
58
examples_common/benchmark_stdio.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::{
|
||||
io::{BufRead, Write},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
std::env::args().any(|arg| arg == "--stdio")
|
||||
}
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn emit_ready() {
|
||||
println!("READY");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn serve(mut f: impl FnMut(&str)) {
|
||||
emit_ready();
|
||||
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines() {
|
||||
let line = line.unwrap();
|
||||
f(&line);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_token(token: &str) {
|
||||
println!("TOK\t{}", escape_token(token));
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn emit_eoq(generated: usize, query_start: Instant) {
|
||||
println!(
|
||||
"EOQ\t{}\t{:.3}",
|
||||
generated,
|
||||
query_start.elapsed().as_secs_f64() * 1e3
|
||||
);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn escape_token(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for ch in s.chars() {
|
||||
match ch {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(ch),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -11,6 +11,7 @@ impl Add for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn add(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to add tensors. Got {:?} and {:?}",
|
||||
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn mul(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(
|
||||
self.dims(),
|
||||
rhs.dims(),
|
||||
"Dims must match to multiply tensors."
|
||||
);
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
|
||||
@@ -474,6 +480,42 @@ pub(super) mod tests {
|
||||
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to add tensors.")]
|
||||
fn test_add_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a + b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to multiply tensors.")]
|
||||
fn test_mul_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a * b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to mod tensors.")]
|
||||
fn test_mod_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a % b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to lt tensors.")]
|
||||
fn test_lt_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a.lt(b);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -557,6 +599,27 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_mod_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
|
||||
test_binary_transforms(
|
||||
size,
|
||||
(),
|
||||
|a, b| a % b.expand_rhs(a.shape),
|
||||
|a, b| {
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let rhs_scalar = b.to_scalar::<f32>().unwrap();
|
||||
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
|
||||
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
identity,
|
||||
shift_from_zero,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -570,6 +633,28 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_lt_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS for `lt`.
|
||||
test_binary(
|
||||
size,
|
||||
(),
|
||||
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|
||||
|a, b| {
|
||||
let scalar = b.to_scalar::<f32>().unwrap();
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let result: Vec<f32> = lhs
|
||||
.iter()
|
||||
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
|
||||
.collect();
|
||||
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
|
||||
@@ -467,7 +467,7 @@ impl GraphTensor {
|
||||
let mut win = Vec::with_capacity(n);
|
||||
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
|
||||
let effective_window = *d * (*k - 1) + 1;
|
||||
win.push(((*dim - effective_window) / s) + 1);
|
||||
win.push((*dim - effective_window).floor_div(s) + 1);
|
||||
}
|
||||
|
||||
// [win..., kernel...]
|
||||
@@ -905,6 +905,14 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
|
||||
let mut cx = Graph::new();
|
||||
let inp = cx.tensor((80, 3000));
|
||||
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
|
||||
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsqueeze() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
@@ -455,6 +455,31 @@ impl Expression {
|
||||
terms.push(Term::CeilDiv);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Floor Division
|
||||
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 {
|
||||
return self;
|
||||
}
|
||||
if self == 0 {
|
||||
return 0.into();
|
||||
}
|
||||
if self == rhs {
|
||||
return 1.into();
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
|
||||
&& let Some(c) = floor_div_i64(a, b)
|
||||
{
|
||||
return c.into();
|
||||
}
|
||||
|
||||
// Shape dimensions are non-negative, so the existing integer Div term
|
||||
// evaluates with floor semantics for dynamic shape expressions.
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Div);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Less than
|
||||
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
@@ -654,6 +679,16 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
|
||||
depth == 1
|
||||
}
|
||||
|
||||
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
|
||||
let q = a.checked_div(b)?;
|
||||
let r = a.checked_rem(b)?;
|
||||
if r != 0 && ((r > 0) != (b > 0)) {
|
||||
q.checked_sub(1)
|
||||
} else {
|
||||
Some(q)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Term> for Expression {
|
||||
fn from(value: Term) -> Self {
|
||||
Expression::new(vec![value])
|
||||
@@ -994,8 +1029,12 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
|
||||
|
||||
impl std::iter::Product for Expression {
|
||||
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
|
||||
// Empty product is the multiplicative identity, 1 — not 0. Returning
|
||||
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
|
||||
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
|
||||
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
|
||||
let Some(mut p) = iter.next() else {
|
||||
return 0.into();
|
||||
return 1.into();
|
||||
};
|
||||
for n in iter {
|
||||
p *= n;
|
||||
@@ -1106,6 +1145,27 @@ mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_product_is_one() {
|
||||
// The empty product (e.g. for a rank-0 tensor's shape) must be the
|
||||
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
|
||||
// emitters use `shape.iter().product()` to compute `numel`, and a
|
||||
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
|
||||
// launch with grid=(0, 1, 1) and crash at runtime.
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(
|
||||
empty.into_iter().product::<Expression>(),
|
||||
Expression::from(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_sum_is_zero() {
|
||||
// Sanity check the additive identity stays 0 (it always was).
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_simplifications() {
|
||||
let x = expr('x');
|
||||
|
||||
Reference in New Issue
Block a user