Compare commits

...

3 Commits

Author SHA1 Message Date
Austin Glover
a4bda06d64 tests for interface specification 2026-05-14 21:24:56 +00:00
tucker-luminal
6416ddb5f8 Use parallel launches for small CUDA kernels (#315)
* Use parallel launches for cast and iota kernels

* Use parallel launch for embed kernel
2026-05-14 00:47:12 -04:00
Austin Glover
c9d4ce6217 Better scalar support: tests + 12 fixes (LUM-474) (#300)
* Add scalar torture test suite (LUM-474)

60 tests asserting strict shape, dtype, and value match between PyTorch
eager and luminal_backend. Includes 9 xfail markers (12 cases) for the
known scalar bugs being addressed under LUM-485 through LUM-490.

* Add aten.select.int support to luminal_python translator (LUM-487)

Single-element indexing (`x[0]`, `x[i, j]`, `x[1, 2, 3]`) lowers to
`aten.select.int` in the FX graph. The translator previously bailed
with "Unsupported ATen op", blocking any model that reads a scalar by
indexing.

Implements `aten.select.int(self, dim, index)` as
`slice_along(index..index+1, dim).squeeze(dim)` — a pure
shape-manipulation that the luminal compiler can fold into surrounding
ops, with a single iota for the slice. Negative `dim` is normalized via
the existing `normalize_dim` helper; negative `index` is normalized
against the (concrete) axis size, mirroring how `translate_gather`
normalizes negative gather indices.

Removes the four `xfail(_INDEX_SELECT_REASON)` markers in
`tests/test_scalar_torture.py` (and the now-unused reason constant);
these tests now pass. Final counts: 52 passed / 8 xfailed (was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Fix LUM-488: support rank-0 tensor mod/lt and add aten.remainder dispatch

Two related issues prevented `x % torch.tensor(c)` from translating:

1. The luminal_python translator did not dispatch aten.remainder.Tensor /
   aten.remainder.Scalar at all, so any module that mods a tensor against
   a 0-d torch.tensor failed with "Unsupported ATen op".

2. core::ops::Rem and GraphTensor::lt asserted exact dim equality, blocking
   rank-0 to rank-N broadcasting that the backend already supports
   transparently for Add/Mul (the input_shapes vec is forwarded to the
   strided iterator).

Drop the dim assertions in Rem and lt so they match Add/Mul's broadcast
behavior, and add aten.remainder.Tensor/Scalar handlers in dispatch.rs that
mirror aten.fmod.Tensor (with ensure_same_dtype + broadcast_binary). For
the Scalar form, build a constant_float and expand_rhs onto the LHS shape.

Tests:
- New proptests test_mod_scalar_broadcast / test_lt_scalar_broadcast in
  src/frontend/binary.rs cover rank-0 RHS via expand_rhs.
- Removed @pytest.mark.xfail from test_mod_by_scalar_tensor; added the
  test_scalar_torture.py file to luminal_python's test suite.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python translator: dispatch aten.clamp.Tensor (LUM-489)

torch.clamp(x, lo, hi) where lo/hi are 0-d tensors routes to
aten.clamp.Tensor, which the translator did not previously handle. Add
a dedicated dispatch that decomposes clamp(x, lo, hi) into
min(max(x, lo), hi), broadcasting each rank-0 bound up to x's shape via
expand_rhs. Either bound may be absent (PyTorch allows min=None or
max=None), so each side is applied only when its FX input is a tensor.

Removes the @pytest.mark.xfail on test_clamp_with_scalar_tensors;
test_scalar_torture now reports 50 passed / 10 xfailed (was 48 / 12).

* luminal_python: support aten.prod.default full-reduction (LUM-490)

The translator's dispatch table mapped aten.{sum,mean,amax,amin}.default
to translate_reduction but lacked an entry for aten.prod.default, so
x.prod() with no axis raised "Unsupported ATen op". Add the missing
dispatch entry; the ReductionOp::Prod branch in translate_reduction
already handles both full-reduce and dim-reduce cases.

aten.prod.dim_int was already wired up; verified it routes correctly.

Removes the xfail marker on test_prod_all_produces_scalar in
test_scalar_torture.py — suite now reports 50 passed / 10 xfailed
(was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: preserve int64 (and other integer) output dtypes (LUM-486)

Full reductions of int64 tensors silently downcast to int32 on the
PyTorch boundary because `output_dtypes` was stored as luminal `DType`,
which collapses every integer width to `DType::Int` (i32). The Python
wrapper therefore reported int32 to PyTorch even when the user passed
int64, breaking strict dtype checks and risking silent overflow on
larger reductions / downstream ops that require int64.

Store `output_dtypes` directly as PT2 dtype codes (the original PyTorch
type IDs) instead of converting through luminal `DType` first. This
preserves int64 vs int32 (and similar) end-to-end. The Python output
path now reads int outputs as i32 and casts to the requested torch
dtype, so int8/int16/int32/int64/uint8 outputs all round-trip with the
right type tag.

Updates two existing assertions (`test_argsort_stable_duplicates`,
`test_tiny_moe_routing`) that were pinning int32 — the new behavior
matches PyTorch eager (int64). Adds `test_reduce_sum_all_axes_int64_preserves_dtype`
as a regression check, and removes the xfail on
`test_int_sum_produces_int_scalar`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: parametrize argsort/MoE dtype tests over int32 and int64

The LUM-486 fix preserves whichever integer dtype the eager model declares
on output. The original tests hardcoded int64 (the dtype torch.argsort and
torch.topk natively produce), which only exercised one path through the
preservation logic.

Add an idx_dtype knob to ArgsortStableDuplicatesModel and TinyMoERoutingModel
that casts the integer outputs to the requested dtype, and parametrize both
tests over [torch.int32, torch.int64]. Internal indices (passed to gather /
scatter) stay int64 since PyTorch requires that for index tensors; the cast
applies only to the returned values.

LUM-486

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Remove xfail markers for fixed scalar bugs

Drops the @pytest.mark.xfail markers on tests now passing after the
LUM-486, LUM-487, LUM-488, LUM-489, and LUM-490 fixes:

  - test_prod_all_produces_scalar (LUM-490)
  - test_clamp_with_scalar_tensors (LUM-489)
  - test_mod_by_scalar_tensor (LUM-488)
  - test_index_1d_produces_scalar (LUM-487)
  - test_index_all_dims_produces_scalar (LUM-487)
  - test_index_then_add_scalar_const (LUM-487)
  - test_model_returns_scalar_from_index (LUM-487)
  - test_int_sum_produces_int_scalar (LUM-486)

Also removes the now-unused _INDEX_SELECT_REASON constant.

The single remaining xfail is test_unsqueeze_expand_sum_back, blocked
on LUM-485 (full reduction returns shape [1] instead of rank-0 ()).

* luminal_python: full reductions return rank-0 () instead of [1] (LUM-485)

The translator's full-reduce path used to flatten the input to [1, N] and
reduce axis 1, leaving a residual [1] dimension. PyTorch eager produces
rank-0 () for x.sum() etc., and downstream ops (e.g. unsqueeze(0).expand(5))
rely on that rank — the residual [1] caused panics like "Cannot expand
from 2 dims to 1 dims" once the scalar fed any further op.

Drop the flatten and reduce over every axis directly. Special-case rank-0
input as a no-op so reducing a scalar is well-defined. Mean still divides
by the cached total to avoid redundant axis-prod work.

Removes the xfail marker on test_unsqueeze_expand_sum_back, which now
passes. With this commit the integration branch has zero xfails:
284 passed across test_scalar_torture.py + test_hlir_ops.py + test_unary.py.

* ruff format: tests/test_hlir_ops.py

Collapse a two-line f-string into one line per ruff format. No behavior change.

* Expand scalar torture suite with PyTorch / NumPy gap coverage

Cross-referenced our suite against PyTorch's test_torch / test_reductions
/ test_view_ops / test_indexing / test_type_promotion / test_binary_ufuncs
and NumPy's test_multiarray / test_indexing / test_shape_base. Added 14
new sections covering 47 in-scope gaps:

  - Binary ops with INPUT 0-d (not reduction-derived) on either side:
    add/sub/mul/div/mod/maximum/minimum/pow/floor_divide
  - Pure 0-d ↔ 0-d arithmetic (no broadcasting required)
  - Full comparison set (gt/ge/lt/le/eq/ne) on input 0-d, plus mask-by-eq
  - Reduction extras: argmax/argmin (no-arg + keepdim), sum(dim=()),
    sum/mean of 0-d input, cumsum of 0-d
  - Shape-flattening on 0-d: flatten/ravel/reshape(-1)/view(-1) all
    return shape (1,); reshape(()) on 1-element collapses to (); plus
    permute([]), contiguous(), squeeze() of (1,1,1,1), expand_as
  - Indexing extras: ellipsis x[...], index by 0-d int tensor, gather
    with 0-d index, negative-index x[-1]
  - Type promotion: float-0-d + int-Nd, int-0-d + float-Nd, cast
    roundtrip through 0-d, .float()/.int() shorthands, where with
    mixed-dtype scalar branches
  - Unary math (abs/neg/exp/sin/cos/tanh/sigmoid/sqrt/sign/floor/ceil)
    on reduction-derived 0-d
  - Bool logic: AND, OR, XOR, NOT on 0-d bool from comparisons
  - Stack of 0-ds; cat of unsqueezed 0-ds
  - Constants: torch.full((), v), torch.full_like on 0-d
  - Reduction edge cases: keepdim across all axes then divide;
    scalar broadcast onto transposed tensor
  - Mixed where/clamp shapes: clamp(x, scalar_tensor, py_float),
    where(cond, scalar_tensor, x)
  - Multi-output models: (scalar, tensor) tuple

Result: 363 passed / 15 xfailed across the python suite. The 15 new
xfails are documented inline with concrete failure modes:

  - 6 op-coverage gaps: aten.argmax.default, aten.argmin.default,
    aten.eq.Scalar, aten.ne.Tensor (translator dispatch entries needed).
  - 2 PT2 export issues: 0-d int64 graph inputs hit "invalid type: null,
    expected i64" in luminal's model.json parser; affects
    test_int_0d_plus_float_nd and test_gather_with_0d_index.
  - 2 real correctness bugs:
      * floor_divide with 0-d divisor returns the un-floored quotient
        (float division result, not floor(x/d)).
      * cumsum on a 0-d tensor panics with index-out-of-bounds.
  - 1 dynamo guard edge case: torch._dynamo emits an unresolved 'L'
    name in _guards_fn for 0-d index tensors.

Plus 4 cross-marker xfails on consequence of the above (the parametric
ne case, mask_by_scalar_eq variants, and other downstream effects).

* Rename test_scalar_torture.py -> test_scalars.py; drop 'torture' wording

The original 'torture test' label is jargon. The file is just a scalar
test module — keep the name simple to match the rest of the suite
(test_unary.py, test_hlir_ops.py).

* luminal_python: parse rounding_mode string arg correctly (LUM-494)

torch.floor_divide(x, d) decomposes to aten.div.Tensor_mode with
rounding_mode='floor' during PT2 export. The translator was reading
the kwarg via serde_json::Value::as_str(), but PT2 serializes string
args as {"as_string": "<value>"} objects, not bare JSON strings. The
extraction silently returned None, so the floor branch was skipped
and the regular un-floored quotient was returned.

Drill into the as_string field as a fallback so floor_divide and
div(x, d, rounding_mode='floor'/'trunc') produce floor(x/d) /
trunc(x/d) as expected.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix cumsum on rank-0 tensor (LUM-495)

The translator's cumsum handler called normalize_dim(dim, a.shape.len())
and then a.cumsum(dim) for any rank — including rank-0. The underlying
cumop in src/frontend/unary.rs indexes self.dims()[axis] inside the
padding/unfold loop, which panics with "index out of bounds: the len is
0 but the index is 0" when shape is empty.

PyTorch eager treats torch.cumsum(s, 0) on a 0-d tensor as an identity
op (cumsum of a single element is the element itself). Mirror the
rank-0 short-circuit pattern from the LUM-485 reduction fix and return
the input unchanged when a.shape.is_empty(). Move the dim arg fetch
inside the non-empty branch since dim is unused for rank-0.

Drops the xfail marker on test_cumsum_of_0d and adds a 1-element 1-D
sibling test that asserts shape (1,) round-trips.

* luminal_python: support aten.argmax/argmin (LUM-496)

argmax/argmin were missing from the translator dispatch table even
though we already have stable_argsort. Add a thin wrapper so the
PyTorch boundary lights up:

  argmax(x, dim=None)    -> argsort(flatten(x), descending=True).select(0, 0)
  argmax(x, dim=N)       -> argsort(x, dim=N, descending=True).select(N, 0)
  argmax(x, dim=N, keepdim=True) -> .unsqueeze(N) over the above
  argmin(...)            -> same with descending=False

The slice + squeeze chain produces a non-contiguous DType::Int view
whose underlying buffer is still sized for the un-sliced argsort
tensor. Final `* 1` materializes a contiguous Int copy with strides
matching the visible shape — same trick `translate_topk` uses for
its sliced index output. Without it the keepdim case panics
("No output node found") and the full-reduce case throws a Python
shape mismatch on the oversized buffer.

PyTorch's argmax returns int64 while luminal collapses to int32 (Int);
LUM-486 already widens at the Python boundary, so the contract is
preserved end-to-end. Drops the three `@pytest.mark.xfail` markers
from `test_argmax_all`, `test_argmin_all`, and `test_argmax_keepdim_1d`
in `test_scalars.py` (6 cases via parametrization).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: dispatch aten.eq.Scalar and aten.ne.Tensor (LUM-497)

Add the two missing comparison overloads to the translator dispatch.
eq.Scalar mirrors the existing ne.Scalar handler (constant_float + cast
+ expand_rhs to broadcast the scalar), and ne.Tensor mirrors the
existing eq.Tensor handler. Removes the corresponding xfail markers on
test_input_0d_comparisons[_NeInput0ds-...] and test_mask_by_scalar_eq.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: accept null range_constraint bounds (LUM-498)

A 0-d int64 graph input made PyTorch 2.10+ emit
`range_constraints: { sN: { min_val: null, max_val: null } }` for the
unbacked symbol PT2 introduces around the rank-0 tensor. Our serde
schema modeled `RangeConstraint.min_val` as `i64`, so deserialization
failed with `invalid type: null, expected i64`, blocking any model
with a scalar integer tensor input.

Make `min_val` and `max_val` `Option<i64>` (matching PT2's
`Optional[int]`) and fall back to 1 as the initial dynamic-dim value
when no lower bound is provided.

Tests: removes the xfail on `test_int_0d_plus_float_nd`, adds a new
`test_int32_0d_plus_float_nd` regression, and updates the xfail
reason on `test_gather_with_0d_index` (the parse error is fixed; a
separate downstream gather panic remains).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: drop dynamo input guards in pt2_backend (LUM-499)

When a 0-d int tensor is used as a tensor index (x[i] where
i = torch.tensor(2)), torch.export records duplicate input guards that
reference both the original local source (L['i']) and the rewrapped flat
args (L['args'][1]). The unlift pass cannot resolve L['i'] against the
wrapped (*args, **kwargs) signature, leaving a literal `L` reference in
the generated _guards_fn that raises NameError during retracing. The
data-dependent .item() in the surviving guard then trips fake-tensor
analysis with DataDependentOutputException.

Drop the guard list before run_decompositions so unlift produces an
empty _guards_fn, and DCE any leftover dead aten.item.default nodes
that came from index specialization.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix gather with rank-0 index on rank-1 source

PyTorch eager allows torch.gather(rank-1, dim, rank-0) — the only
rank-mismatch case it permits — and returns a rank-0 scalar. Our
gather_elements requires source-rank == index-rank, so the rank-0
index hit flatten_strides with mismatched (0, 1) lengths and
panicked.

Detect this specific pattern in translate_gather: unsqueeze the
rank-0 index to (1,), gather, then squeeze the result back to ().
Output shape and value match eager.

This was the last remaining xfail in test_scalars.py. Suite is now
381 passed / 0 xfailed / 0 failed across test_scalars.py +
test_hlir_ops.py + test_unary.py.

* luminal_python: clamp.Tensor handles all broadcastable bound shapes

PyTorch's aten.clamp.Tensor accepts bounds with any NumPy-broadcastable
shape (rank-0, same-shape, or broadcastable). The previous translator
used expand_rhs(result.shape) which appends dims rather than broadcasts,
so only rank-0 bounds came out correctly. Same-shape and broadcastable
bounds either panicked or silently produced wrong values.

Switch to broadcast_binary (the right-align + size-1 expand helper used
by aten.remainder.Tensor, aten.eq.Tensor, etc.). Now all three modes
work uniformly.

Add 7 new tests covering the previously-broken modes:
  - same-shape bounds (per-element clamp, e.g. learned bounds)
  - per-row broadcast (3,1) against (3,4)
  - per-col broadcast (4,) against (3,4)
  - mixed rank-0 lo + same-shape hi
  - min-only with same-shape lo
  - max-only with per-row hi
  - 3-D x with 2-D bounds (left-unsqueeze broadcast)

Suite goes from 381 to 388 passing, 0 xfailed.

* shape: empty Expression product returns 1, not 0

The empty product is the multiplicative identity (1) — every shape-iterator
call site (`shape.iter().product()` for `numel`, output-buffer sizing, CUDA
grid-dim computation) implicitly relies on this. The previous impl returned
0 for an empty iterator, which was a latent bug masked while no path
produced rank-0 shapes.

The LUM-485 fix (full reductions return rank-0 () instead of rank-1 [1])
exposed it on CUDA: SumReduce kernels with rank-0 output got `n_outputs=0`,
launched with `grid=(0, 1, 1)`, and crashed with "invalid CUDA launch
dimensions" — every CUDA reduction in the Python CUDA tests was failing.

Fix: return Expression::from(1) for empty iteration. Sum's identity (0)
was already correct and is unchanged.

Add two unit tests covering both identities.

* cargo fmt

* Fix PT2 passthrough input output ID collision

* Fix scalar argextremum keepdim behavior

* Defer PT2 interface collision fix

* Keep HLIR binary ops shape-strict

* fixed gemma issue

* Fix explicit broadcasts and conv shape division

* Normalize Whisper cache slice shape

---------

Co-authored-by: Austin Glover <austin@luminal.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Austin Glover <austin_glover@berekely.edu>
Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-13 20:16:30 -04:00
32 changed files with 2809 additions and 116 deletions

View File

@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn gather_experts(

View File

@@ -1540,19 +1540,22 @@ impl KernelOp for KernelIota {
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
vars.extend(self.range.dyn_vars());
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let range = self.range.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void iota_k(int *C{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {range}) return;
C[const_z] = {};
}}
}}",
@@ -1571,8 +1574,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.range, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.range.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -2935,6 +2938,14 @@ impl KernelOp for KernelCast {
) {
let out_dtype = cuda_dtype(self.out_dtype);
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let size = self.size.to_kernel();
let kernel = if self.in_dtype.bits() < 8 {
// Sub-byte packed types: multiple values packed per byte.
@@ -2944,9 +2955,11 @@ impl KernelOp for KernelCast {
let mask = (1u32 << bits) - 1;
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {size}) return;
long long bit_offset = idx * {bits};
long long byte_idx = bit_offset >> 3;
int bit_pos = (int)(bit_offset & 7);
@@ -2962,9 +2975,11 @@ extern \"C\" {{
let in_dtype = cuda_dtype(self.in_dtype);
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {size}) return;
out[const_z] = ({out_dtype})in[const_z];
}}
}}"
@@ -2983,8 +2998,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.size, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.size.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -3275,12 +3290,15 @@ impl KernelOp for KernelEmbed {
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let total_threads = batch_size * self.embed_dim;
let n_elements = total_threads.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {n_elements}) return;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
long long embed_idx = idx % embed_dim;
@@ -3303,13 +3321,12 @@ extern \"C\" {{
};
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
let total_threads = batch_size * self.embed_dim;
(
func,
module,
kernel,
(total_threads, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(total_threads.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
constants,
)

View File

@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
QwenMoeGraph {
graph: cx,
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
GemmaMoeGraph {
graph: cx,

View File

@@ -61,7 +61,8 @@ impl MoE {
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
weights_exp.shape.expand(expert_out.dims());
(expert_out * weights_exp).sum(n - 1)
}
}
@@ -478,7 +479,8 @@ mod tests {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
let _output = (down_out * weights_exp).sum(n - 1).output();
// Dump the HLIR to egglog

View File

@@ -855,8 +855,6 @@ Two important details:
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
---
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.

View File

@@ -8,7 +8,7 @@ echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
CUDA_TESTS="tests/"
# ── Phase 1: Native Backend ─────────────────────────────────

View File

@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -98,7 +98,12 @@ pub struct GraphTranslation {
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
/// distinctions luminal collapses internally — notably int64 vs int32,
/// both of which map to `DType::Int` in luminal but must be reported
/// back to PyTorch with their original precision.
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -124,7 +129,9 @@ pub struct CompiledGraph {
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
/// that luminal collapses to `DType::Int` internally).
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -476,10 +483,7 @@ impl CompiledGraph {
/// Get the PT2 dtype codes for all outputs (in order).
#[getter]
fn output_dtypes(&self) -> Vec<u32> {
self.output_dtypes
.iter()
.map(|d| luminal_dtype_to_pt2_code(*d))
.collect()
self.output_dtypes.clone()
}
/// Get output tensor data by name as f32 (copies to host).

View File

@@ -262,10 +262,13 @@ pub fn translate_pt2(
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges
// Set initial dynamic dim values from symbol ranges. PT2 emits
// `min_val: null` when the constraint is unbounded; fall back to 1 in
// that case (the smallest valid dim — used only as an initial value).
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
graph.set_dim(*c, rc.min_val as usize);
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
graph.set_dim(*c, initial);
}
}
@@ -281,14 +284,14 @@ pub fn translate_pt2(
})
.collect();
let output_dtypes: Vec<DType> = translated
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
// Python wrapper can return tensors with the right torch.dtype, even when
// luminal collapses the type internally (e.g. int64 → DType::Int).
let output_dtypes: Vec<u32> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
.unwrap_or(DType::F32)
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
})
.collect();

View File

@@ -15,7 +15,16 @@ pub struct ExportedProgram {
#[derive(Debug, Clone, Deserialize)]
pub struct RangeConstraint {
pub min_val: i64,
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
/// constraint is unbounded (no min set), so this must accept None.
#[serde(default)]
pub min_val: Option<i64>,
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
/// unused on the luminal side, but accepted to avoid deserialization
/// errors when PT2 emits it.
#[serde(default)]
#[allow(dead_code)]
pub max_val: Option<i64>,
}
#[derive(Debug, Deserialize)]

View File

@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, 1);
let mut b_expanded = b.expand_dim(0, out_dims[0]);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
@@ -389,8 +389,11 @@ fn depthwise_conv(
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
let patches = patches.expand_dim(2, group_out);
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
// Explicitly expand weight across the batch axis so the elementwise Mul
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
let w_expanded = w_flat
.expand_dim(0, patches.dims()[0])
.expand_dim(3, patches.dims()[3]);
// Element-wise multiply and sum over kernel dim
let product = patches * w_expanded;

View File

@@ -6,6 +6,7 @@ use crate::pt2_util::*;
use super::Translator;
use super::attention::SdpaVariant;
use super::reduction::ArgExtremum;
impl<'a> Translator<'a> {
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
@@ -147,6 +148,7 @@ impl<'a> Translator<'a> {
// Slice/index ops
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
@@ -219,6 +221,16 @@ impl<'a> Translator<'a> {
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
// Tensor comparisons
"torch.ops.aten.eq.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a.eq(scalar)
}
"torch.ops.aten.ne.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
@@ -236,6 +248,13 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a.eq(b)
}
"torch.ops.aten.ne.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.ne(b)
}
"torch.ops.aten.le.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
@@ -274,18 +293,27 @@ impl<'a> Translator<'a> {
// Clamp
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
// Cumsum
"torch.ops.aten.cumsum.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let a = if a.dtype == DType::Bool {
a.cast(DType::Int)
} else {
a
};
a.cumsum(dim)
// Rank-0 (scalar) input: cumsum of a single element is the element
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
// and the underlying `cumop` indexes `shape.dims[axis]` which would
// panic with empty dims.
if a.shape.is_empty() {
a
} else {
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
a.cumsum(dim)
}
}
// Floor / Ceil / Erf (approximations)
@@ -381,6 +409,17 @@ impl<'a> Translator<'a> {
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
// PyTorch's argmax/argmin returns int64; the dtype is preserved
// through the LUM-486 boundary widening.
"torch.ops.aten.argmax.default" => {
self.translate_argextremum(node, ArgExtremum::Max)?
}
"torch.ops.aten.argmin.default" => {
self.translate_argextremum(node, ArgExtremum::Min)?
}
// Gather (axis-aware)
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
@@ -444,6 +483,28 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a % b
}
// Remainder (Python-style modulo). For float tensors aten.remainder
// returns the same value as `%` would in luminal (Mod follows the
// language's % semantics on f32). The Tensor variant accepts a
// tensor RHS that may be rank-0; broadcast both operands so a
// scalar RHS is expanded to match the LHS shape before mod.
"torch.ops.aten.remainder.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a % b
}
"torch.ops.aten.remainder.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a % scalar
}
// Prod reduction
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,

View File

@@ -120,6 +120,47 @@ impl<'a> Translator<'a> {
Ok(a.slice_along(start..end, dim))
}
/// `aten.select.int(self, dim, index)` — select element `index` along
/// `dim`, dropping that dim. Output rank = input rank 1, so a 1-D input
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
/// are normalized against the input shape.
///
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
/// slice + squeeze decomposition (rather than `gather`) because the
/// composition is a pure shape manipulation with a single iota, which the
/// luminal compiler can fold into surrounding ops.
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let index_raw = self.get_int_arg(node, 2)?;
// Normalize a possibly-negative index. PyTorch accepts indices in
// [-size, size); negative wraps from the end.
let index = if index_raw < 0 {
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"select.int: dim {} must be concrete to normalize a negative index",
dim
)
})?;
let normalized = axis_size as i64 + index_raw;
if normalized < 0 {
bail!(
"select.int: index {} out of range for dim {} of size {}",
index_raw,
dim,
axis_size
);
}
normalized as usize
} else {
index_raw as usize
};
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
}
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
names
@@ -333,6 +374,17 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
// gather_elements requires the index rank to match the source rank,
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
let indices = if promoted_rank0 {
indices.unsqueeze(0)
} else {
indices
};
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
// Stay in Int the whole way — multiplying an Int tensor by an
// Expression broadcasts the axis size and avoids three Cast nodes
@@ -344,7 +396,12 @@ impl<'a> Translator<'a> {
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
Ok(a.gather_elements(normalized, dim))
let result = a.gather_elements(normalized, dim);
Ok(if promoted_rank0 {
result.squeeze(0)
} else {
result
})
}
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {

View File

@@ -6,6 +6,20 @@ use crate::pt2_util::*;
use super::Translator;
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
/// smallest (ascending sort) element when scanning the input.
#[derive(Clone, Copy)]
pub(crate) enum ArgExtremum {
Max,
Min,
}
impl ArgExtremum {
fn descending(self) -> bool {
matches!(self, ArgExtremum::Max)
}
}
/// Compute total element count, returning an error if any dimension is symbolic.
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
a.dims().iter().try_fold(1usize, |acc, d| {
@@ -37,32 +51,26 @@ impl<'a> Translator<'a> {
(axes, keepdim)
}
_ => {
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
// override below assumes contiguous, no-broadcast storage —
// otherwise the `[1, N]` view treats stride-0 broadcast dims
// as if they held N distinct values and reads past the backing
// buffer. Materialize first when that's not the case (matches
// the guard `translate_reshape` already applies).
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
let ndim = a.shape.len();
if ndim == 0 {
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
// and mean of a scalar is just the scalar.
return Ok(a);
}
let total = concrete_numel(&a)?;
let has_broadcast = a
.shape
.dims
.iter()
.zip(a.shape.strides.iter())
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
let a = if has_broadcast || !a.shape.is_contiguous() {
a + 0.0
} else {
a
};
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let axes: Vec<usize> = (0..ndim).collect();
let result = match op {
ReductionOp::Sum => flat.sum(vec![1]),
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
ReductionOp::Max => flat.max(vec![1]),
ReductionOp::Min => flat.min(vec![1]),
ReductionOp::Prod => flat.prod(vec![1]),
ReductionOp::Sum => a.sum(axes),
// Note: the luminal `mean` helper divides by the product of the
// axis dims, but we already require concrete dims here so we
// divide by the cached `total` to avoid recomputing.
ReductionOp::Mean => a.sum(axes) / total as f32,
ReductionOp::Max => a.max(axes),
ReductionOp::Min => a.min(axes),
ReductionOp::Prod => a.prod(axes),
};
return Ok(result);
}
@@ -86,4 +94,100 @@ impl<'a> Translator<'a> {
Ok(result)
}
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
/// existing `stable_argsort` op and selecting the first index along the
/// sort axis.
///
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
/// for argmin). FX export emits the inputs positionally:
/// - input 0: tensor
/// - input 1: dim (Int) or None (Other) — when `dim=None`
/// - input 2: keepdim (Bool, optional)
///
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
/// The result of argsort along the sort axis is sliced at index 0,
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
/// `dim`.
///
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
/// view; we materialize it with `* 1` so the resulting node has
/// contiguous strides matching its visible shape (mirroring the
/// `topk` lowering in `translate_topk`). Without this, the output
/// buffer would be sized for the un-sliced argsort tensor while the
/// shape tracker reports a smaller rank.
///
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
/// metadata records int64 and the Python wrapper widens at the
/// boundary, so the PyTorch contract is preserved end-to-end
/// (LUM-486).
pub(crate) fn translate_argextremum(
&mut self,
node: &Node,
which: ArgExtremum,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
// argument (typically `Argument::Other(Null)`), so a missing or
// non-int slot means "reduce over the flattened tensor".
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
self.get_int_arg(node, 1).ok()
} else {
None
};
let keepdim = if node.inputs.len() > 2 {
self.get_bool_arg(node, 2).unwrap_or(false)
} else {
false
};
if a.shape.is_empty() {
match dim_opt {
None | Some(0) | Some(-1) => {
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
// `keepdim=True` does not add a dimension when the input is 0-d.
return Ok(self.graph.constant(0i64).cast(DType::Int));
}
Some(dim) => {
return Err(anyhow::anyhow!(
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
));
}
}
}
let descending = which.descending();
let (sort_axis, base) = match dim_opt {
None => {
// Full-reduce: flatten to 1-D, argsort along axis 0.
let total = concrete_numel(&a)?;
let flat = reshape_tensor(a, vec![Expression::from(total)]);
(0usize, flat)
}
Some(dim_raw) => {
let dim = normalize_dim(dim_raw, a.shape.len());
(dim, a)
}
};
// Pick index 0 along the sort axis. The slice-then-squeeze chain
// produces a non-contiguous view whose physical buffer is still
// sized for the un-sliced argsort tensor; the optional `keepdim`
// unsqueeze adds a stride-0 axis which is also non-contiguous.
// Materialize at the end with `* 1` so the resulting node has
// contiguous strides matching its visible shape (matches the
// pattern used by `translate_topk` for sliced index outputs).
let sorted = base.stable_argsort(sort_axis, descending);
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
let result = if keepdim {
picked.unsqueeze(sort_axis)
} else {
picked
};
Ok(result * 1)
}
}

View File

@@ -213,12 +213,18 @@ impl<'a> Translator<'a> {
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
// Check rounding_mode kwarg
// Check rounding_mode kwarg. PT2 serializes string args as
// {"as_string": "<value>"}, so we have to drill into the JSON.
let rounding_mode = node.inputs.iter().find_map(|input| {
if input.name == "rounding_mode"
&& let Argument::Other(val) = &input.arg
{
return val.as_str().map(|s| s.to_string());
if let Some(s) = val.as_str() {
return Some(s.to_string());
}
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
return Some(s.to_string());
}
}
None
});
@@ -269,4 +275,52 @@ impl<'a> Translator<'a> {
}
Ok(result)
}
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
///
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
/// overload takes tensor bounds that appear as separate input nodes in the
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
///
/// - rank-0 (scalar wrapped in a tensor) — most common
/// - same shape as self (per-element clamp, e.g. learned bounds)
/// - any shape that broadcasts to self via right-align + size-1 expand
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
///
/// We use `broadcast_binary` to right-align and expand both operands to a
/// common shape before the elementwise max/min, matching PyTorch semantics
/// across all three modes.
///
/// Either bound may be absent (FX represents this as a non-tensor argument
/// at the corresponding input slot), in which case we clamp to one side
/// only.
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_tensor = node
.inputs
.get(1)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let max_tensor = node
.inputs
.get(2)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let mut result = a;
if let Some(lo) = min_tensor {
let lo = lo.cast(result.dtype);
let (r, lo) = broadcast_binary(result, lo);
result = r.maximum(lo);
}
if let Some(hi) = max_tensor {
let hi = hi.cast(result.dtype);
let (r, hi) = broadcast_binary(result, hi);
result = r.minimum(hi);
}
Ok(result)
}
}

