Compare commits

..

4 Commits

Author SHA1 Message Date
Tucker Morgan
0b1e09cf23 Merge remote-tracking branch 'origin/main' into codex/rust-stdio-benchmark 2026-05-14 17:01:19 +00:00
tucker-luminal
6416ddb5f8 Use parallel launches for small CUDA kernels (#315)
* Use parallel launches for cast and iota kernels

* Use parallel launch for embed kernel
2026-05-14 00:47:12 -04:00
Austin Glover
c9d4ce6217 Better scalar support: tests + 12 fixes (LUM-474) (#300)
* Add scalar torture test suite (LUM-474)

60 tests asserting strict shape, dtype, and value match between PyTorch
eager and luminal_backend. Includes 9 xfail markers (12 cases) for the
known scalar bugs being addressed under LUM-485 through LUM-490.

* Add aten.select.int support to luminal_python translator (LUM-487)

Single-element indexing (`x[0]`, `x[i, j]`, `x[1, 2, 3]`) lowers to
`aten.select.int` in the FX graph. The translator previously bailed
with "Unsupported ATen op", blocking any model that reads a scalar by
indexing.

Implements `aten.select.int(self, dim, index)` as
`slice_along(index..index+1, dim).squeeze(dim)` — a pure
shape-manipulation that the luminal compiler can fold into surrounding
ops, with a single iota for the slice. Negative `dim` is normalized via
the existing `normalize_dim` helper; negative `index` is normalized
against the (concrete) axis size, mirroring how `translate_gather`
normalizes negative gather indices.

Removes the four `xfail(_INDEX_SELECT_REASON)` markers in
`tests/test_scalar_torture.py` (and the now-unused reason constant);
these tests now pass. Final counts: 52 passed / 8 xfailed (was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Fix LUM-488: support rank-0 tensor mod/lt and add aten.remainder dispatch

Two related issues prevented `x % torch.tensor(c)` from translating:

1. The luminal_python translator did not dispatch aten.remainder.Tensor /
   aten.remainder.Scalar at all, so any module that mods a tensor against
   a 0-d torch.tensor failed with "Unsupported ATen op".

2. core::ops::Rem and GraphTensor::lt asserted exact dim equality, blocking
   rank-0 to rank-N broadcasting that the backend already supports
   transparently for Add/Mul (the input_shapes vec is forwarded to the
   strided iterator).

Drop the dim assertions in Rem and lt so they match Add/Mul's broadcast
behavior, and add aten.remainder.Tensor/Scalar handlers in dispatch.rs that
mirror aten.fmod.Tensor (with ensure_same_dtype + broadcast_binary). For
the Scalar form, build a constant_float and expand_rhs onto the LHS shape.

Tests:
- New proptests test_mod_scalar_broadcast / test_lt_scalar_broadcast in
  src/frontend/binary.rs cover rank-0 RHS via expand_rhs.
- Removed @pytest.mark.xfail from test_mod_by_scalar_tensor; added the
  test_scalar_torture.py file to luminal_python's test suite.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python translator: dispatch aten.clamp.Tensor (LUM-489)

torch.clamp(x, lo, hi) where lo/hi are 0-d tensors routes to
aten.clamp.Tensor, which the translator did not previously handle. Add
a dedicated dispatch that decomposes clamp(x, lo, hi) into
min(max(x, lo), hi), broadcasting each rank-0 bound up to x's shape via
expand_rhs. Either bound may be absent (PyTorch allows min=None or
max=None), so each side is applied only when its FX input is a tensor.

Removes the @pytest.mark.xfail on test_clamp_with_scalar_tensors;
test_scalar_torture now reports 50 passed / 10 xfailed (was 48 / 12).

* luminal_python: support aten.prod.default full-reduction (LUM-490)

The translator's dispatch table mapped aten.{sum,mean,amax,amin}.default
to translate_reduction but lacked an entry for aten.prod.default, so
x.prod() with no axis raised "Unsupported ATen op". Add the missing
dispatch entry; the ReductionOp::Prod branch in translate_reduction
already handles both full-reduce and dim-reduce cases.

aten.prod.dim_int was already wired up; verified it routes correctly.

Removes the xfail marker on test_prod_all_produces_scalar in
test_scalar_torture.py — suite now reports 50 passed / 10 xfailed
(was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: preserve int64 (and other integer) output dtypes (LUM-486)

Full reductions of int64 tensors silently downcast to int32 on the
PyTorch boundary because `output_dtypes` was stored as luminal `DType`,
which collapses every integer width to `DType::Int` (i32). The Python
wrapper therefore reported int32 to PyTorch even when the user passed
int64, breaking strict dtype checks and risking silent overflow on
larger reductions / downstream ops that require int64.

Store `output_dtypes` directly as PT2 dtype codes (the original PyTorch
type IDs) instead of converting through luminal `DType` first. This
preserves int64 vs int32 (and similar) end-to-end. The Python output
path now reads int outputs as i32 and casts to the requested torch
dtype, so int8/int16/int32/int64/uint8 outputs all round-trip with the
right type tag.

Updates two existing assertions (`test_argsort_stable_duplicates`,
`test_tiny_moe_routing`) that were pinning int32 — the new behavior
matches PyTorch eager (int64). Adds `test_reduce_sum_all_axes_int64_preserves_dtype`
as a regression check, and removes the xfail on
`test_int_sum_produces_int_scalar`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: parametrize argsort/MoE dtype tests over int32 and int64

The LUM-486 fix preserves whichever integer dtype the eager model declares
on output. The original tests hardcoded int64 (the dtype torch.argsort and
torch.topk natively produce), which only exercised one path through the
preservation logic.

Add an idx_dtype knob to ArgsortStableDuplicatesModel and TinyMoERoutingModel
that casts the integer outputs to the requested dtype, and parametrize both
tests over [torch.int32, torch.int64]. Internal indices (passed to gather /
scatter) stay int64 since PyTorch requires that for index tensors; the cast
applies only to the returned values.

LUM-486

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Remove xfail markers for fixed scalar bugs

Drops the @pytest.mark.xfail markers on tests now passing after the
LUM-486, LUM-487, LUM-488, LUM-489, and LUM-490 fixes:

  - test_prod_all_produces_scalar (LUM-490)
  - test_clamp_with_scalar_tensors (LUM-489)
  - test_mod_by_scalar_tensor (LUM-488)
  - test_index_1d_produces_scalar (LUM-487)
  - test_index_all_dims_produces_scalar (LUM-487)
  - test_index_then_add_scalar_const (LUM-487)
  - test_model_returns_scalar_from_index (LUM-487)
  - test_int_sum_produces_int_scalar (LUM-486)

Also removes the now-unused _INDEX_SELECT_REASON constant.

The single remaining xfail is test_unsqueeze_expand_sum_back, blocked
on LUM-485 (full reduction returns shape [1] instead of rank-0 ()).

* luminal_python: full reductions return rank-0 () instead of [1] (LUM-485)

The translator's full-reduce path used to flatten the input to [1, N] and
reduce axis 1, leaving a residual [1] dimension. PyTorch eager produces
rank-0 () for x.sum() etc., and downstream ops (e.g. unsqueeze(0).expand(5))
rely on that rank — the residual [1] caused panics like "Cannot expand
from 2 dims to 1 dims" once the scalar fed any further op.

Drop the flatten and reduce over every axis directly. Special-case rank-0
input as a no-op so reducing a scalar is well-defined. Mean still divides
by the cached total to avoid redundant axis-prod work.

Removes the xfail marker on test_unsqueeze_expand_sum_back, which now
passes. With this commit the integration branch has zero xfails:
284 passed across test_scalar_torture.py + test_hlir_ops.py + test_unary.py.

* ruff format: tests/test_hlir_ops.py

Collapse a two-line f-string into one line per ruff format. No behavior change.

* Expand scalar torture suite with PyTorch / NumPy gap coverage

Cross-referenced our suite against PyTorch's test_torch / test_reductions
/ test_view_ops / test_indexing / test_type_promotion / test_binary_ufuncs
and NumPy's test_multiarray / test_indexing / test_shape_base. Added 14
new sections covering 47 in-scope gaps:

  - Binary ops with INPUT 0-d (not reduction-derived) on either side:
    add/sub/mul/div/mod/maximum/minimum/pow/floor_divide
  - Pure 0-d ↔ 0-d arithmetic (no broadcasting required)
  - Full comparison set (gt/ge/lt/le/eq/ne) on input 0-d, plus mask-by-eq
  - Reduction extras: argmax/argmin (no-arg + keepdim), sum(dim=()),
    sum/mean of 0-d input, cumsum of 0-d
  - Shape-flattening on 0-d: flatten/ravel/reshape(-1)/view(-1) all
    return shape (1,); reshape(()) on 1-element collapses to (); plus
    permute([]), contiguous(), squeeze() of (1,1,1,1), expand_as
  - Indexing extras: ellipsis x[...], index by 0-d int tensor, gather
    with 0-d index, negative-index x[-1]
  - Type promotion: float-0-d + int-Nd, int-0-d + float-Nd, cast
    roundtrip through 0-d, .float()/.int() shorthands, where with
    mixed-dtype scalar branches
  - Unary math (abs/neg/exp/sin/cos/tanh/sigmoid/sqrt/sign/floor/ceil)
    on reduction-derived 0-d
  - Bool logic: AND, OR, XOR, NOT on 0-d bool from comparisons
  - Stack of 0-ds; cat of unsqueezed 0-ds
  - Constants: torch.full((), v), torch.full_like on 0-d
  - Reduction edge cases: keepdim across all axes then divide;
    scalar broadcast onto transposed tensor
  - Mixed where/clamp shapes: clamp(x, scalar_tensor, py_float),
    where(cond, scalar_tensor, x)
  - Multi-output models: (scalar, tensor) tuple

Result: 363 passed / 15 xfailed across the python suite. The 15 new
xfails are documented inline with concrete failure modes:

  - 6 op-coverage gaps: aten.argmax.default, aten.argmin.default,
    aten.eq.Scalar, aten.ne.Tensor (translator dispatch entries needed).
  - 2 PT2 export issues: 0-d int64 graph inputs hit "invalid type: null,
    expected i64" in luminal's model.json parser; affects
    test_int_0d_plus_float_nd and test_gather_with_0d_index.
  - 2 real correctness bugs:
      * floor_divide with 0-d divisor returns the un-floored quotient
        (float division result, not floor(x/d)).
      * cumsum on a 0-d tensor panics with index-out-of-bounds.
  - 1 dynamo guard edge case: torch._dynamo emits an unresolved 'L'
    name in _guards_fn for 0-d index tensors.

Plus 4 cross-marker xfails on consequence of the above (the parametric
ne case, mask_by_scalar_eq variants, and other downstream effects).

* Rename test_scalar_torture.py -> test_scalars.py; drop 'torture' wording

The original 'torture test' label is jargon. The file is just a scalar
test module — keep the name simple to match the rest of the suite
(test_unary.py, test_hlir_ops.py).

* luminal_python: parse rounding_mode string arg correctly (LUM-494)

torch.floor_divide(x, d) decomposes to aten.div.Tensor_mode with
rounding_mode='floor' during PT2 export. The translator was reading
the kwarg via serde_json::Value::as_str(), but PT2 serializes string
args as {"as_string": "<value>"} objects, not bare JSON strings. The
extraction silently returned None, so the floor branch was skipped
and the regular un-floored quotient was returned.

Drill into the as_string field as a fallback so floor_divide and
div(x, d, rounding_mode='floor'/'trunc') produce floor(x/d) /
trunc(x/d) as expected.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix cumsum on rank-0 tensor (LUM-495)

The translator's cumsum handler called normalize_dim(dim, a.shape.len())
and then a.cumsum(dim) for any rank — including rank-0. The underlying
cumop in src/frontend/unary.rs indexes self.dims()[axis] inside the
padding/unfold loop, which panics with "index out of bounds: the len is
0 but the index is 0" when shape is empty.

PyTorch eager treats torch.cumsum(s, 0) on a 0-d tensor as an identity
op (cumsum of a single element is the element itself). Mirror the
rank-0 short-circuit pattern from the LUM-485 reduction fix and return
the input unchanged when a.shape.is_empty(). Move the dim arg fetch
inside the non-empty branch since dim is unused for rank-0.

Drops the xfail marker on test_cumsum_of_0d and adds a 1-element 1-D
sibling test that asserts shape (1,) round-trips.

* luminal_python: support aten.argmax/argmin (LUM-496)

argmax/argmin were missing from the translator dispatch table even
though we already have stable_argsort. Add a thin wrapper so the
PyTorch boundary lights up:

  argmax(x, dim=None)    -> argsort(flatten(x), descending=True).select(0, 0)
  argmax(x, dim=N)       -> argsort(x, dim=N, descending=True).select(N, 0)
  argmax(x, dim=N, keepdim=True) -> .unsqueeze(N) over the above
  argmin(...)            -> same with descending=False

The slice + squeeze chain produces a non-contiguous DType::Int view
whose underlying buffer is still sized for the un-sliced argsort
tensor. Final `* 1` materializes a contiguous Int copy with strides
matching the visible shape — same trick `translate_topk` uses for
its sliced index output. Without it the keepdim case panics
("No output node found") and the full-reduce case throws a Python
shape mismatch on the oversized buffer.

PyTorch's argmax returns int64 while luminal collapses to int32 (Int);
LUM-486 already widens at the Python boundary, so the contract is
preserved end-to-end. Drops the three `@pytest.mark.xfail` markers
from `test_argmax_all`, `test_argmin_all`, and `test_argmax_keepdim_1d`
in `test_scalars.py` (6 cases via parametrization).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: dispatch aten.eq.Scalar and aten.ne.Tensor (LUM-497)

Add the two missing comparison overloads to the translator dispatch.
eq.Scalar mirrors the existing ne.Scalar handler (constant_float + cast
+ expand_rhs to broadcast the scalar), and ne.Tensor mirrors the
existing eq.Tensor handler. Removes the corresponding xfail markers on
test_input_0d_comparisons[_NeInput0ds-...] and test_mask_by_scalar_eq.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: accept null range_constraint bounds (LUM-498)

A 0-d int64 graph input made PyTorch 2.10+ emit
`range_constraints: { sN: { min_val: null, max_val: null } }` for the
unbacked symbol PT2 introduces around the rank-0 tensor. Our serde
schema modeled `RangeConstraint.min_val` as `i64`, so deserialization
failed with `invalid type: null, expected i64`, blocking any model
with a scalar integer tensor input.

Make `min_val` and `max_val` `Option<i64>` (matching PT2's
`Optional[int]`) and fall back to 1 as the initial dynamic-dim value
when no lower bound is provided.

Tests: removes the xfail on `test_int_0d_plus_float_nd`, adds a new
`test_int32_0d_plus_float_nd` regression, and updates the xfail
reason on `test_gather_with_0d_index` (the parse error is fixed; a
separate downstream gather panic remains).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: drop dynamo input guards in pt2_backend (LUM-499)

When a 0-d int tensor is used as a tensor index (x[i] where
i = torch.tensor(2)), torch.export records duplicate input guards that
reference both the original local source (L['i']) and the rewrapped flat
args (L['args'][1]). The unlift pass cannot resolve L['i'] against the
wrapped (*args, **kwargs) signature, leaving a literal `L` reference in
the generated _guards_fn that raises NameError during retracing. The
data-dependent .item() in the surviving guard then trips fake-tensor
analysis with DataDependentOutputException.

Drop the guard list before run_decompositions so unlift produces an
empty _guards_fn, and DCE any leftover dead aten.item.default nodes
that came from index specialization.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix gather with rank-0 index on rank-1 source

PyTorch eager allows torch.gather(rank-1, dim, rank-0) — the only
rank-mismatch case it permits — and returns a rank-0 scalar. Our
gather_elements requires source-rank == index-rank, so the rank-0
index hit flatten_strides with mismatched (0, 1) lengths and
panicked.

Detect this specific pattern in translate_gather: unsqueeze the
rank-0 index to (1,), gather, then squeeze the result back to ().
Output shape and value match eager.

This was the last remaining xfail in test_scalars.py. Suite is now
381 passed / 0 xfailed / 0 failed across test_scalars.py +
test_hlir_ops.py + test_unary.py.

* luminal_python: clamp.Tensor handles all broadcastable bound shapes

PyTorch's aten.clamp.Tensor accepts bounds with any NumPy-broadcastable
shape (rank-0, same-shape, or broadcastable). The previous translator
used expand_rhs(result.shape) which appends dims rather than broadcasts,
so only rank-0 bounds came out correctly. Same-shape and broadcastable
bounds either panicked or silently produced wrong values.

Switch to broadcast_binary (the right-align + size-1 expand helper used
by aten.remainder.Tensor, aten.eq.Tensor, etc.). Now all three modes
work uniformly.

Add 7 new tests covering the previously-broken modes:
  - same-shape bounds (per-element clamp, e.g. learned bounds)
  - per-row broadcast (3,1) against (3,4)
  - per-col broadcast (4,) against (3,4)
  - mixed rank-0 lo + same-shape hi
  - min-only with same-shape lo
  - max-only with per-row hi
  - 3-D x with 2-D bounds (left-unsqueeze broadcast)

Suite goes from 381 to 388 passing, 0 xfailed.

* shape: empty Expression product returns 1, not 0

The empty product is the multiplicative identity (1) — every shape-iterator
call site (`shape.iter().product()` for `numel`, output-buffer sizing, CUDA
grid-dim computation) implicitly relies on this. The previous impl returned
0 for an empty iterator, which was a latent bug masked while no path
produced rank-0 shapes.

The LUM-485 fix (full reductions return rank-0 () instead of rank-1 [1])
exposed it on CUDA: SumReduce kernels with rank-0 output got `n_outputs=0`,
launched with `grid=(0, 1, 1)`, and crashed with "invalid CUDA launch
dimensions" — every CUDA reduction in the Python CUDA tests was failing.

Fix: return Expression::from(1) for empty iteration. Sum's identity (0)
was already correct and is unchanged.

Add two unit tests covering both identities.

* cargo fmt

* Fix PT2 passthrough input output ID collision

* Fix scalar argextremum keepdim behavior

* Defer PT2 interface collision fix

* Keep HLIR binary ops shape-strict

* fixed gemma issue

* Fix explicit broadcasts and conv shape division

* Normalize Whisper cache slice shape

---------

Co-authored-by: Austin Glover <austin@luminal.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Austin Glover <austin_glover@berekely.edu>
Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-13 20:16:30 -04:00
Tucker Morgan
7402503bd4 Add stdio mode to Rust benchmark examples 2026-05-12 21:21:29 +00:00
46 changed files with 2873 additions and 2408 deletions

View File

@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn gather_experts(

View File

@@ -5,7 +5,6 @@ use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
use luminal::prelude::*;
use crate::cudarc::driver::CudaContext;
use crate::host::describe_host_op;
use crate::runtime::CudaRuntime;
/// [`DynBackend`] wrapper for [`CudaRuntime`].
@@ -40,26 +39,6 @@ impl DynBackend for CudaLiteDynBackend {
self.runtime.execute(dyn_map);
}
fn kernel_names(&self) -> Vec<String> {
self.runtime
.kernel_names()
.iter()
.map(|name| (*name).to_string())
.collect()
}
fn host_op_names(&self) -> Vec<String> {
self.runtime
.host_ops()
.iter()
.map(|op| describe_host_op(*op))
.collect()
}
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
fn supports_device_ptrs(&self) -> bool {
true
}

View File

@@ -247,10 +247,6 @@ impl HostOp for CuBlasSgemmV2 {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasSgemmV2")
}
fn output_size(&self) -> Expression {
self.m * self.n
}

View File

@@ -419,6 +419,7 @@ fn transpose_op_name(op: cublasOperation_t) -> &'static str {
}
}
#[cfg(test)]
fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
match epilogue {
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => "DEFAULT",
@@ -977,18 +978,6 @@ impl CuBlasLt {
&& normalize(self.stride_c) == normalize(self.stride_d)
&& self.c_order == self.d_order
}
pub(crate) fn debug_summary(&self) -> String {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
format!(
"CuBlasLt[m={}, n={}, k={}, batch={}, epilogue={}]",
resolve(&self.m),
resolve(&self.n),
resolve(&self.k),
resolve(&self.batch_count),
epilogue_name(self.epilogue),
)
}
}
impl HostOp for CuBlasLt {
@@ -1125,10 +1114,6 @@ impl HostOp for CuBlasLt {
Ok(())
}
fn stats_name(&self) -> Option<&'static str> {
Some("CuBlasLt")
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)

View File

@@ -1,7 +1,6 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use crate::kernel::CudaGraphOp;
use luminal::{op::EgglogOp, prelude::*};
pub mod compute_attn_mask;
mod cublas;
@@ -80,24 +79,6 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
.map(cublaslt::CuBlasLt::c_d_layouts_match)
}
pub(crate) fn describe_host_op(op: &dyn HostOp) -> String {
if let Some(op) = op.as_any().downcast_ref::<cublaslt::CuBlasLt>() {
return op.debug_summary();
}
if let Some(op) = op.as_any().downcast_ref::<CudaGraphOp>() {
let mut summary = op.debug_summary();
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& let Some(timing) = op.debug_timing_summary()
{
summary.push_str(" [");
summary.push_str(&timing);
summary.push(']');
}
return summary;
}
op.stats_name().unwrap_or("unknown").to_string()
}
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside

View File

@@ -12,10 +12,7 @@ use luminal::{
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
extract_dtype, extract_expr, extract_expr_list,
},
hlir::{
Add, Concat2D, EmbeddingBagSum, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter,
Sin, Sqrt, SumReduce,
},
hlir::{Add, Exp2, LessThan, Log2, MaxReduce, Mod, Mul, Recip, Scatter, Sin, Sqrt, SumReduce},
op::*,
prelude::*,
};
@@ -68,8 +65,6 @@ pub type Ops = (
KernelConstant,
KernelCast,
KernelEmbed,
KernelConcat2D,
KernelEmbeddingBagSum,
);
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
@@ -1545,19 +1540,22 @@ impl KernelOp for KernelIota {
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
vars.extend(self.range.dyn_vars());
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let range = self.range.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void iota_k(int *C{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {range}) return;
C[const_z] = {};
}}
}}",
@@ -1576,8 +1574,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.range, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.range.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -2940,6 +2938,14 @@ impl KernelOp for KernelCast {
) {
let out_dtype = cuda_dtype(self.out_dtype);
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let size = self.size.to_kernel();
let kernel = if self.in_dtype.bits() < 8 {
// Sub-byte packed types: multiple values packed per byte.
@@ -2949,9 +2955,11 @@ impl KernelOp for KernelCast {
let mask = (1u32 << bits) - 1;
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {size}) return;
long long bit_offset = idx * {bits};
long long byte_idx = bit_offset >> 3;
int bit_pos = (int)(bit_offset & 7);
@@ -2967,9 +2975,11 @@ extern \"C\" {{
let in_dtype = cuda_dtype(self.in_dtype);
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {size}) return;
out[const_z] = ({out_dtype})in[const_z];
}}
}}"
@@ -2988,8 +2998,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.size, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.size.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -3280,12 +3290,15 @@ impl KernelOp for KernelEmbed {
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let total_threads = batch_size * self.embed_dim;
let n_elements = total_threads.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {n_elements}) return;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
long long embed_idx = idx % embed_dim;
@@ -3308,13 +3321,12 @@ extern \"C\" {{
};
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
let total_threads = batch_size * self.embed_dim;
(
func,
module,
kernel,
(total_threads, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(total_threads.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
constants,
)
@@ -3359,361 +3371,3 @@ extern \"C\" {{
"Embed"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelEmbeddingBagSum {
n_bags: Expression,
n_indices: Expression,
hidden_dim: Expression,
num_embeddings: Expression,
dtype: DType,
}
impl EgglogOp for KernelEmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelEmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<EmbeddingBagSum, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[4]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelEmbeddingBagSum {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelEmbeddingBagSum only supports F32 weights today, got {:?}",
self.dtype
);
let vars = self
.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_bags = self.n_bags.to_kernel();
let n_indices = self.n_indices.to_kernel();
let hidden_dim = self.hidden_dim.to_kernel();
let num_embeddings = self.num_embeddings.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embedding_bag_sum(float *out, const float *weight, const int *indices, const int *offsets{dyn_dims_param}) {{
long long dim = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long bag = blockIdx.y;
long long hidden_dim = {hidden_dim};
long long n_bags = {n_bags};
long long n_indices = {n_indices};
long long num_embeddings = {num_embeddings};
if (bag >= n_bags || dim >= hidden_dim) return;
int start_raw = offsets[bag];
int end_raw = (bag + 1 < n_bags) ? offsets[bag + 1] : (int)n_indices;
int start = max(0, min(start_raw, (int)n_indices));
int end = max(start, min(end_raw, (int)n_indices));
float sum = 0.0f;
for (int pos = start; pos < end; ++pos) {{
int row = indices[pos];
row = max(0, min(row, (int)num_embeddings - 1));
sum += weight[(long long)row * hidden_dim + dim];
}}
out[bag * hidden_dim + dim] = sum;
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embedding_bag_sum").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(self.hidden_dim.ceil_div(256), self.n_bags, 1.into()),
(self.hidden_dim.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.n_bags * self.hidden_dim
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.n_bags
.dyn_vars()
.into_iter()
.chain(self.n_indices.dyn_vars())
.chain(self.hidden_dim.dyn_vars())
.chain(self.num_embeddings.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
// Approximate: weights + indices + offsets
self.n_indices * (self.hidden_dim * 4 + 4) + self.n_bags * 4
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.n_indices * self.hidden_dim
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"EmbeddingBagSum"
}
}
#[derive(Default, Debug, Clone)]
pub struct KernelConcat2D {
rows: Expression,
lhs_cols: Expression,
rhs_cols: Expression,
dtype: DType,
}
impl EgglogOp for KernelConcat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelConcat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<Concat2D, Self>()]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelConcat2D {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
assert!(
self.dtype == DType::F32,
"KernelConcat2D only supports F32 today, got {:?}",
self.dtype
);
let vars = self
.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let rows = self.rows.to_kernel();
let lhs_cols = self.lhs_cols.to_kernel();
let rhs_cols = self.rhs_cols.to_kernel();
let total = (self.rows * (self.lhs_cols + self.rhs_cols)).to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void concat_2d(float *out, const float *lhs, const float *rhs{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = {total};
if (idx >= total) return;
long long rows = {rows};
long long lhs_cols = {lhs_cols};
long long rhs_cols = {rhs_cols};
long long out_cols = lhs_cols + rhs_cols;
if (rows == 0 || out_cols == 0) return;
long long row = idx / out_cols;
long long col = idx - row * out_cols;
if (col < lhs_cols) {{
out[idx] = lhs[row * lhs_cols + col];
}} else {{
long long rhs_col = col - lhs_cols;
out[idx] = rhs[row * rhs_cols + rhs_col];
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("concat_2d").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let output_size = self.output_size();
(
func,
module,
kernel,
(output_size.ceil_div(256), 1.into(), 1.into()),
(output_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.rows * (self.lhs_cols + self.rhs_cols)
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.rows
.dyn_vars()
.into_iter()
.chain(self.lhs_cols.dyn_vars())
.chain(self.rhs_cols.dyn_vars())
.collect()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
0.into()
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Concat2D"
}
}

View File

@@ -4,8 +4,6 @@
//! that can be executed like any other HostOp.
use std::cell::RefCell;
use std::cmp::Reverse;
use std::collections::BTreeMap;
use std::sync::Arc;
use cudarc::driver::{
@@ -143,8 +141,6 @@ struct CudaGraphOpState {
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
/// Timing events for profiling
timing_events: Vec<cudarc::driver::sys::CUevent>,
/// Last per-kernel GPU timings (microseconds) captured for diagnostics.
last_kernel_timings_us: Vec<(&'static str, f64)>,
}
impl CudaGraphOpState {
@@ -159,7 +155,6 @@ impl CudaGraphOpState {
last_dyn_values: FxHashMap::default(),
last_buffer_ptrs: FxHashMap::default(),
timing_events: Vec::new(),
last_kernel_timings_us: Vec::new(),
}
}
}
@@ -197,41 +192,6 @@ impl CudaGraphOp {
state: RefCell::new(state),
}
}
pub fn debug_summary(&self) -> String {
let state = self.state.borrow();
let mut counts: BTreeMap<&'static str, usize> = BTreeMap::new();
for kernel in &state.kernels {
*counts.entry(kernel.kernel_name).or_default() += 1;
}
let mut counts: Vec<_> = counts.into_iter().collect();
counts.sort_by_key(|(name, count)| (Reverse(*count), *name));
let top = counts
.into_iter()
.take(4)
.map(|(name, count)| format!("{name}x{count}"))
.join(", ");
format!("CudaGraph[{} kernels: {top}]", state.kernels.len())
}
pub fn debug_timing_summary(&self) -> Option<String> {
let state = self.state.borrow();
if state.last_kernel_timings_us.is_empty() {
return None;
}
let mut totals: BTreeMap<&'static str, f64> = BTreeMap::new();
for (name, us) in &state.last_kernel_timings_us {
*totals.entry(*name).or_default() += *us;
}
let mut totals: Vec<_> = totals.into_iter().collect();
totals.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(b.0)));
let top = totals
.into_iter()
.take(4)
.map(|(name, us)| format!("{name}={us:.0}us"))
.join(", ");
Some(top)
}
}
impl std::fmt::Debug for CudaGraphOp {
@@ -606,23 +566,6 @@ impl CudaGraphOp {
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
if std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some()
&& state.timing_events.len() >= state.kernels.len() + 1
{
stream.synchronize()?;
let ctx = stream.context().clone();
state.last_kernel_timings_us.clear();
for idx in 0..state.kernels.len() {
let start_event = state.timing_events[idx];
let end_event = state.timing_events[idx + 1];
let kernel_name = state.kernels[idx].kernel_name;
let us = crate::kernel::event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
state.last_kernel_timings_us.push((kernel_name, us));
}
}
Ok(())
}
@@ -641,9 +584,8 @@ impl CudaGraphOp {
state.kernel_params.clear();
state.kernel_params.reserve(num_kernels);
let profile_cuda_graph = std::env::var_os("LUMINAL_PROFILE_CUDA_GRAPH").is_some();
let tracing_enabled = enabled!(Level::TRACE);
if tracing_enabled || profile_cuda_graph {
if tracing_enabled {
let needed_events = num_kernels + 1;
while state.timing_events.len() < needed_events {
state.timing_events.push(create_cuda_event(&ctx)?);
@@ -759,7 +701,7 @@ impl CudaGraphOp {
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled || profile_cuda_graph {
let timing_event = if tracing_enabled {
Some(state.timing_events[idx])
} else {
None
@@ -797,9 +739,7 @@ impl CudaGraphOp {
prev_graph_node = Some(graph_node);
}
if (tracing_enabled || profile_cuda_graph)
&& let Some(prev) = prev_graph_node
{
if tracing_enabled && let Some(prev) = prev_graph_node {
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}

View File

@@ -1,9 +1,6 @@
use crate::{
host::{DeviceBuffer, HostOp, describe_host_op},
kernel::{
CudaGraphTiming, KernelOp, create_cuda_event, destroy_cuda_event, event_elapsed_ms,
record_cuda_graph_timings, record_event_on_stream,
},
host::{DeviceBuffer, HostOp},
kernel::{CudaGraphTiming, KernelOp, record_cuda_graph_timings},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, result};
@@ -63,12 +60,6 @@ pub struct KernelStats {
pub tflops: f64,
}
#[derive(Debug, Clone)]
pub struct HostOpStats {
pub name: String,
pub execution_time_us: f64,
}
impl Debug for ExecutableHostOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HostOp: ({:?})", self.internal)
@@ -150,7 +141,6 @@ pub struct CudaRuntime {
changed_hlir: FxHashSet<NodeIndex>,
pub(crate) cuda_graph_timings: Vec<(CudaGraphTiming, Uuid)>,
pub last_kernel_stats: Vec<KernelStats>,
pub last_host_op_stats: Vec<HostOpStats>,
pub last_total_time_us: f64,
kernel_cache: FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
/// When true, execute() skips input buffer consumption (used during search/profile)
@@ -1213,7 +1203,6 @@ impl Runtime for CudaRuntime {
changed_hlir: FxHashSet::default(),
cuda_graph_timings: vec![],
last_kernel_stats: vec![],
last_host_op_stats: vec![],
last_total_time_us: 0.0,
kernel_cache: FxHashMap::default(),
profiling: false,
@@ -1465,8 +1454,6 @@ impl Runtime for CudaRuntime {
let total_start = std::time::Instant::now();
let bucket = &self.compiled_buckets[self.active_bucket];
let profile_host_ops = std::env::var_os("LUMINAL_PROFILE_HOST_OPS").is_some();
let mut host_timing_events = Vec::new();
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
let exec_op = &bucket.exec_graph[exec_node];
@@ -1520,16 +1507,6 @@ impl Runtime for CudaRuntime {
n_inputs = exec_op.inputs.len()
)
.entered();
let host_op_timing = if profile_host_ops {
let name = describe_host_op(exec_op.internal.as_ref().as_ref());
let ctx = exec_op.stream.context().clone();
let start_event = create_cuda_event(&ctx).unwrap();
let end_event = create_cuda_event(&ctx).unwrap();
record_event_on_stream(&ctx, start_event, &exec_op.stream).unwrap();
Some((name, ctx, start_event, end_event))
} else {
None
};
exec_op
.internal
.execute(
@@ -1545,28 +1522,10 @@ impl Runtime for CudaRuntime {
exec_op.internal.stats_name().unwrap_or("unknown")
);
});
if let Some((name, ctx, start_event, end_event)) = host_op_timing {
record_event_on_stream(&ctx, end_event, &exec_op.stream).unwrap();
host_timing_events.push((name, ctx, start_event, end_event));
}
}
// Single sync at end - CUDA stream ordering guarantees sequential execution
self.cuda_stream.synchronize().unwrap();
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
self.last_host_op_stats.clear();
if profile_host_ops {
for (name, ctx, start_event, end_event) in host_timing_events {
let execution_time_us = event_elapsed_ms(&ctx, start_event, end_event)
.map(|ms| ms as f64 * 1_000.0)
.unwrap_or(0.0);
self.last_host_op_stats.push(HostOpStats {
name,
execution_time_us,
});
destroy_cuda_event(&ctx, start_event);
destroy_cuda_event(&ctx, end_event);
}
}
// Populate last_kernel_stats from HostOps that report stats
self.last_kernel_stats.clear();
@@ -1902,16 +1861,6 @@ impl CudaRuntime {
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
if !self.last_host_op_stats.is_empty() {
println!("\n=== Host Operation Statistics ===\n");
println!("{:<20} {:>12}", "HostOp", "Time (us)");
println!("{}", "-".repeat(34));
for s in &self.last_host_op_stats {
println!("{:<20} {:>12.2}", s.name, s.execution_time_us);
}
println!("{}", "-".repeat(34));
}
// Print kernel stats
if !self.last_kernel_stats.is_empty() {
println!("\n=== Kernel Execution Statistics ===\n");

View File

@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
QwenMoeGraph {
graph: cx,
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
GemmaMoeGraph {
graph: cx,

View File

@@ -61,7 +61,8 @@ impl MoE {
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
weights_exp.shape.expand(expert_out.dims());
(expert_out * weights_exp).sum(n - 1)
}
}
@@ -478,7 +479,8 @@ mod tests {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
let _output = (down_out * weights_exp).sum(n - 1).output();
// Dump the HLIR to egglog

View File

@@ -855,8 +855,6 @@ Two important details:
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
---
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.

View File

@@ -98,7 +98,12 @@ pub struct GraphTranslation {
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
/// distinctions luminal collapses internally — notably int64 vs int32,
/// both of which map to `DType::Int` in luminal but must be reported
/// back to PyTorch with their original precision.
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -124,7 +129,9 @@ pub struct CompiledGraph {
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
/// that luminal collapses to `DType::Int` internally).
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -248,23 +255,6 @@ impl CompiledGraph {
self.runtime.device_type()
}
/// Names of kernels compiled into the active runtime bucket, if available.
#[getter]
fn kernel_names(&self) -> Vec<String> {
self.runtime.kernel_names()
}
/// Names of host ops in the active runtime bucket, if available.
#[getter]
fn host_op_names(&self) -> Vec<String> {
self.runtime.host_op_names()
}
/// Print backend execution statistics for the last run, if supported.
fn print_execution_stats(&self) {
self.runtime.print_execution_stats();
}
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
#[getter]
fn supports_device_ptrs(&self) -> bool {
@@ -493,10 +483,7 @@ impl CompiledGraph {
/// Get the PT2 dtype codes for all outputs (in order).
#[getter]
fn output_dtypes(&self) -> Vec<u32> {
self.output_dtypes
.iter()
.map(|d| luminal_dtype_to_pt2_code(*d))
.collect()
self.output_dtypes.clone()
}
/// Get output tensor data by name as f32 (copies to host).

View File

@@ -262,10 +262,13 @@ pub fn translate_pt2(
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges
// Set initial dynamic dim values from symbol ranges. PT2 emits
// `min_val: null` when the constraint is unbounded; fall back to 1 in
// that case (the smallest valid dim — used only as an initial value).
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
graph.set_dim(*c, rc.min_val as usize);
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
graph.set_dim(*c, initial);
}
}
@@ -281,14 +284,14 @@ pub fn translate_pt2(
})
.collect();
let output_dtypes: Vec<DType> = translated
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
// Python wrapper can return tensors with the right torch.dtype, even when
// luminal collapses the type internally (e.g. int64 → DType::Int).
let output_dtypes: Vec<u32> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
.unwrap_or(DType::F32)
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
})
.collect();

View File

@@ -15,7 +15,16 @@ pub struct ExportedProgram {
#[derive(Debug, Clone, Deserialize)]
pub struct RangeConstraint {
pub min_val: i64,
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
/// constraint is unbounded (no min set), so this must accept None.
#[serde(default)]
pub min_val: Option<i64>,
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
/// unused on the luminal side, but accepted to avoid deserialization
/// errors when PT2 emits it.
#[serde(default)]
#[allow(dead_code)]
pub max_val: Option<i64>,
}
#[derive(Debug, Deserialize)]

View File

@@ -1,4 +1,4 @@
use anyhow::{Result, bail};
use anyhow::Result;
use luminal::prelude::*;
use crate::pt2_schema::*;
@@ -8,62 +8,21 @@ use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
let alpha = match op {
BinaryOp::Add | BinaryOp::Sub => self.get_float_arg(node, 2).unwrap_or(1.0) as f32,
BinaryOp::Mul | BinaryOp::Div => 1.0,
};
let lhs = node.inputs[0]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
let rhs = node.inputs[1]
.arg
.as_tensor_name()
.map(|name| self.get_tensor(name))
.transpose()?;
match (lhs, rhs) {
(Some(a), Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(Some(a), None) => {
let mut val = self.get_float_arg(node, 1)? as f32;
if alpha != 1.0 {
val *= alpha;
}
Ok(self.apply_scalar_op(a, val, op))
}
(None, Some(mut b)) => {
if alpha != 1.0 {
b = self.apply_scalar_op(b, alpha, BinaryOp::Mul);
}
let lhs_val = self.get_float_arg(node, 0)? as f32;
let a = self
.graph
.constant_float(lhs_val)
.cast(b.dtype)
.expand_rhs(b.shape);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
}
(None, None) => bail!("{} expects at least one tensor operand", node.target),
let a = self.get_input_tensor(node, 0)?;
let arg1 = &node.inputs[1].arg;
if let Some(name) = arg1.as_tensor_name() {
let b = self.get_tensor(name)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
} else {
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
}
}

View File

@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, 1);
let mut b_expanded = b.expand_dim(0, out_dims[0]);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
@@ -389,8 +389,11 @@ fn depthwise_conv(
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
let patches = patches.expand_dim(2, group_out);
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
// Explicitly expand weight across the batch axis so the elementwise Mul
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
let w_expanded = w_flat
.expand_dim(0, patches.dims()[0])
.expand_dim(3, patches.dims()[3]);
// Element-wise multiply and sum over kernel dim
let product = patches * w_expanded;

View File

@@ -6,6 +6,7 @@ use crate::pt2_util::*;
use super::Translator;
use super::attention::SdpaVariant;
use super::reduction::ArgExtremum;
impl<'a> Translator<'a> {
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
@@ -111,7 +112,6 @@ impl<'a> Translator<'a> {
result
}
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
"torch.ops.aten.repeat.default" => self.translate_repeat(node)?,
"torch.ops.aten.clone.default" => {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
@@ -134,28 +134,8 @@ impl<'a> Translator<'a> {
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
if alpha == 0.0 && beta == 0.0 {
self.graph
.constant_float(0.0)
.cast(mm.dtype)
.expand_rhs(mm.shape)
} else if beta == 0.0 {
if alpha == 1.0 { mm } else { mm * alpha }
} else if alpha == 0.0 {
let input = if beta == 1.0 { input } else { input * beta };
let zero = self
.graph
.constant_float(0.0)
.cast(input.dtype)
.expand_rhs(mm.shape);
let (input, _) = broadcast_binary(input, zero);
input
} else {
let input = if beta == 1.0 { input } else { input * beta };
let mm = if alpha == 1.0 { mm } else { mm * alpha };
let (input, mm) = broadcast_binary(input, mm);
input + mm
}
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
}
// Convolution
@@ -171,11 +151,6 @@ impl<'a> Translator<'a> {
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
"torch.ops.aten._embedding_bag.default"
| "torch.ops.aten._embedding_bag_forward_only.default" => {
self.translate_embedding_bag(node)?
}
"<built-in function getitem>" => self.translate_getitem(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
@@ -190,9 +165,6 @@ impl<'a> Translator<'a> {
// LayerNorm
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
"torch.ops.aten._native_batch_norm_legit_no_training.default" => {
self.translate_native_batch_norm_no_training(node)?
}
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,
@@ -249,6 +221,16 @@ impl<'a> Translator<'a> {
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
// Tensor comparisons
"torch.ops.aten.eq.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a.eq(scalar)
}
"torch.ops.aten.ne.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
@@ -266,6 +248,13 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a.eq(b)
}
"torch.ops.aten.ne.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.ne(b)
}
"torch.ops.aten.le.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
@@ -304,18 +293,27 @@ impl<'a> Translator<'a> {
// Clamp
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
// Cumsum
"torch.ops.aten.cumsum.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let a = if a.dtype == DType::Bool {
a.cast(DType::Int)
} else {
a
};
a.cumsum(dim)
// Rank-0 (scalar) input: cumsum of a single element is the element
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
// and the underlying `cumop` indexes `shape.dims[axis]` which would
// panic with empty dims.
if a.shape.is_empty() {
a
} else {
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
a.cumsum(dim)
}
}
// Floor / Ceil / Erf (approximations)
@@ -411,6 +409,17 @@ impl<'a> Translator<'a> {
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
// PyTorch's argmax/argmin returns int64; the dtype is preserved
// through the LUM-486 boundary widening.
"torch.ops.aten.argmax.default" => {
self.translate_argextremum(node, ArgExtremum::Max)?
}
"torch.ops.aten.argmin.default" => {
self.translate_argextremum(node, ArgExtremum::Min)?
}
// Gather (axis-aware)
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
@@ -474,6 +483,28 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a % b
}
// Remainder (Python-style modulo). For float tensors aten.remainder
// returns the same value as `%` would in luminal (Mod follows the
// language's % semantics on f32). The Tensor variant accepts a
// tensor RHS that may be rank-0; broadcast both operands so a
// scalar RHS is expanded to match the LHS shape before mod.
"torch.ops.aten.remainder.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a % b
}
"torch.ops.aten.remainder.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a % scalar
}
// Prod reduction
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,

View File

@@ -12,64 +12,6 @@ const SCATTER_INDEX_ARG: usize = 2;
const SCATTER_VALUE_ARG: usize = 3;
impl<'a> Translator<'a> {
fn try_concat_2d_fast(
&mut self,
lhs: GraphTensor,
rhs: GraphTensor,
axis: usize,
) -> Option<GraphTensor> {
if axis != 1
|| lhs.dtype != DType::F32
|| rhs.dtype != DType::F32
|| lhs.shape.len() != 2
|| rhs.shape.len() != 2
|| !lhs.shape.is_contiguous()
|| !rhs.shape.is_contiguous()
|| lhs.shape.dims[0] != rhs.shape.dims[0]
{
return None;
}
let rows = lhs.shape.dims[0];
let lhs_cols = lhs.shape.dims[1];
let rhs_cols = rhs.shape.dims[1];
let id = self.graph.add_op(
luminal::hlir::Concat2D {
rows,
lhs_cols,
rhs_cols,
},
&[lhs.id, rhs.id],
);
Some(GraphTensor::from_id(
id,
ShapeTracker::new(vec![rows, lhs_cols + rhs_cols]),
lhs.graph_ref,
lhs.dtype,
))
}
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = normalize_dim(self.get_int_arg(node, 1).unwrap_or(0), a.shape.len());
let index = self
.get_int_arg(node, 2)
.context("select.int: missing index")?;
let dim_size = a.shape.dims[dim]
.to_usize()
.context("select.int: symbolic dims are not supported for negative indices")?;
let normalized_index = if index < 0 {
(dim_size as i64 + index) as usize
} else {
index as usize
};
Ok(a.slice_along(normalized_index..normalized_index + 1, dim)
.squeeze(dim))
}
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -138,43 +80,6 @@ impl<'a> Translator<'a> {
Ok(a)
}
pub(crate) fn translate_repeat(&mut self, node: &Node) -> Result<GraphTensor> {
let mut a = self.get_input_tensor(node, 0)?;
let repeats: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
sizes
.into_iter()
.map(|size| {
anyhow::ensure!(size >= 0, "repeat: negative repeats are not supported");
Ok(Expression::from(size as usize))
})
.collect::<Result<_>>()?
} else {
self.get_exprs_arg(node, 1)?
};
anyhow::ensure!(
repeats.len() >= a.shape.len(),
"repeat: repeats rank {} is smaller than input rank {}",
repeats.len(),
a.shape.len()
);
while a.shape.len() < repeats.len() {
a = a.unsqueeze(0);
}
Ok(a.repeat(repeats))
}
pub(crate) fn translate_getitem(&mut self, node: &Node) -> Result<GraphTensor> {
let index = self.get_int_arg(node, 1)?;
anyhow::ensure!(
index == 0,
"getitem: only tuple[0] access is supported today, got index={index}"
);
self.get_input_tensor(node, 0)
}
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1).unwrap_or(0);
@@ -215,6 +120,47 @@ impl<'a> Translator<'a> {
Ok(a.slice_along(start..end, dim))
}
/// `aten.select.int(self, dim, index)` — select element `index` along
/// `dim`, dropping that dim. Output rank = input rank 1, so a 1-D input
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
/// are normalized against the input shape.
///
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
/// slice + squeeze decomposition (rather than `gather`) because the
/// composition is a pure shape manipulation with a single iota, which the
/// luminal compiler can fold into surrounding ops.
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let index_raw = self.get_int_arg(node, 2)?;
// Normalize a possibly-negative index. PyTorch accepts indices in
// [-size, size); negative wraps from the end.
let index = if index_raw < 0 {
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"select.int: dim {} must be concrete to normalize a negative index",
dim
)
})?;
let normalized = axis_size as i64 + index_raw;
if normalized < 0 {
bail!(
"select.int: index {} out of range for dim {} of size {}",
index_raw,
dim,
axis_size
);
}
normalized as usize
} else {
index_raw as usize
};
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
}
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
names
@@ -256,11 +202,7 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, tensors[0].shape.len());
let mut result = tensors[0];
for t in &tensors[1..] {
if let Some(fast) = self.try_concat_2d_fast(result, *t, dim) {
result = fast;
} else {
result = result.concat_along(*t, dim);
}
result = result.concat_along(*t, dim);
}
Ok(result)
}
@@ -317,79 +259,6 @@ impl<'a> Translator<'a> {
bail!("index.Tensor: no index tensors in optional_tensors list");
}
index_names = found_tensors;
// Multiple explicit index tensors after leading `None`s mean
// "keep the prefix dims, then advanced-index the contiguous
// tail dims". DLRM's `Z[:, li, lj]` is exactly this pattern.
if first_non_none_dim > 0
&& index_names.len() > 1
&& first_non_none_dim + index_names.len() == source.shape.len()
{
let src_dims = source.shape.dims;
let indexed_dims = &src_dims[first_non_none_dim..];
let n_indexed = index_names.len();
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * indexed_dims[i + 1];
}
let mut flat_idx: Option<GraphTensor> = None;
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
let axis_size = indexed_dims[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
idx_int
} else {
idx_int * stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (acc_b, w_b) = broadcast_binary(acc, weighted);
acc_b + w_b
}
None => weighted,
});
}
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
let idx_shape = flat_idx.shape.dims.to_vec();
let mut idx_numel = Expression::from(1usize);
for dim in &idx_shape {
idx_numel *= *dim;
}
let flat_idx = reshape_tensor(flat_idx, vec![idx_numel]);
let prefix_dims = src_dims[..first_non_none_dim].to_vec();
let mut indexed_size = Expression::from(1usize);
for dim in indexed_dims {
indexed_size *= *dim;
}
let mut flat_source_shape = prefix_dims.clone();
flat_source_shape.push(indexed_size);
let flat_source = reshape_tensor(source, flat_source_shape);
let mut expanded_idx = flat_idx;
for _ in 0..prefix_dims.len() {
expanded_idx = expanded_idx.expand_dim(0, Expression::from(1usize));
}
let mut target = prefix_dims.clone();
target.push(idx_numel);
expanded_idx.shape.expand(target);
let gathered = flat_source.gather_elements(expanded_idx, prefix_dims.len());
let mut result_shape = prefix_dims;
result_shape.extend_from_slice(&idx_shape);
return Ok(reshape_tensor(gathered, result_shape));
}
// Simple case: single non-None index on a specific dim → gather_elements
if first_non_none_dim > 0 && index_names.len() == 1 {
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
@@ -505,6 +374,17 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
// gather_elements requires the index rank to match the source rank,
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
let indices = if promoted_rank0 {
indices.unsqueeze(0)
} else {
indices
};
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
// Stay in Int the whole way — multiplying an Int tensor by an
// Expression broadcasts the axis size and avoids three Cast nodes
@@ -516,7 +396,12 @@ impl<'a> Translator<'a> {
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
Ok(a.gather_elements(normalized, dim))
let result = a.gather_elements(normalized, dim);
Ok(if promoted_rank0 {
result.squeeze(0)
} else {
result
})
}
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {

View File

@@ -6,6 +6,20 @@ use crate::pt2_util::*;
use super::Translator;
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
/// smallest (ascending sort) element when scanning the input.
#[derive(Clone, Copy)]
pub(crate) enum ArgExtremum {
Max,
Min,
}
impl ArgExtremum {
fn descending(self) -> bool {
matches!(self, ArgExtremum::Max)
}
}
/// Compute total element count, returning an error if any dimension is symbolic.
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
a.dims().iter().try_fold(1usize, |acc, d| {
@@ -37,32 +51,26 @@ impl<'a> Translator<'a> {
(axes, keepdim)
}
_ => {
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
// override below assumes contiguous, no-broadcast storage —
// otherwise the `[1, N]` view treats stride-0 broadcast dims
// as if they held N distinct values and reads past the backing
// buffer. Materialize first when that's not the case (matches
// the guard `translate_reshape` already applies).
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
let ndim = a.shape.len();
if ndim == 0 {
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
// and mean of a scalar is just the scalar.
return Ok(a);
}
let total = concrete_numel(&a)?;
let has_broadcast = a
.shape
.dims
.iter()
.zip(a.shape.strides.iter())
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
let a = if has_broadcast || !a.shape.is_contiguous() {
a + 0.0
} else {
a
};
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let axes: Vec<usize> = (0..ndim).collect();
let result = match op {
ReductionOp::Sum => flat.sum(vec![1]),
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
ReductionOp::Max => flat.max(vec![1]),
ReductionOp::Min => flat.min(vec![1]),
ReductionOp::Prod => flat.prod(vec![1]),
ReductionOp::Sum => a.sum(axes),
// Note: the luminal `mean` helper divides by the product of the
// axis dims, but we already require concrete dims here so we
// divide by the cached `total` to avoid recomputing.
ReductionOp::Mean => a.sum(axes) / total as f32,
ReductionOp::Max => a.max(axes),
ReductionOp::Min => a.min(axes),
ReductionOp::Prod => a.prod(axes),
};
return Ok(result);
}
@@ -86,4 +94,100 @@ impl<'a> Translator<'a> {
Ok(result)
}
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
/// existing `stable_argsort` op and selecting the first index along the
/// sort axis.
///
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
/// for argmin). FX export emits the inputs positionally:
/// - input 0: tensor
/// - input 1: dim (Int) or None (Other) — when `dim=None`
/// - input 2: keepdim (Bool, optional)
///
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
/// The result of argsort along the sort axis is sliced at index 0,
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
/// `dim`.
///
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
/// view; we materialize it with `* 1` so the resulting node has
/// contiguous strides matching its visible shape (mirroring the
/// `topk` lowering in `translate_topk`). Without this, the output
/// buffer would be sized for the un-sliced argsort tensor while the
/// shape tracker reports a smaller rank.
///
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
/// metadata records int64 and the Python wrapper widens at the
/// boundary, so the PyTorch contract is preserved end-to-end
/// (LUM-486).
pub(crate) fn translate_argextremum(
&mut self,
node: &Node,
which: ArgExtremum,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
// argument (typically `Argument::Other(Null)`), so a missing or
// non-int slot means "reduce over the flattened tensor".
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
self.get_int_arg(node, 1).ok()
} else {
None
};
let keepdim = if node.inputs.len() > 2 {
self.get_bool_arg(node, 2).unwrap_or(false)
} else {
false
};
if a.shape.is_empty() {
match dim_opt {
None | Some(0) | Some(-1) => {
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
// `keepdim=True` does not add a dimension when the input is 0-d.
return Ok(self.graph.constant(0i64).cast(DType::Int));
}
Some(dim) => {
return Err(anyhow::anyhow!(
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
));
}
}
}
let descending = which.descending();
let (sort_axis, base) = match dim_opt {
None => {
// Full-reduce: flatten to 1-D, argsort along axis 0.
let total = concrete_numel(&a)?;
let flat = reshape_tensor(a, vec![Expression::from(total)]);
(0usize, flat)
}
Some(dim_raw) => {
let dim = normalize_dim(dim_raw, a.shape.len());
(dim, a)
}
};
// Pick index 0 along the sort axis. The slice-then-squeeze chain
// produces a non-contiguous view whose physical buffer is still
// sized for the un-sliced argsort tensor; the optional `keepdim`
// unsqueeze adds a stride-0 axis which is also non-contiguous.
// Materialize at the end with `* 1` so the resulting node has
// contiguous strides matching its visible shape (matches the
// pattern used by `translate_topk` for sliced index outputs).
let sorted = base.stable_argsort(sort_axis, descending);
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
let result = if keepdim {
picked.unsqueeze(sort_axis)
} else {
picked
};
Ok(result * 1)
}
}

View File

@@ -28,45 +28,6 @@ const TRIANGULAR_INPUT_ARG: usize = 0;
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
impl<'a> Translator<'a> {
fn translate_embedding_bag_generic(
&mut self,
weight: GraphTensor,
indices: GraphTensor,
offsets: GraphTensor,
) -> Result<GraphTensor> {
let hidden_dim = weight.shape.dims[1];
let n_indices = indices.shape.dims[0];
let n_bags = offsets.shape.dims[0];
// Gather per-index embeddings: [E] -> [E, D].
let ids_expanded = (indices * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim).expand_dim(0, n_indices);
let gathered = weight.gather(ids_expanded + arange);
// Bag assignment per position:
// bag_id[pos] = count(offsets <= pos) - 1
// This supports empty bags too, because repeated offsets simply skip a
// bag id when no positions land in that interval.
let positions = self.graph.arange(n_indices).expand_dim(0, n_bags);
let starts = offsets.expand_dim(1, n_indices);
let bag_ids = positions.ge(starts).cast(DType::Int).sum(0)
- self
.graph
.constant_float(1.0)
.cast(DType::Int)
.expand_rhs(vec![n_indices]);
let bag_axis = self.graph.arange(n_bags).expand_dim(1, n_indices);
let bag_ids = bag_ids.expand_dim(0, n_bags);
let mask = bag_ids
.eq(bag_axis)
.expand_dim(2, hidden_dim)
.cast(gathered.dtype);
let gathered = gathered.expand_dim(0, n_bags);
Ok((gathered * mask).sum(1))
}
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
.inputs
@@ -219,45 +180,6 @@ impl<'a> Translator<'a> {
Ok(value.expand_rhs(reference.shape))
}
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?.cast(DType::Int);
let offsets = self.get_input_tensor(node, 2)?.cast(DType::Int);
let mode = self.get_int_arg(node, 4).unwrap_or(0);
anyhow::ensure!(
mode == 0,
"_embedding_bag: only mode=0 (sum) is supported, got mode={mode}"
);
anyhow::ensure!(
indices.shape.len() == 1 && offsets.shape.len() == 1,
"_embedding_bag: expected 1D indices/offsets, got indices={}D offsets={}D",
indices.shape.len(),
offsets.shape.len()
);
if weight.dtype == DType::F32 {
let id = self.graph.add_op(
luminal::hlir::EmbeddingBagSum {
n_bags: offsets.shape.dims[0],
n_indices: indices.shape.dims[0],
hidden_dim: weight.shape.dims[1],
num_embeddings: weight.shape.dims[0],
},
&[weight.id, indices.id, offsets.id],
);
return Ok(GraphTensor::from_id(
id,
ShapeTracker::new(vec![offsets.shape.dims[0], weight.shape.dims[1]]),
weight.graph_ref,
DType::F32,
));
}
self.translate_embedding_bag_generic(weight, indices, offsets)
}
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
let output_name = node
.outputs

View File

@@ -21,30 +21,6 @@ const DIV_MODE_INPUT_ARG: usize = 0;
const DIV_MODE_OTHER_ARG: usize = 1;
impl<'a> Translator<'a> {
fn expand_channel_parameter(
&self,
input: GraphTensor,
parameter: GraphTensor,
) -> Result<GraphTensor> {
anyhow::ensure!(
input.shape.len() >= 2,
"batch_norm: expected rank >= 2 input, got rank {}",
input.shape.len()
);
anyhow::ensure!(
parameter.shape.len() == 1,
"batch_norm: expected 1D channel parameter, got rank {}",
parameter.shape.len()
);
let mut expanded = parameter.unsqueeze(0);
for axis in 2..input.shape.len() {
expanded = expanded.unsqueeze(axis);
}
expanded.shape.expand(input.dims().to_vec());
Ok(expanded)
}
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
@@ -125,30 +101,6 @@ impl<'a> Translator<'a> {
Ok(result)
}
pub(crate) fn translate_native_batch_norm_no_training(
&mut self,
node: &Node,
) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let running_mean = self.expand_channel_parameter(input, self.get_input_tensor(node, 3)?)?;
let running_var = self.expand_channel_parameter(input, self.get_input_tensor(node, 4)?)?;
let eps = self.get_float_arg(node, 6).unwrap_or(1e-5) as f32;
let mut result = (input - running_mean) / (running_var + eps).sqrt();
if let Some(weight_name) = node.inputs.get(1).and_then(|i| i.arg.as_tensor_name()) {
let weight = self.expand_channel_parameter(input, self.get_tensor(weight_name)?)?;
result = result * weight;
}
if let Some(bias_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
let bias = self.expand_channel_parameter(input, self.get_tensor(bias_name)?)?;
result = result + bias;
}
Ok(result)
}
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let zero = self
@@ -261,12 +213,18 @@ impl<'a> Translator<'a> {
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
// Check rounding_mode kwarg
// Check rounding_mode kwarg. PT2 serializes string args as
// {"as_string": "<value>"}, so we have to drill into the JSON.
let rounding_mode = node.inputs.iter().find_map(|input| {
if input.name == "rounding_mode"
&& let Argument::Other(val) = &input.arg
{
return val.as_str().map(|s| s.to_string());
if let Some(s) = val.as_str() {
return Some(s.to_string());
}
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
return Some(s.to_string());
}
}
None
});
@@ -317,4 +275,52 @@ impl<'a> Translator<'a> {
}
Ok(result)
}
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
///
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
/// overload takes tensor bounds that appear as separate input nodes in the
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
///
/// - rank-0 (scalar wrapped in a tensor) — most common
/// - same shape as self (per-element clamp, e.g. learned bounds)
/// - any shape that broadcasts to self via right-align + size-1 expand
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
///
/// We use `broadcast_binary` to right-align and expand both operands to a
/// common shape before the elementwise max/min, matching PyTorch semantics
/// across all three modes.
///
/// Either bound may be absent (FX represents this as a non-tensor argument
/// at the corresponding input slot), in which case we clamp to one side
/// only.
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_tensor = node
.inputs
.get(1)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let max_tensor = node
.inputs
.get(2)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let mut result = a;
if let Some(lo) = min_tensor {
let lo = lo.cast(result.dtype);
let (r, lo) = broadcast_binary(result, lo);
result = r.maximum(lo);
}
if let Some(hi) = max_tensor {
let hi = hi.cast(result.dtype);
let (r, hi) = broadcast_binary(result, hi);
result = r.minimum(hi);
}
Ok(result)
}
}

View File

@@ -8,14 +8,12 @@ from .compiled_model import CompiledModel
# Import Rust extension components (built by maturin)
from .luminal import CompiledGraph, process_pt2
from .main import luminal_backend, register_backend
from .pt2 import compile
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
"CompiledModel",
"compile",
"luminal_backend",
"register_backend",
"CompiledGraph",

View File

@@ -1,8 +1,9 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
from typing import Any, List
from typing import List
import torch
from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
@@ -27,14 +28,6 @@ class CompiledModel:
self._input_names = input_names or graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
output_dtype_codes = graph_result.output_dtypes
self._output_dtypes = [
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
for i in range(len(self._output_names))
]
self._static_output_shapes = [tuple(shape) for shape in self._output_shapes]
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
self._weight_refs = weight_refs or []
self._user_indices = user_indices
@@ -42,9 +35,6 @@ class CompiledModel:
self._supports_device_ptrs = getattr(
graph_result, "supports_device_ptrs", False
)
# Cache converted/contiguous views for repeated calls with the same input
# tensors so we don't rebuild sparse index buffers every forward.
self._prepared_input_cache = [None] * len(self._input_names)
# Expected input dtypes from graph (used to convert user inputs)
input_dtype_codes = graph_result.input_dtypes
self._input_dtypes = [
@@ -53,80 +43,6 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
self._single_float_output_fast_path = (
self._supports_device_ptrs
and not self._has_dynamic_dims
and len(self._output_names) == 1
and self._output_dtypes[0].is_floating_point
)
@staticmethod
def _input_cache_key(tensor: torch.Tensor, expected_dtype: torch.dtype):
return (
id(tensor),
getattr(tensor, "_version", None),
tensor.data_ptr(),
expected_dtype,
)
def _prepare_input_tensor(
self, index: int, tensor: torch.Tensor, expected_dtype: torch.dtype
) -> torch.Tensor:
detached = tensor.detach()
needs_preparation = (
detached.dtype != expected_dtype or not detached.is_contiguous()
)
if not needs_preparation:
return detached
cache_key = self._input_cache_key(detached, expected_dtype)
cached = self._prepared_input_cache[index]
if cached is not None and cached[0] == cache_key:
return cached[1]
prepared = detached.contiguous().to(expected_dtype)
self._prepared_input_cache[index] = (cache_key, prepared)
return prepared
def _bind_user_inputs(self, user_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Bind the current user inputs into the Rust graph."""
input_refs = []
for index, (name, tensor, expected_dtype) in enumerate(
zip(self._input_names, user_inputs, self._input_dtypes)
):
prepared = self._prepare_input_tensor(index, tensor, expected_dtype)
if self._supports_device_ptrs and tensor.is_cuda:
n_bytes = prepared.numel() * prepared.element_size()
self._graph.set_input_device_ptr(name, prepared.data_ptr(), n_bytes)
input_refs.append(prepared)
else:
if prepared.device.type != "cpu":
prepared = prepared.cpu()
n_bytes = prepared.numel() * prepared.element_size()
dtype_code = _torch_dtype_code(prepared.dtype)
self._graph.set_input_from_ptr(
name, prepared.data_ptr(), n_bytes, dtype_code
)
return input_refs
def _run_static_single_float_output(
self, user_inputs: List[torch.Tensor], input_device: torch.device
):
_input_refs = self._bind_user_inputs(user_inputs)
output_name = self._output_names[0]
output_dtype = self._output_dtypes[0]
out = torch.empty(
self._static_output_shapes[0], dtype=output_dtype, device=input_device
)
self._graph.set_output_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
self._graph.run()
if not self._graph.output_is_zero_copy(output_name):
self._graph.copy_output_to_device_ptr(
output_name, out.data_ptr(), out.numel() * out.element_size()
)
return (out,)
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -170,18 +86,25 @@ class CompiledModel:
if self._has_dynamic_dims:
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
elif (
self._single_float_output_fast_path
and input_device.type != "cpu"
and all(torch.is_tensor(t) and t.is_cuda for t in user_inputs)
):
return self._run_static_single_float_output(user_inputs, input_device)
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = self._bind_user_inputs(user_inputs)
_input_refs = []
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
# Resolve output shapes before run() (needed for pre-allocation).
if self._has_dynamic_dims:
@@ -189,6 +112,8 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
output_dtype_codes = self._graph.output_dtypes
# CUDA zero-copy path: pre-allocate output tensors and register their device
# pointers so the final kernel writes directly into PyTorch's buffer.
_use_zero_copy = self._supports_device_ptrs
@@ -196,8 +121,8 @@ class CompiledModel:
if _use_zero_copy:
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
self._output_dtypes[i]
if i < len(self._output_dtypes)
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
out = torch.empty(shape, dtype=out_dtype, device=input_device)
@@ -210,13 +135,18 @@ class CompiledModel:
# Run the graph
self._graph.run()
# Integer dtypes for which we read the buffer as i32 and then cast.
# Includes int64 because luminal collapses all integer types to its
# 32-bit `Int` internally — we restore the original precision here.
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
# Collect outputs
if _use_zero_copy:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
self._output_dtypes[i]
if i < len(self._output_dtypes)
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
out = output_tensors[i]
@@ -225,11 +155,12 @@ class CompiledModel:
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype == torch.int32:
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
elif out_dtype == torch.bool:
@@ -253,13 +184,17 @@ class CompiledModel:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
self._output_dtypes[i]
if i < len(self._output_dtypes)
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
if out_dtype == torch.int32:
if out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
)
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
@@ -274,41 +209,3 @@ class CompiledModel:
outputs.append(out)
return tuple(outputs)
def _leaf_paths(tree: Any, prefix=()):
if torch.is_tensor(tree):
return [prefix]
if isinstance(tree, (list, tuple)):
paths = []
for idx, value in enumerate(tree):
paths.extend(_leaf_paths(value, prefix + (idx,)))
return paths
if isinstance(tree, dict):
paths = []
for key, value in tree.items():
paths.extend(_leaf_paths(value, prefix + (key,)))
return paths
return [prefix]
def _follow_path(tree: Any, path):
value = tree
for key in path:
value = value[key]
return value
class StructuredCompiledModel:
"""Preserve a module's original nested input structure for direct PT2 compile()."""
def __init__(self, compiled_model, example_args):
self._compiled = compiled_model
self._leaf_paths = _leaf_paths(example_args)
def __getattr__(self, name):
return getattr(self._compiled, name)
def __call__(self, *args):
flat_inputs = [_follow_path(args, path) for path in self._leaf_paths]
return self._compiled(*flat_inputs)

View File

@@ -9,12 +9,10 @@ import inspect
import os
import shutil
import tempfile
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
from .compiled_model import CompiledModel, StructuredCompiledModel
from .compiled_model import CompiledModel
from .luminal import process_pt2
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
@@ -186,66 +184,6 @@ def _save_and_compile(
shutil.rmtree(tmpdir, ignore_errors=True)
def _has_cuda_inputs(flat_example_inputs):
return any(torch.is_tensor(inp) and inp.is_cuda for inp in flat_example_inputs)
def _direct_search_env(flat_example_inputs, search_trials=None, search_keep_best=None):
"""Search env overrides for direct compile() calls.
CUDA DLRM-style models benefit materially from a deeper per-candidate
profile and from keeping more parents alive between generations. Keep the
defaults narrow so CPU and env-configured callers are unchanged.
"""
has_cuda = _has_cuda_inputs(flat_example_inputs)
overrides = {}
if search_trials is not None:
overrides["LUMINAL_SEARCH_TRIALS"] = str(search_trials)
elif has_cuda and "LUMINAL_SEARCH_TRIALS" not in os.environ:
overrides["LUMINAL_SEARCH_TRIALS"] = "5"
if search_keep_best is not None:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = str(search_keep_best)
elif has_cuda and "LUMINAL_SEARCH_KEEP_BEST" not in os.environ:
overrides["LUMINAL_SEARCH_KEEP_BEST"] = "3"
return overrides
@contextmanager
def _temporary_env(overrides):
sentinel = object()
previous = {}
try:
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = old_value
def _strip_exported_weights_for_zero_copy(ep, original_weights):
"""Shrink the saved .pt2 artifact when original weights will be reused."""
if not original_weights:
return
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
replacement = torch.zeros(1, dtype=orig.dtype, device="cpu")
if isinstance(orig, torch.nn.Parameter):
replacement = torch.nn.Parameter(
replacement, requires_grad=orig.requires_grad
)
ep._state_dict[key] = replacement
del orig
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
@@ -463,9 +401,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
def compile(
model,
example_input,
search_iterations=None,
search_trials=None,
search_keep_best=None,
search_iterations=25,
factory=None,
export_kwargs=None,
dynamic_dim=None,
@@ -477,13 +413,7 @@ def compile(
model: A PyTorch nn.Module.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations. When None,
defaults to 200 on CUDA inputs and 10 otherwise.
search_trials: Optional per-candidate profiling trials inside Luminal's
search. When unset, direct CUDA compile defaults to 5.
search_keep_best: Optional number of parent candidates to retain
between search generations. When unset, direct CUDA compile
defaults to 3.
search_iterations: Number of optimization search iterations.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
@@ -502,26 +432,17 @@ def compile(
Returns:
A CompiledModel callable.
"""
if factory is None:
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
flat_example_inputs = pytree.arg_tree_leaves(*example_args)
if factory is None:
factory = _detect_factory_capsule(flat_example_inputs)
if search_iterations is None:
search_iterations = (
200
if _has_cuda_inputs(flat_example_inputs)
else 10
)
search_env = _direct_search_env(
flat_example_inputs,
search_trials=search_trials,
search_keep_best=search_keep_best,
)
kwargs = export_kwargs or {}
extra = _export_kwargs()
@@ -567,16 +488,63 @@ def compile(
)
ep = ep.run_decompositions(_decomp_table())
original_weights = model.state_dict()
_strip_exported_weights_for_zero_copy(ep, original_weights)
with _temporary_env(search_env):
compiled = _save_and_compile(
ep,
factory,
search_iterations,
original_weights=original_weights,
)
return StructuredCompiledModel(compiled, example_args)
return _save_and_compile(ep, factory, search_iterations)
def _drop_input_guards(ep):
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
``i = torch.tensor(2)``), torch.export records two equivalent input
guards: ``L['i'].item() == 2`` (referencing the original local source)
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
Two failures stack on top of each other:
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
a literal ``L`` reference in the generated ``_guards_fn`` and raising
``NameError: name 'L' is not defined`` during retracing.
2. Even after dropping the unresolvable guard, the surviving
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
a graph break.
These guards exist solely to validate inputs at runtime in eager-mode
consumers of the ExportedProgram; the luminal compiler does its own
input shape/dtype checks against the compiled graph signature, so we
are not losing any safety by clearing them.
"""
if hasattr(ep, "_guards_code"):
ep._guards_code = []
def _drop_dead_data_dependent_ops(gm):
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
When dynamo specializes a 0-d int input by tracing through ``.item()``,
the resulting graph may contain a dead ``aten.item.default`` node whose
output is never consumed. luminal's translator does not lower
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
node in the graph causes a graph break at compile time. Eliminating it
keeps the (correctly specialized) downstream graph in a single subgraph.
"""
graph = gm.graph
changed = False
for node in list(graph.nodes):
if (
node.op == "call_function"
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
and len(node.users) == 0
):
graph.erase_node(node)
changed = True
if changed:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
def _legacy_auto_dim(example_args):
@@ -658,13 +626,23 @@ def _eager_pt2_compile(
if dynamic_shapes is None:
raise
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
# LUM-499: drop dynamo-emitted input guards before run_decompositions
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
# data-dependent .item() calls and unresolved `L[...]` references.
_drop_input_guards(ep)
_drop_dead_data_dependent_ops(ep.graph_module)
ep = ep.run_decompositions(_decomp_table())
# When using shared memory (original_weights), strip large weight buffers
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
_strip_exported_weights_for_zero_copy(ep, original_weights)
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
@@ -678,20 +656,11 @@ def _eager_pt2_compile(
if torch.cuda.is_available():
torch.cuda.empty_cache()
default_search_iterations = (
50
if any(torch.is_tensor(inp) and inp.is_cuda for inp in user_inputs)
else 10
)
search_iterations = int(
os.environ.get("LUMINAL_PT2_SEARCH_ITERATIONS", str(default_search_iterations))
)
try:
return _save_and_compile(
pt2_path,
factory,
search_iterations,
10,
original_weights=original_weights,
user_indices=user_indices,
)

View File

@@ -1,11 +0,0 @@
# DLRM CUDA Benchmark
These numbers are from the focused `2048`-candidate DLRM CUDA benchmark in
`test_dlrm.py`, measured after compile and warmup with `5 x 20` timed runs.
| Path | Median latency | Throughput |
| --- | ---: | ---: |
| eager | 0.267 ms | 7,674,321 candidates/s |
| torch.compile + inductor | 0.295 ms | 6,933,911 candidates/s |
| torch.compile + inductor (`reduce-overhead`) | 0.299 ms | 6,843,456 candidates/s |
| torch.compile + `luminal_backend` | 0.476 ms | 4,299,775 candidates/s |

View File

@@ -1,213 +0,0 @@
"""DeepCTR-Torch DCN / DIN coverage for the luminal torch.compile backend.
These tests are intended for the local integration workflow where the
``DeepCTR-Torch`` repo is checked out next to ``luminal``. They first confirm
that eager mode and regular ``torch.compile(..., backend="inductor")`` agree,
then run the same model through ``backend=luminal_backend``.
"""
from __future__ import annotations
import copy
import sys
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
pytest.importorskip("tqdm")
DEEPCTR_ROOT = Path(__file__).resolve().parents[4] / "DeepCTR-Torch"
if not DEEPCTR_ROOT.exists():
pytest.skip(
f"DeepCTR-Torch checkout not found at {DEEPCTR_ROOT}",
allow_module_level=True,
)
deepctr_root = str(DEEPCTR_ROOT)
if deepctr_root not in sys.path:
sys.path.insert(0, deepctr_root)
from deepctr_torch.inputs import (
DenseFeat,
SparseFeat,
VarLenSparseFeat,
build_input_features,
)
from deepctr_torch.models import DCN
from deepctr_torch.models.din import DIN
from luminal import luminal_backend
def _stack_features(
feature_columns: list, feature_dict: dict[str, np.ndarray], device: torch.device
) -> torch.Tensor:
parts = []
for name in build_input_features(feature_columns):
value = np.asarray(feature_dict[name])
if value.ndim == 1:
value = np.expand_dims(value, axis=1)
parts.append(value)
stacked = np.concatenate(parts, axis=-1)
return torch.tensor(stacked, dtype=torch.float32, device=device)
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _run_luminal(model: torch.nn.Module, *inputs: torch.Tensor) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _make_dcn(
device: torch.device, cross_parameterization: str
) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("s0", 5, embedding_dim=4),
SparseFeat("s1", 7, embedding_dim=4),
DenseFeat("d0", 1),
DenseFeat("d1", 1),
]
feature_dict = {
"s0": np.array([0, 1, 2, 3], dtype=np.int64),
"s1": np.array([1, 2, 3, 4], dtype=np.int64),
"d0": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32),
"d1": np.array([1.0, 0.0, 1.0, 0.0], dtype=np.float32),
}
model = DCN(
linear_feature_columns=feature_columns,
dnn_feature_columns=feature_columns,
cross_num=2,
cross_parameterization=cross_parameterization,
dnn_hidden_units=(16,),
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
def _make_din(device: torch.device) -> tuple[torch.nn.Module, tuple[torch.Tensor]]:
torch.manual_seed(0)
feature_columns = [
SparseFeat("user", 4, embedding_dim=4),
SparseFeat("gender", 2, embedding_dim=4),
SparseFeat("item_id", 4, embedding_dim=8),
SparseFeat("cate_id", 3, embedding_dim=4),
DenseFeat("pay_score", 1),
VarLenSparseFeat(
SparseFeat(
"hist_item_id",
vocabulary_size=4,
embedding_dim=8,
embedding_name="item_id",
),
maxlen=4,
length_name="seq_length",
),
VarLenSparseFeat(
SparseFeat(
"hist_cate_id",
vocabulary_size=3,
embedding_dim=4,
embedding_name="cate_id",
),
maxlen=4,
length_name="seq_length",
),
]
feature_dict = {
"user": np.array([0, 1, 2, 3], dtype=np.int64),
"gender": np.array([0, 1, 0, 1], dtype=np.int64),
"item_id": np.array([1, 2, 3, 2], dtype=np.int64),
"cate_id": np.array([1, 2, 1, 2], dtype=np.int64),
"pay_score": np.array([0.1, 0.2, 0.3, 0.2], dtype=np.float32),
"hist_item_id": np.array(
[[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"hist_cate_id": np.array(
[[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]],
dtype=np.int64,
),
"seq_length": np.array([3, 3, 2, 2], dtype=np.int64),
}
model = DIN(
feature_columns,
["item_id", "cate_id"],
dnn_dropout=0.0,
device=str(device),
).eval()
inputs = (_stack_features(feature_columns, feature_dict, device),)
return model.to(device), inputs
@pytest.mark.parametrize("cross_parameterization", ["vector", "matrix"])
def test_deepctr_dcn_matches_inductor_when_supported(
device: torch.device, cross_parameterization: str
) -> None:
model, inputs = _make_dcn(device, cross_parameterization)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
def test_deepctr_din_matches_inductor_when_supported(device: torch.device) -> None:
model, inputs = _make_din(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")

View File

@@ -1,456 +0,0 @@
"""DLRM coverage for the luminal torch.compile backend.
This test expects a sibling ``dlrm`` checkout next to ``luminal`` and validates
that eager mode, ``torch.compile(..., backend="inductor")``, and
``torch.compile(..., backend=luminal_backend)`` agree on deterministic DLRM
configurations, including a CUDA benchmark that compares Luminal against
TorchInductor's CUDA-graph-enabled ``mode="reduce-overhead"`` path.
"""
from __future__ import annotations
import copy
import importlib.machinery
import sys
import types
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pytest
import torch
pytest.importorskip("sklearn")
DLRM_ROOT = Path(__file__).resolve().parents[4] / "dlrm"
if not DLRM_ROOT.exists():
pytest.skip(f"dlrm checkout not found at {DLRM_ROOT}", allow_module_level=True)
dlrm_root = str(DLRM_ROOT)
if dlrm_root not in sys.path:
sys.path.insert(0, dlrm_root)
def _install_dlrm_import_stubs() -> None:
ext_dist = types.ModuleType("extend_distributed")
ext_dist.my_size = 1
ext_dist.dist = None
ext_dist.get_split_lengths = lambda n: (n, [n])
ext_dist.get_my_slice = lambda n: slice(0, n)
class _AllToAll:
def __init__(self, values):
self._values = values
def wait(self):
return self._values
ext_dist.alltoall = lambda values, n_emb_per_rank: _AllToAll(values)
ext_dist.__spec__ = importlib.machinery.ModuleSpec(
"extend_distributed", loader=None
)
sys.modules["extend_distributed"] = ext_dist
mlperf_logger = types.ModuleType("mlperf_logger")
mlperf_logger.__spec__ = importlib.machinery.ModuleSpec(
"mlperf_logger", loader=None
)
sys.modules["mlperf_logger"] = mlperf_logger
tensorboard = types.ModuleType("torch.utils.tensorboard")
class SummaryWriter:
def __init__(self, *args, **kwargs):
pass
def add_scalar(self, *args, **kwargs):
pass
def close(self):
pass
tensorboard.SummaryWriter = SummaryWriter
tensorboard.__spec__ = importlib.machinery.ModuleSpec(
"torch.utils.tensorboard", loader=None
)
sys.modules["torch.utils.tensorboard"] = tensorboard
onnx = types.ModuleType("onnx")
onnx.__spec__ = importlib.machinery.ModuleSpec("onnx", loader=None)
sys.modules["onnx"] = onnx
_install_dlrm_import_stubs()
import dlrm_s_pytorch as dlrm_mod
from luminal import luminal_backend
dlrm_mod.args = types.SimpleNamespace(loss_weights="1-1", loss_function="bce")
def _unwrap(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
if isinstance(output, tuple) and len(output) == 1:
return output[0]
return output
def _assert_allclose(
lhs: torch.Tensor, rhs: torch.Tensor, label: str, atol: float = 1e-5
) -> None:
max_diff = torch.max(torch.abs(lhs - rhs)).item()
assert torch.allclose(lhs, rhs, atol=atol), f"{label} max_diff={max_diff:.2e}"
def _run_eager(model: torch.nn.Module, *inputs) -> torch.Tensor:
with torch.no_grad():
return _unwrap(model(*inputs))
@contextmanager
def _relaxed_dynamo_limits():
prev_recompile_limit = torch._dynamo.config.recompile_limit
prev_cache_size_limit = torch._dynamo.config.cache_size_limit
torch._dynamo.config.recompile_limit = 16
torch._dynamo.config.cache_size_limit = 16
try:
yield
finally:
torch._dynamo.config.recompile_limit = prev_recompile_limit
torch._dynamo.config.cache_size_limit = prev_cache_size_limit
def _run_inductor(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend="inductor")
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend="inductor")
def _run_luminal(model: torch.nn.Module, *inputs) -> torch.Tensor:
with _relaxed_dynamo_limits():
torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=luminal_backend)
with torch.no_grad():
return _unwrap(compiled(*inputs))
def _compile_inductor_reduce_overhead(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(
copy.deepcopy(model),
backend="inductor",
mode="reduce-overhead",
)
def _compile_luminal(model: torch.nn.Module):
with _relaxed_dynamo_limits():
torch._dynamo.reset()
return torch.compile(copy.deepcopy(model), backend=luminal_backend)
def _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters: int,
timed_iters: int,
mark_step_begin: bool = False,
) -> dict[str, float]:
assert torch.cuda.is_available(), "CUDA timing requires an available GPU"
with torch.no_grad():
for _ in range(warmup_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
_unwrap(compiled_model(*inputs))
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(timed_iters)]
with torch.no_grad():
for idx in range(timed_iters):
if mark_step_begin:
torch.compiler.cudagraph_mark_step_begin()
starts[idx].record()
_unwrap(compiled_model(*inputs))
ends[idx].record()
torch.cuda.synchronize()
elapsed_ms = np.array(
[start.elapsed_time(end) for start, end in zip(starts, ends)],
dtype=np.float64,
)
return {
"mean_ms": float(elapsed_ms.mean()),
"median_ms": float(np.median(elapsed_ms)),
"min_ms": float(elapsed_ms.min()),
}
def _timed_cuda_rounds(
compiled_model,
*inputs,
pre_round_warmup_iters: int,
timed_iters: int,
rounds: int,
mark_step_begin: bool = False,
) -> dict[str, float | list[float]]:
assert rounds > 0, "rounds must be positive"
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=pre_round_warmup_iters,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians = [stats["median_ms"]]
for _ in range(rounds - 1):
stats = _timed_cuda_runs(
compiled_model,
*inputs,
warmup_iters=0,
timed_iters=timed_iters,
mark_step_begin=mark_step_begin,
)
round_medians.append(stats["median_ms"])
round_medians_np = np.array(round_medians, dtype=np.float64)
return {
"round_medians_ms": [float(value) for value in round_medians],
"median_ms": float(np.median(round_medians_np)),
"mean_ms": float(round_medians_np.mean()),
"min_ms": float(round_medians_np.min()),
}
def _make_dlrm(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
m_spa = 4
ln_emb = np.array([8, 6, 4])
ln_bot = np.array([3, 4])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 8, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=1,
).eval()
inputs = (
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32, device=device),
[
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
torch.tensor([0, 1], dtype=torch.int64, device=device),
],
[
torch.tensor([1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 3], dtype=torch.int64, device=device),
torch.tensor([2, 1], dtype=torch.int64, device=device),
],
)
return model.to(device), inputs
def _make_dlrm_batch_2048(
device: torch.device,
) -> tuple[torch.nn.Module, tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]]:
np.random.seed(0)
torch.manual_seed(0)
batch_size = 2048
indices_per_bag = 2
m_spa = 16
ln_emb = np.array([4096, 2048, 1024])
ln_bot = np.array([3, 64, m_spa])
num_fea = ln_emb.size + 1
num_int = (num_fea * (num_fea - 1)) // 2 + m_spa
ln_top = np.array([num_int, 64, 32, 1])
model = dlrm_mod.DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op="dot",
arch_interaction_itself=False,
sigmoid_top=2,
).eval()
dense_x = torch.linspace(
-1.0,
1.0,
steps=batch_size * 3,
dtype=torch.float32,
device=device,
).reshape(batch_size, 3)
total_sparse_indices = batch_size * indices_per_bag
positions = torch.arange(total_sparse_indices, dtype=torch.int64, device=device)
offsets = torch.arange(
0,
total_sparse_indices,
indices_per_bag,
dtype=torch.int64,
device=device,
)
inputs = (
dense_x,
[offsets.clone(), offsets.clone(), offsets.clone()],
[
((positions * 3 + 1) % int(ln_emb[0])).to(torch.int64),
((positions * 5 + 2) % int(ln_emb[1])).to(torch.int64),
((positions * 7 + 3) % int(ln_emb[2])).to(torch.int64),
],
)
return model.to(device), inputs
def test_dlrm_matches_inductor_and_luminal(device: torch.device) -> None:
model, inputs = _make_dlrm(device)
eager = _run_eager(model, *inputs)
inductor = _run_inductor(model, *inputs)
_assert_allclose(inductor, eager, "inductor vs eager")
luminal = _run_luminal(model, *inputs)
_assert_allclose(luminal, eager, "luminal vs eager")
_assert_allclose(luminal, inductor, "luminal vs inductor")
@pytest.mark.slow
def test_dlrm_batch_2048_cuda_matches_torchinductor_reduce_overhead_and_reports_speed(
device: torch.device,
) -> None:
if device.type != "cuda":
pytest.skip("Requires `LUMINAL_TEST_DEVICE=cuda` for the CUDA benchmark")
model, inputs = _make_dlrm_batch_2048(device)
eager = _run_eager(model, *inputs)
eager_model = copy.deepcopy(model).to(device).eval()
inductor_compiled = _compile_inductor(model)
inductor_reduce_overhead = _compile_inductor_reduce_overhead(model)
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad():
inductor_output = _unwrap(inductor_reduce_overhead(*inputs))
luminal_compiled = _compile_luminal(model)
with torch.no_grad():
luminal_output = _unwrap(luminal_compiled(*inputs))
with torch.no_grad():
inductor_default_output = _unwrap(inductor_compiled(*inputs))
_assert_allclose(inductor_default_output, eager, "inductor default vs eager", atol=1e-4)
_assert_allclose(inductor_output, eager, "inductor reduce-overhead vs eager", atol=1e-4)
_assert_allclose(luminal_output, eager, "luminal vs eager", atol=1e-4)
_assert_allclose(
luminal_output,
inductor_default_output,
"luminal vs inductor default",
atol=1e-4,
)
_assert_allclose(
luminal_output,
inductor_output,
"luminal vs inductor reduce-overhead",
atol=1e-4,
)
benchmark_rounds = 5
benchmark_iters = 20
post_compile_warmup_iters = 10
eager_stats = _timed_cuda_rounds(
eager_model,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_default_stats = _timed_cuda_rounds(
inductor_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
inductor_stats = _timed_cuda_rounds(
inductor_reduce_overhead,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
mark_step_begin=True,
)
luminal_stats = _timed_cuda_rounds(
luminal_compiled,
*inputs,
pre_round_warmup_iters=post_compile_warmup_iters,
timed_iters=benchmark_iters,
rounds=benchmark_rounds,
)
batch_size = inputs[0].shape[0]
benchmark_results = [
("eager", eager_stats),
("inductor default", inductor_default_stats),
("inductor reduce-overhead", inductor_stats),
("luminal backend", luminal_stats),
]
ranked_results = sorted(
benchmark_results,
key=lambda item: float(item[1]["median_ms"]),
)
speed_lines = []
for idx, (label, stats) in enumerate(ranked_results, start=1):
throughput = batch_size / (float(stats["median_ms"]) / 1000.0)
rounds_repr = ", ".join(
f"{value:.3f}" for value in stats["round_medians_ms"] # type: ignore[index]
)
speed_lines.append(
f" {idx}. {label}: {float(stats['median_ms']):.3f} ms"
f" ({throughput:,.0f} candidates/s)"
f" [round medians: {rounds_repr}]"
)
luminal_vs_inductor = float(luminal_stats["median_ms"]) / float(
inductor_stats["median_ms"]
)
luminal_vs_eager = float(luminal_stats["median_ms"]) / float(eager_stats["median_ms"])
print(
"\n"
f"DLRM batch={batch_size} candidates on CUDA after compile/warmup\n"
f" Timed rounds: {benchmark_rounds} x {benchmark_iters} iterations\n"
f" Ranking by median latency:\n"
+ "\n".join(speed_lines)
+ "\n"
f" Luminal backend / TorchInductor reduce-overhead latency ratio:"
f" {luminal_vs_inductor:.3f}x\n"
f" Luminal backend / eager latency ratio: {luminal_vs_eager:.3f}x"
)

View File

@@ -1,5 +1,6 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
@@ -220,6 +221,7 @@ from test_models import (
Conv1dNoPadModel,
Conv1dSamePadModel,
Conv1dBiasModel,
Conv1dFloorDivPositionalModel,
Conv2dNoPadModel,
Conv2dSamePadModel,
Conv2dBiasModel,
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
eager = model(x)
out = model_compiled(x)
assert out.dtype == eager.dtype == torch.int64
assert torch.equal(out, eager)
def test_reduce_sum_3d_axis1(device: torch.device):
"""Test sum reduction along axis 1 for a 3D tensor."""
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
@@ -2022,9 +2035,16 @@ def test_split(device: torch.device):
# ========== Argsort / MoE Routing Tests ==========
def test_argsort_stable_duplicates(device: torch.device):
"""Duplicate values should follow stable lower-index-first tie-breaking."""
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
"""Duplicate values should follow stable lower-index-first tie-breaking.
Parametrized over int32/int64 to verify luminal preserves whichever
integer dtype the eager model declares (LUM-486).
"""
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
device
)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.tensor(
[[2.0, 1.0, 1.0, 3.0]],
@@ -2033,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.dtype == torch.int32
assert torch.equal(output, original.to(torch.int32))
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
assert output.dtype == original.dtype, (
f"luminal returned {output.dtype}, eager produced {original.dtype}"
)
assert torch.equal(output, original)
def test_tiny_moe_routing(device: torch.device):
"""Focused proof for build MoE routing support."""
model: torch.nn.Module = TinyMoERoutingModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
"""Focused proof for built MoE routing support.
Parametrized over int32/int64 for the integer-valued outputs to verify
luminal preserves the dtype declared by the eager model (LUM-486).
"""
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
scores = torch.tensor(
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
@@ -2050,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
expected = model(scores)
output = model_compiled(scores)
expected_dtypes = (
torch.int32,
torch.float32,
torch.int32,
torch.bool,
torch.int32,
torch.float32,
)
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
assert actual.dtype == expected_dtype
eager = eager.to(actual.dtype)
for actual, eager in zip(output, expected):
assert actual.dtype == eager.dtype, (
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
)
if actual.dtype.is_floating_point:
assert torch.allclose(actual, eager)
else:
@@ -2477,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_floor_div_positional_pt2(device: torch.device):
"""Conv1d stride output uses floor division before positional add."""
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.shape == original.shape == (15, 16)
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)

View File

@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
class ArgsortStableDuplicatesModel(torch.nn.Module):
"""Tests deterministic duplicate ordering for exported argsort."""
"""Tests deterministic duplicate ordering for exported argsort.
``idx_dtype`` parameterizes the integer dtype of the returned indices so
the test can verify dtype preservation across luminal's int dtype paths
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
lets us drive the same model toward int32 or int64 outputs.
"""
SORT_DIM = 1
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.argsort(x, dim=self.SORT_DIM)
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
class TinyMoERoutingModel(torch.nn.Module):
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
group_ids) to the requested dtype so the test can sweep int32 and int64
output paths (LUM-486). Internal indices stay int64 because torch.gather
/ torch.scatter require int64 index tensors.
"""
TOP_K = 2
ROUTING_DIM = -1
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
DISPATCH_ON = 1
GROUP_SIZE = 2
def __init__(self) -> None:
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
self.register_buffer(
"expert_scale",
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
routing_sign = torch.sign(masked_values)
return (
routed_indices,
routed_indices.to(self.idx_dtype),
masked_values,
dispatch,
dispatch.to(self.idx_dtype),
inactive_mask,
group_ids,
group_ids.to(self.idx_dtype),
routing_sign,
)
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
return self.conv(x)
class Conv1dFloorDivPositionalModel(torch.nn.Module):
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
self.conv2 = torch.nn.Conv1d(
16, 16, kernel_size=3, stride=2, padding=1, bias=True
)
self.position = torch.nn.Parameter(torch.randn(15, 16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.gelu(self.conv1(x))
x = torch.nn.functional.gelu(self.conv2(x))
x = x.squeeze(0).transpose(0, 1)
return x + self.position
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""

File diff suppressed because it is too large Load Diff

View File

@@ -315,8 +315,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -6,18 +8,14 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
}
fn main() {
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = env_usize("GEN_TOKENS", 30);
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
let stdio = benchmark_stdio::enabled();
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
@@ -38,11 +37,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
@@ -63,11 +57,14 @@ fn main() {
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -75,15 +72,66 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
&prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
print_token_ids: bool,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -93,7 +141,7 @@ fn main() {
const EOS_TOKEN: u32 = 1;
let prefill_start = std::time::Instant::now();
let prefill_start = Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -121,12 +169,26 @@ fn main() {
.unwrap()
.0 as u32;
generated_token_ids.push(next_token);
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
seen_tokens.insert(next_token);
for _ in 1..gen_tokens {
let start = std::time::Instant::now();
if stdio && next_token == EOS_TOKEN {
break;
}
let start = Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -165,10 +227,21 @@ fn main() {
break;
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
fwd_durations.push(start.elapsed());
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
if print_token_ids {
println!("Generated token ids: {generated_token_ids:?}");

View File

@@ -462,8 +462,13 @@ fn hlir_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
let hidden_exp = hidden.unsqueeze(2);
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 500;
let search_graphs = 500;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let prompt = "Explain what a neural network is in a paragraph.";
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -31,14 +47,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
// Build graph
let mut cx = Graph::default();
@@ -66,10 +74,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -77,12 +88,65 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -94,13 +158,16 @@ fn main() {
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
let mut generated = 0usize;
for i in 0..total_steps {
let start = std::time::Instant::now();
let start = Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -159,12 +226,21 @@ fn main() {
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{}", decoded);
std::io::stdout().flush().unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -246,8 +246,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "Qwen/Qwen3-4B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 500;
let search_graphs = 500;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let prompt = "Explain what a neural network is in a paragraph.";
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -31,7 +47,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -54,10 +69,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -65,12 +83,58 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -82,13 +146,16 @@ fn main() {
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
const STOP_TOKEN: u32 = 151643; // <|end|>
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
let mut generated = 0usize;
for i in 0..total_steps {
let start = std::time::Instant::now();
let start = Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -147,12 +214,21 @@ fn main() {
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{}", decoded);
std::io::stdout().flush().unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -287,8 +287,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -6,15 +8,27 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 30;
let search_graphs = 50;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 30)
} else {
30
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
} else {
50
};
let prompt = "The capital of France is";
let ctx = CudaContext::new(0).unwrap();
@@ -24,7 +38,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -47,10 +60,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -58,14 +74,63 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -76,7 +141,7 @@ fn main() {
const STOP_TOKEN: u32 = 151643;
// Prefill: process prompt tokens one at a time
let prefill_start = std::time::Instant::now();
let prefill_start = Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -105,13 +170,27 @@ fn main() {
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
seen_tokens.insert(next_token);
// Decode loop
for _ in 1..gen_tokens {
let start = std::time::Instant::now();
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
break;
}
let start = Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -150,13 +229,23 @@ fn main() {
break;
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
fwd_durations.push(start.elapsed());
}
println!();
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
// Report benchmarks
println!();
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,

View File

@@ -287,7 +287,8 @@ impl QwenMoE {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}
@@ -385,8 +386,13 @@ fn attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -174,8 +174,13 @@ fn decoder_self_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total, ..));
let v_full = v_cache_out.slice((.., ..total, ..));
let mut k_full = k_cache_out.slice((.., ..total, ..));
let mut v_full = v_cache_out.slice((.., ..total, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
k_full.shape.dims[1] = total;
v_full.shape.dims[1] = total;
let q = split_heads(q);

View File

@@ -0,0 +1,58 @@
use std::{
io::{BufRead, Write},
time::Instant,
};
pub fn enabled() -> bool {
std::env::args().any(|arg| arg == "--stdio")
}
pub fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn emit_ready() {
println!("READY");
std::io::stdout().flush().unwrap();
}
pub fn serve(mut f: impl FnMut(&str)) {
emit_ready();
let stdin = std::io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap();
f(&line);
}
}
pub fn emit_token(token: &str) {
println!("TOK\t{}", escape_token(token));
std::io::stdout().flush().unwrap();
}
pub fn emit_eoq(generated: usize, query_start: Instant) {
println!(
"EOQ\t{}\t{:.3}",
generated,
query_start.elapsed().as_secs_f64() * 1e3
);
std::io::stdout().flush().unwrap();
}
fn escape_token(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
_ => out.push(ch),
}
}
out
}

View File

@@ -7,14 +7,13 @@
//! - [`NativeDynBackend`]: the reference implementation for CPU
use std::collections::HashMap;
use std::time::Duration;
use half::{bf16, f16};
use petgraph::stable_graph::NodeIndex;
use rustc_hash::FxHashMap;
use crate::dtype::DType;
use crate::graph::{Graph, SearchOptions};
use crate::graph::Graph;
use crate::hlir::{NativeData, NativeRuntime, Output};
use crate::op::Runtime;
@@ -47,18 +46,6 @@ pub trait DynBackend {
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
// --- Optional diagnostics --------------------------------------------
fn kernel_names(&self) -> Vec<String> {
vec![]
}
fn host_op_names(&self) -> Vec<String> {
vec![]
}
fn print_execution_stats(&self) {}
// --- Optional device pointer support (GPU backends) --------------------
fn supports_device_ptrs(&self) -> bool {
@@ -175,9 +162,7 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rng = rand::rng();
let search_options = search_options_from_env(args.search_iters);
let mut rt = graph.search_options(rt, search_options, &mut rng);
let mut rt = graph.search(rt, args.search_iters);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);
@@ -194,39 +179,6 @@ pub fn compile_backend<Rt: Runtime + 'static>(
Ok(wrap(rt))
}
fn env_usize(name: &str) -> Option<usize> {
std::env::var(name).ok()?.parse().ok()
}
fn env_duration_ms(name: &str) -> Option<Duration> {
Some(Duration::from_millis(env_usize(name)? as u64))
}
fn search_options_from_env(limit: usize) -> SearchOptions {
let mut options = SearchOptions::new(limit);
if let Some(generation_size) = env_usize("LUMINAL_SEARCH_GENERATION_SIZE") {
options = options.generation_size(generation_size);
}
if let Some(mutations) = env_usize("LUMINAL_SEARCH_MUTATIONS") {
options = options.mutations(mutations);
}
if let Some(trials) = env_usize("LUMINAL_SEARCH_TRIALS") {
options = options.trials(trials);
}
if let Some(keep_best) = env_usize("LUMINAL_SEARCH_KEEP_BEST") {
options = options.keep_best(keep_best);
}
if let Some(profile_timeout) = env_duration_ms("LUMINAL_SEARCH_PROFILE_TIMEOUT_MS") {
options = options.profile_timeout(profile_timeout);
}
if let Some(group_timeout) = env_duration_ms("LUMINAL_SEARCH_GROUP_TIMEOUT_MS") {
options = options.group_timeout(group_timeout);
}
options
}
// ---------------------------------------------------------------------------
// Shared utilities
// ---------------------------------------------------------------------------

View File

@@ -11,6 +11,7 @@ impl Add for GraphTensor {
type Output = GraphTensor;
fn add(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to add tensors. Got {:?} and {:?}",
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
type Output = GraphTensor;
fn mul(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(
self.dims(),
rhs.dims(),
"Dims must match to multiply tensors."
);
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
@@ -474,6 +480,42 @@ pub(super) mod tests {
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
}
#[test]
#[should_panic(expected = "Dims must match to add tensors.")]
fn test_add_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a + b;
}
#[test]
#[should_panic(expected = "Dims must match to multiply tensors.")]
fn test_mul_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a * b;
}
#[test]
#[should_panic(expected = "Dims must match to mod tensors.")]
fn test_mod_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a % b;
}
#[test]
#[should_panic(expected = "Dims must match to lt tensors.")]
fn test_lt_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a.lt(b);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -557,6 +599,27 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_mod_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
test_binary_transforms(
size,
(),
|a, b| a % b.expand_rhs(a.shape),
|a, b| {
let lhs = a.to_vec1::<f32>().unwrap();
let rhs_scalar = b.to_scalar::<f32>().unwrap();
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
},
identity,
shift_from_zero,
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -570,6 +633,28 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_lt_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS for `lt`.
test_binary(
size,
(),
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|a, b| {
let scalar = b.to_scalar::<f32>().unwrap();
let lhs = a.to_vec1::<f32>().unwrap();
let result: Vec<f32> = lhs
.iter()
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
.collect();
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
},
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]

View File

@@ -467,7 +467,7 @@ impl GraphTensor {
let mut win = Vec::with_capacity(n);
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
let effective_window = *d * (*k - 1) + 1;
win.push(((*dim - effective_window) / s) + 1);
win.push((*dim - effective_window).floor_div(s) + 1);
}
// [win..., kernel...]
@@ -905,6 +905,14 @@ mod tests {
);
}
#[test]
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
let mut cx = Graph::new();
let inp = cx.tensor((80, 3000));
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
}
#[test]
fn test_unsqueeze() {
let mut cx = Graph::new();

View File

@@ -217,38 +217,32 @@ pub fn reduce_sort(name: &str) -> SortDef {
}
pub type HLIROps = (
(
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
),
(
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Concat2D,
EmbeddingBagSum,
Scatter,
SumReduce,
MaxReduce,
Softmax,
),
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
Exp2,
Log2,
Sin,
Recip,
Sqrt,
Add,
Mul,
Mod,
LessThan,
Gather,
Scatter,
SumReduce,
MaxReduce,
Softmax,
);
#[derive(Default, Debug, Clone)]
@@ -1727,9 +1721,7 @@ impl NativeOp for Add {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
}
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x || y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
}
}
}
@@ -1816,9 +1808,7 @@ impl NativeOp for Mul {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
}
NativeData::Bool(a) => {
NativeData::Bool(bin_fn(a_ind, a, b_ind, b, NativeData::bool, |x, y| x && y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
}
}
}
@@ -2136,233 +2126,6 @@ impl NativeOp for Gather {
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Concat2D {
pub rows: Expression,
pub lhs_cols: Expression,
pub rhs_cols: Expression,
}
impl Display for Concat2D {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Concat2D")
}
}
impl HLIROp for Concat2D {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (Concat2D {} {} {}) {})",
self.rows.to_egglog(),
self.lhs_cols.to_egglog(),
self.rhs_cols.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1]),
)
}
}
impl EgglogOp for Concat2D {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"Concat2D",
&[
("rows", EXPRESSION),
("lhs_cols", EXPRESSION),
("rhs_cols", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
rows: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
lhs_cols: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
rhs_cols: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for Concat2D {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let rows = self.rows.exec(dyn_map).unwrap();
let lhs_cols = self.lhs_cols.exec(dyn_map).unwrap();
let rhs_cols = self.rhs_cols.exec(dyn_map).unwrap();
fn concat_rows<T: Clone>(
lhs: &[T],
rhs: &[T],
rows: usize,
lhs_cols: usize,
rhs_cols: usize,
) -> Vec<T> {
let mut out = Vec::with_capacity(rows * (lhs_cols + rhs_cols));
for row in 0..rows {
let lhs_base = row * lhs_cols;
let rhs_base = row * rhs_cols;
out.extend_from_slice(&lhs[lhs_base..lhs_base + lhs_cols]);
out.extend_from_slice(&rhs[rhs_base..rhs_base + rhs_cols]);
}
out
}
match (inputs[0], inputs[1]) {
(NativeData::F32(lhs), NativeData::F32(rhs)) => {
NativeData::F32(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::F16(lhs), NativeData::F16(rhs)) => {
NativeData::F16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bf16(lhs), NativeData::Bf16(rhs)) => {
NativeData::Bf16(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Int(lhs), NativeData::Int(rhs)) => {
NativeData::Int(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
(NativeData::Bool(lhs), NativeData::Bool(rhs)) => {
NativeData::Bool(concat_rows(lhs, rhs, rows, lhs_cols, rhs_cols))
}
_ => panic!("Concat2D inputs must have the same dtype"),
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct EmbeddingBagSum {
pub n_bags: Expression,
pub n_indices: Expression,
pub hidden_dim: Expression,
pub num_embeddings: Expression,
}
impl Display for EmbeddingBagSum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "EmbeddingBagSum")
}
}
impl HLIROp for EmbeddingBagSum {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (EmbeddingBagSum {} {} {} {}) {})",
self.n_bags.to_egglog(),
self.n_indices.to_egglog(),
self.hidden_dim.to_egglog(),
self.num_embeddings.to_egglog(),
ilist_egglog(&[&inputs[0].1, &inputs[1].1, &inputs[2].1]),
)
}
}
impl EgglogOp for EmbeddingBagSum {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"EmbeddingBagSum",
&[
("n_bags", EXPRESSION),
("n_indices", EXPRESSION),
("hidden_dim", EXPRESSION),
("num_embeddings", EXPRESSION),
],
)
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
n_bags: extract_expr(egraph, kind_children[0], expr_cache).unwrap(),
n_indices: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
hidden_dim: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
num_embeddings: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
})),
input_enodes,
)
}
}
impl NativeOp for EmbeddingBagSum {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
let n_bags = self.n_bags.exec(dyn_map).unwrap();
let n_indices = self.n_indices.exec(dyn_map).unwrap();
let hidden_dim = self.hidden_dim.exec(dyn_map).unwrap();
let num_embeddings = self.num_embeddings.exec(dyn_map).unwrap_or(0);
let clamp_index = |value: i32, limit: usize| -> usize {
if limit == 0 {
return 0;
}
value.clamp(0, limit.saturating_sub(1) as i32) as usize
};
let clamp_offset = |value: i32| -> usize { value.clamp(0, n_indices as i32) as usize };
let NativeData::Int(indices) = inputs[1] else {
panic!("EmbeddingBagSum indices must be Int")
};
let NativeData::Int(offsets) = inputs[2] else {
panic!("EmbeddingBagSum offsets must be Int")
};
let bag_bounds = |bag: usize| -> (usize, usize) {
let start = offsets.get(bag).copied().map(clamp_offset).unwrap_or(0);
let end = offsets
.get(bag + 1)
.copied()
.map(clamp_offset)
.unwrap_or(n_indices);
(start, end.max(start))
};
match inputs[0] {
NativeData::F32(weight) => {
let mut out = vec![0.0f32; n_bags * hidden_dim];
for bag in 0..n_bags {
let (start, end) = bag_bounds(bag);
let out_base = bag * hidden_dim;
for pos in start..end {
let row = clamp_index(indices[pos], num_embeddings);
let row_base = row * hidden_dim;
for dim in 0..hidden_dim {
out[out_base + dim] += weight[row_base + dim];
}
}
}
NativeData::F32(out)
}
_ => panic!("EmbeddingBagSum only supports F32 weights"),
}
}
}
// Scatter Op (inverse of Gather)
#[derive(Debug, Clone, Default, PartialEq)]

View File

@@ -455,6 +455,31 @@ impl Expression {
terms.push(Term::CeilDiv);
Expression::new(terms)
}
/// Floor Division
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == 1 {
return self;
}
if self == 0 {
return 0.into();
}
if self == rhs {
return 1.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
&& let Some(c) = floor_div_i64(a, b)
{
return c.into();
}
// Shape dimensions are non-negative, so the existing integer Div term
// evaluates with floor semantics for dynamic shape expressions.
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Div);
Expression::new(terms)
}
/// Less than
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
@@ -654,6 +679,16 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
depth == 1
}
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
let q = a.checked_div(b)?;
let r = a.checked_rem(b)?;
if r != 0 && ((r > 0) != (b > 0)) {
q.checked_sub(1)
} else {
Some(q)
}
}
impl From<Term> for Expression {
fn from(value: Term) -> Self {
Expression::new(vec![value])
@@ -994,8 +1029,12 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
impl std::iter::Product for Expression {
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
// Empty product is the multiplicative identity, 1 — not 0. Returning
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
let Some(mut p) = iter.next() else {
return 0.into();
return 1.into();
};
for n in iter {
p *= n;
@@ -1106,6 +1145,27 @@ mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_empty_product_is_one() {
// The empty product (e.g. for a rank-0 tensor's shape) must be the
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
// emitters use `shape.iter().product()` to compute `numel`, and a
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
// launch with grid=(0, 1, 1) and crash at runtime.
let empty: Vec<Expression> = vec![];
assert_eq!(
empty.into_iter().product::<Expression>(),
Expression::from(1)
);
}
#[test]
fn test_empty_sum_is_zero() {
// Sanity check the additive identity stays 0 (it always was).
let empty: Vec<Expression> = vec![];
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
}
#[test]
fn test_basic_simplifications() {
let x = expr('x');