View File

@@ -1,5 +1,6 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
import warnings
from typing import List
import torch
@@ -8,6 +9,10 @@ from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
class DTypeBoundaryWarning(UserWarning):
"""Warns when the PyTorch boundary must cast input data before execution."""
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
@@ -95,6 +100,15 @@ class CompiledModel:
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if tensor.dtype != expected_dtype:
warnings.warn(
"Luminal compiled input "
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
f"expects {expected_dtype}; converting at every call will "
"allocate/copy input data.",
DTypeBoundaryWarning,
stacklevel=2,
)
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
@@ -135,6 +149,11 @@ class CompiledModel:
# Run the graph
self._graph.run()
# Integer dtypes for which we read the buffer as i32 and then cast.
# Includes int64 because luminal collapses all integer types to its
# 32-bit `Int` internally — we restore the original precision here.
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
# Collect outputs
if _use_zero_copy:
outputs = []
@@ -150,11 +169,12 @@ class CompiledModel:
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype == torch.int32:
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
elif out_dtype == torch.bool:
@@ -182,9 +202,13 @@ class CompiledModel:
if i < len(output_dtype_codes)
else torch.float32
)
if out_dtype == torch.int32:
if out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
)
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))

View File

@@ -491,6 +491,62 @@ def compile(
return _save_and_compile(ep, factory, search_iterations)
def _drop_input_guards(ep):
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
``i = torch.tensor(2)``), torch.export records two equivalent input
guards: ``L['i'].item() == 2`` (referencing the original local source)
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
Two failures stack on top of each other:
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
a literal ``L`` reference in the generated ``_guards_fn`` and raising
``NameError: name 'L' is not defined`` during retracing.
2. Even after dropping the unresolvable guard, the surviving
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
a graph break.
These guards exist solely to validate inputs at runtime in eager-mode
consumers of the ExportedProgram; the luminal compiler does its own
input shape/dtype checks against the compiled graph signature, so we
are not losing any safety by clearing them.
"""
if hasattr(ep, "_guards_code"):
ep._guards_code = []
def _drop_dead_data_dependent_ops(gm):
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
When dynamo specializes a 0-d int input by tracing through ``.item()``,
the resulting graph may contain a dead ``aten.item.default`` node whose
output is never consumed. luminal's translator does not lower
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
node in the graph causes a graph break at compile time. Eliminating it
keeps the (correctly specialized) downstream graph in a single subgraph.
"""
graph = gm.graph
changed = False
for node in list(graph.nodes):
if (
node.op == "call_function"
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
and len(node.users) == 0
):
graph.erase_node(node)
changed = True
if changed:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
def _legacy_auto_dim(example_args):
"""Match the historical `dynamic_dim="auto"` heuristic.
@@ -570,6 +626,11 @@ def _eager_pt2_compile(
if dynamic_shapes is None:
raise
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
# LUM-499: drop dynamo-emitted input guards before run_decompositions
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
# data-dependent .item() calls and unresolved `L[...]` references.
_drop_input_guards(ep)
_drop_dead_data_dependent_ops(ep.graph_module)
ep = ep.run_decompositions(_decomp_table())
# When using shared memory (original_weights), strip large weight buffers

View File

@@ -0,0 +1,215 @@
from dataclasses import dataclass
import warnings
from typing import Callable
import pytest
import torch
from luminal import luminal_backend
from luminal.compiled_model import DTypeBoundaryWarning
class BoundaryNoopModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.bool:
return x | torch.zeros((), dtype=torch.bool, device=x.device)
return x + torch.zeros((), dtype=x.dtype, device=x.device)
@dataclass(frozen=True)
class DTypeCase:
name: str
dtype: torch.dtype
values: Callable[[], torch.Tensor]
xfail_reason: str | None = None
DTYPE_CASES = [
DTypeCase(
"bool",
torch.bool,
lambda: torch.tensor([True, False, True], dtype=torch.bool),
),
DTypeCase(
"uint8",
torch.uint8,
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
),
DTypeCase(
"int8",
torch.int8,
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
),
DTypeCase(
"int16",
torch.int16,
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
),
DTypeCase(
"int32",
torch.int32,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int32,
),
),
DTypeCase(
"int64_i32_range",
torch.int64,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int64,
),
),
DTypeCase(
"float16",
torch.float16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
),
DTypeCase(
"bfloat16",
torch.bfloat16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
),
DTypeCase(
"float32",
torch.float32,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
),
DTypeCase(
"float64_f32_exact",
torch.float64,
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
),
DTypeCase(
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
xfail_reason=(
"Luminal currently collapses integer inputs through i32 at the "
"compiled boundary, so out-of-range int64 values lose information."
),
),
DTypeCase(
"float64_precision_sensitive",
torch.float64,
lambda: torch.tensor(
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
xfail_reason=(
"Luminal currently routes float64 no-op computation through f32 "
"storage/outputs before restoring the PyTorch-visible dtype."
),
),
]
def _cuda_skip_reason() -> str | None:
if not torch.cuda.is_available():
return "CUDA is not available"
try:
from luminal.luminal import _cuda_lite_factory_capsule
_cuda_lite_factory_capsule()
except (ImportError, AttributeError, RuntimeError) as exc:
return f"luminal_python was not built with CUDA support: {exc}"
return None
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
def boundary_device(request) -> torch.device:
device_name = request.param
if device_name == "cuda":
skip_reason = _cuda_skip_reason()
if skip_reason is not None:
pytest.skip(skip_reason)
return torch.device(device_name)
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
if case.xfail_reason is not None
else (),
id=case.name,
)
for case in DTYPE_CASES
],
)
def test_boundary_noop_preserves_dtype_and_values(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
expected = model(x)
actual = compiled(x)
assert isinstance(actual, torch.Tensor)
assert actual.dtype == expected.dtype
assert torch.equal(actual.cpu(), expected.cpu())
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"uint8",
"int8",
"int16",
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_boundary_warns_when_input_dtype_requires_conversion(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
compiled(x)
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
],
)
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always")
compiled(x)
dtype_boundary_warnings = [
record
for record in records
if issubclass(record.category, DTypeBoundaryWarning)
]
assert dtype_boundary_warnings == []

View File

@@ -1,5 +1,6 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
@@ -220,6 +221,7 @@ from test_models import (
Conv1dNoPadModel,
Conv1dSamePadModel,
Conv1dBiasModel,
Conv1dFloorDivPositionalModel,
Conv2dNoPadModel,
Conv2dSamePadModel,
Conv2dBiasModel,
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
eager = model(x)
out = model_compiled(x)
assert out.dtype == eager.dtype == torch.int64
assert torch.equal(out, eager)
def test_reduce_sum_3d_axis1(device: torch.device):
"""Test sum reduction along axis 1 for a 3D tensor."""
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
@@ -2022,9 +2035,16 @@ def test_split(device: torch.device):
# ========== Argsort / MoE Routing Tests ==========
def test_argsort_stable_duplicates(device: torch.device):
"""Duplicate values should follow stable lower-index-first tie-breaking."""
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
"""Duplicate values should follow stable lower-index-first tie-breaking.
Parametrized over int32/int64 to verify luminal preserves whichever
integer dtype the eager model declares (LUM-486).
"""
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
device
)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.tensor(
[[2.0, 1.0, 1.0, 3.0]],
@@ -2033,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.dtype == torch.int32
assert torch.equal(output, original.to(torch.int32))
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
assert output.dtype == original.dtype, (
f"luminal returned {output.dtype}, eager produced {original.dtype}"
)
assert torch.equal(output, original)
def test_tiny_moe_routing(device: torch.device):
"""Focused proof for build MoE routing support."""
model: torch.nn.Module = TinyMoERoutingModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
"""Focused proof for built MoE routing support.
Parametrized over int32/int64 for the integer-valued outputs to verify
luminal preserves the dtype declared by the eager model (LUM-486).
"""
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
scores = torch.tensor(
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
@@ -2050,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
expected = model(scores)
output = model_compiled(scores)
expected_dtypes = (
torch.int32,
torch.float32,
torch.int32,
torch.bool,
torch.int32,
torch.float32,
)
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
assert actual.dtype == expected_dtype
eager = eager.to(actual.dtype)
for actual, eager in zip(output, expected):
assert actual.dtype == eager.dtype, (
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
)
if actual.dtype.is_floating_point:
assert torch.allclose(actual, eager)
else:
@@ -2477,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_floor_div_positional_pt2(device: torch.device):
"""Conv1d stride output uses floor division before positional add."""
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.shape == original.shape == (15, 16)
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)

View File

@@ -0,0 +1,142 @@
import torch
import pytest
from luminal import luminal_backend
class StrideSensitiveInputModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"coeff",
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.coeff
class TwoInputReadModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * 2.0 + y * 3.0
class ReturnInputModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class ReturnInputAndComputedModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x, x + 1.0
class CloneThenMutateModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.clone()
y.add_(1.0)
return y, x * 2.0
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
return base, base.t()
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
assert not view.is_contiguous()
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
def _assert_same(actual, expected) -> None:
if isinstance(expected, tuple):
assert isinstance(actual, tuple)
assert len(actual) == len(expected)
for actual_item, expected_item in zip(actual, expected):
_assert_same(actual_item, expected_item)
return
assert torch.allclose(actual, expected)
def _single_non_contiguous_view(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return StrideSensitiveInputModel().to(device), (view,), base
def _same_view_twice(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return TwoInputReadModel().to(device), (view, view), base
def _overlapping_views(device: torch.device):
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
x = base[:3, :4]
y = base[1:, 1:]
assert not x.is_contiguous()
assert not y.is_contiguous()
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
return TwoInputReadModel().to(device), (x, y), base
def _return_input(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputModel().to(device), (view,), base
def _return_input_and_computed(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputAndComputedModel().to(device), (view,), base
def _internal_clone_inplace(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return CloneThenMutateModel().to(device), (view,), base
@pytest.mark.parametrize(
"make_case",
[
pytest.param(
_single_non_contiguous_view,
id="single_non_contiguous_view_stride_sensitive_read",
),
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
pytest.param(_return_input, id="return_input_boundary_value"),
pytest.param(
_return_input_and_computed,
id="return_input_boundary_value_and_computed_value",
),
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
],
)
def test_input_boundary_contiguous_materialization_cases(
device: torch.device, make_case
) -> None:
model, inputs, base = make_case(device)
compiled = torch.compile(model, backend=luminal_backend)
base_before = base.clone()
expected = model(*inputs)
actual = compiled(*inputs)
_assert_same(actual, expected)
assert torch.allclose(base, base_before)
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
device: torch.device,
) -> None:
model, (view,), base = _single_non_contiguous_view(device)
wrong_if_storage_order_used = model(base.reshape(view.shape))
expected = model(view)
assert not torch.allclose(wrong_if_storage_order_used, expected)

View File

@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
class ArgsortStableDuplicatesModel(torch.nn.Module):
"""Tests deterministic duplicate ordering for exported argsort."""
"""Tests deterministic duplicate ordering for exported argsort.
``idx_dtype`` parameterizes the integer dtype of the returned indices so
the test can verify dtype preservation across luminal's int dtype paths
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
lets us drive the same model toward int32 or int64 outputs.
"""
SORT_DIM = 1
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.argsort(x, dim=self.SORT_DIM)
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
class TinyMoERoutingModel(torch.nn.Module):
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
group_ids) to the requested dtype so the test can sweep int32 and int64
output paths (LUM-486). Internal indices stay int64 because torch.gather
/ torch.scatter require int64 index tensors.
"""
TOP_K = 2
ROUTING_DIM = -1
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
DISPATCH_ON = 1
GROUP_SIZE = 2
def __init__(self) -> None:
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
self.register_buffer(
"expert_scale",
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
routing_sign = torch.sign(masked_values)
return (
routed_indices,
routed_indices.to(self.idx_dtype),
masked_values,
dispatch,
dispatch.to(self.idx_dtype),
inactive_mask,
group_ids,
group_ids.to(self.idx_dtype),
routing_sign,
)
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
return self.conv(x)
class Conv1dFloorDivPositionalModel(torch.nn.Module):
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
self.conv2 = torch.nn.Conv1d(
16, 16, kernel_size=3, stride=2, padding=1, bias=True
)
self.position = torch.nn.Parameter(torch.randn(15, 16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.gelu(self.conv1(x))
x = torch.nn.functional.gelu(self.conv2(x))
x = x.squeeze(0).transpose(0, 1)
return x + self.position
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""

View File

@@ -0,0 +1,138 @@
"""Regression coverage for torch.compile mutation and alias contracts.
PyTorch backends are expected to preserve the semantics of the traced graph.
After torch.export functionalization, input mutations are represented as
leading mutation outputs before user outputs. Luminal currently treats every
compiled graph output as a user output and also materializes inputs at the
boundary, so caller-visible mutation and aliasing semantics are not preserved.
"""
import pytest
import torch
from luminal import luminal_backend
class MutateInputThenCompute(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x * 2.0
class MutateInputReturnAlias(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x
class MutateOverlappingInputAlias(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x.add_(10.0)
return y * 2.0
class ReturnInputView(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.t()
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
model = MutateInputThenCompute()
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
actual_input = expected_input.clone()
expected = model(expected_input)
compiled = torch.compile(model, backend=backend)
actual = compiled(actual_input)
assert torch.equal(actual, expected)
assert torch.equal(actual_input, expected_input)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
model = MutateInputReturnAlias()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
_assert_same_storage(out, x)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
model = ReturnInputView()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal currently treats functionalized input-mutation outputs as user "
"outputs and does not copy mutation outputs back to caller inputs."
),
)
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
model = MutateInputThenCompute().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
expected_out = expected_x * 2.0
assert torch.equal(out, expected_out)
assert torch.equal(x, expected_x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal does not preserve caller-visible overlapping input aliasing "
"when one aliased input is mutated."
),
)
def test_luminal_overlapping_input_alias_mutation_contract(
device: torch.device,
) -> None:
model = MutateOverlappingInputAlias().to(device)
eager_base = torch.arange(6, dtype=torch.float32, device=device)
expected = model(eager_base[:4], eager_base[1:5])
base = torch.arange(6, dtype=torch.float32, device=device)
compiled = torch.compile(model, backend=luminal_backend)
actual = compiled(base[:4], base[1:5])
assert torch.equal(actual, expected)
assert torch.equal(base, eager_base)
@pytest.mark.xfail(
strict=True,
reason="Luminal materializes returned input views instead of preserving aliasing.",
)
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
model = ReturnInputView().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)

File diff suppressed because it is too large Load Diff

View File

@@ -315,8 +315,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -462,8 +462,13 @@ fn hlir_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
let hidden_exp = hidden.unsqueeze(2);
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}

View File

@@ -246,8 +246,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -287,8 +287,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -287,7 +287,8 @@ impl QwenMoE {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}
@@ -385,8 +386,13 @@ fn attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -174,8 +174,13 @@ fn decoder_self_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total, ..));
let v_full = v_cache_out.slice((.., ..total, ..));
let mut k_full = k_cache_out.slice((.., ..total, ..));
let mut v_full = v_cache_out.slice((.., ..total, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
k_full.shape.dims[1] = total;
v_full.shape.dims[1] = total;
let q = split_heads(q);

View File

@@ -11,6 +11,7 @@ impl Add for GraphTensor {
type Output = GraphTensor;
fn add(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to add tensors. Got {:?} and {:?}",
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
type Output = GraphTensor;
fn mul(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(
self.dims(),
rhs.dims(),
"Dims must match to multiply tensors."
);
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
@@ -474,6 +480,42 @@ pub(super) mod tests {
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
}
#[test]
#[should_panic(expected = "Dims must match to add tensors.")]
fn test_add_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a + b;
}
#[test]
#[should_panic(expected = "Dims must match to multiply tensors.")]
fn test_mul_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a * b;
}
#[test]
#[should_panic(expected = "Dims must match to mod tensors.")]
fn test_mod_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a % b;
}
#[test]
#[should_panic(expected = "Dims must match to lt tensors.")]
fn test_lt_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a.lt(b);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -557,6 +599,27 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_mod_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
test_binary_transforms(
size,
(),
|a, b| a % b.expand_rhs(a.shape),
|a, b| {
let lhs = a.to_vec1::<f32>().unwrap();
let rhs_scalar = b.to_scalar::<f32>().unwrap();
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
},
identity,
shift_from_zero,
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -570,6 +633,28 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_lt_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS for `lt`.
test_binary(
size,
(),
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|a, b| {
let scalar = b.to_scalar::<f32>().unwrap();
let lhs = a.to_vec1::<f32>().unwrap();
let result: Vec<f32> = lhs
.iter()
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
.collect();
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
},
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]

View File

@@ -467,7 +467,7 @@ impl GraphTensor {
let mut win = Vec::with_capacity(n);
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
let effective_window = *d * (*k - 1) + 1;
win.push(((*dim - effective_window) / s) + 1);
win.push((*dim - effective_window).floor_div(s) + 1);
}
// [win..., kernel...]
@@ -905,6 +905,14 @@ mod tests {
);
}
#[test]
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
let mut cx = Graph::new();
let inp = cx.tensor((80, 3000));
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
}
#[test]
fn test_unsqueeze() {
let mut cx = Graph::new();

View File

@@ -455,6 +455,31 @@ impl Expression {
terms.push(Term::CeilDiv);
Expression::new(terms)
}
/// Floor Division
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == 1 {
return self;
}
if self == 0 {
return 0.into();
}
if self == rhs {
return 1.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
&& let Some(c) = floor_div_i64(a, b)
{
return c.into();
}
// Shape dimensions are non-negative, so the existing integer Div term
// evaluates with floor semantics for dynamic shape expressions.
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Div);
Expression::new(terms)
}
/// Less than
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
@@ -654,6 +679,16 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
depth == 1
}
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
let q = a.checked_div(b)?;
let r = a.checked_rem(b)?;
if r != 0 && ((r > 0) != (b > 0)) {
q.checked_sub(1)
} else {
Some(q)
}
}
impl From<Term> for Expression {
fn from(value: Term) -> Self {
Expression::new(vec![value])
@@ -994,8 +1029,12 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
impl std::iter::Product for Expression {
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
// Empty product is the multiplicative identity, 1 — not 0. Returning
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
let Some(mut p) = iter.next() else {
return 0.into();
return 1.into();
};
for n in iter {
p *= n;
@@ -1106,6 +1145,27 @@ mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_empty_product_is_one() {
// The empty product (e.g. for a rank-0 tensor's shape) must be the
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
// emitters use `shape.iter().product()` to compute `numel`, and a
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
// launch with grid=(0, 1, 1) and crash at runtime.
let empty: Vec<Expression> = vec![];
assert_eq!(
empty.into_iter().product::<Expression>(),
Expression::from(1)
);
}
#[test]
fn test_empty_sum_is_zero() {
// Sanity check the additive identity stays 0 (it always was).
let empty: Vec<Expression> = vec![];
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
}
#[test]
fn test_basic_simplifications() {
let x = expr('x');