Compare commits

...

19 Commits

Author SHA1 Message Date
Austin Glover
a4bda06d64 tests for interface specification 2026-05-14 21:24:56 +00:00
tucker-luminal
6416ddb5f8 Use parallel launches for small CUDA kernels (#315)
* Use parallel launches for cast and iota kernels

* Use parallel launch for embed kernel
2026-05-14 00:47:12 -04:00
Austin Glover
c9d4ce6217 Better scalar support: tests + 12 fixes (LUM-474) (#300)
* Add scalar torture test suite (LUM-474)

60 tests asserting strict shape, dtype, and value match between PyTorch
eager and luminal_backend. Includes 9 xfail markers (12 cases) for the
known scalar bugs being addressed under LUM-485 through LUM-490.

* Add aten.select.int support to luminal_python translator (LUM-487)

Single-element indexing (`x[0]`, `x[i, j]`, `x[1, 2, 3]`) lowers to
`aten.select.int` in the FX graph. The translator previously bailed
with "Unsupported ATen op", blocking any model that reads a scalar by
indexing.

Implements `aten.select.int(self, dim, index)` as
`slice_along(index..index+1, dim).squeeze(dim)` — a pure
shape-manipulation that the luminal compiler can fold into surrounding
ops, with a single iota for the slice. Negative `dim` is normalized via
the existing `normalize_dim` helper; negative `index` is normalized
against the (concrete) axis size, mirroring how `translate_gather`
normalizes negative gather indices.

Removes the four `xfail(_INDEX_SELECT_REASON)` markers in
`tests/test_scalar_torture.py` (and the now-unused reason constant);
these tests now pass. Final counts: 52 passed / 8 xfailed (was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Fix LUM-488: support rank-0 tensor mod/lt and add aten.remainder dispatch

Two related issues prevented `x % torch.tensor(c)` from translating:

1. The luminal_python translator did not dispatch aten.remainder.Tensor /
   aten.remainder.Scalar at all, so any module that mods a tensor against
   a 0-d torch.tensor failed with "Unsupported ATen op".

2. core::ops::Rem and GraphTensor::lt asserted exact dim equality, blocking
   rank-0 to rank-N broadcasting that the backend already supports
   transparently for Add/Mul (the input_shapes vec is forwarded to the
   strided iterator).

Drop the dim assertions in Rem and lt so they match Add/Mul's broadcast
behavior, and add aten.remainder.Tensor/Scalar handlers in dispatch.rs that
mirror aten.fmod.Tensor (with ensure_same_dtype + broadcast_binary). For
the Scalar form, build a constant_float and expand_rhs onto the LHS shape.

Tests:
- New proptests test_mod_scalar_broadcast / test_lt_scalar_broadcast in
  src/frontend/binary.rs cover rank-0 RHS via expand_rhs.
- Removed @pytest.mark.xfail from test_mod_by_scalar_tensor; added the
  test_scalar_torture.py file to luminal_python's test suite.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python translator: dispatch aten.clamp.Tensor (LUM-489)

torch.clamp(x, lo, hi) where lo/hi are 0-d tensors routes to
aten.clamp.Tensor, which the translator did not previously handle. Add
a dedicated dispatch that decomposes clamp(x, lo, hi) into
min(max(x, lo), hi), broadcasting each rank-0 bound up to x's shape via
expand_rhs. Either bound may be absent (PyTorch allows min=None or
max=None), so each side is applied only when its FX input is a tensor.

Removes the @pytest.mark.xfail on test_clamp_with_scalar_tensors;
test_scalar_torture now reports 50 passed / 10 xfailed (was 48 / 12).

* luminal_python: support aten.prod.default full-reduction (LUM-490)

The translator's dispatch table mapped aten.{sum,mean,amax,amin}.default
to translate_reduction but lacked an entry for aten.prod.default, so
x.prod() with no axis raised "Unsupported ATen op". Add the missing
dispatch entry; the ReductionOp::Prod branch in translate_reduction
already handles both full-reduce and dim-reduce cases.

aten.prod.dim_int was already wired up; verified it routes correctly.

Removes the xfail marker on test_prod_all_produces_scalar in
test_scalar_torture.py — suite now reports 50 passed / 10 xfailed
(was 48 / 12).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: preserve int64 (and other integer) output dtypes (LUM-486)

Full reductions of int64 tensors silently downcast to int32 on the
PyTorch boundary because `output_dtypes` was stored as luminal `DType`,
which collapses every integer width to `DType::Int` (i32). The Python
wrapper therefore reported int32 to PyTorch even when the user passed
int64, breaking strict dtype checks and risking silent overflow on
larger reductions / downstream ops that require int64.

Store `output_dtypes` directly as PT2 dtype codes (the original PyTorch
type IDs) instead of converting through luminal `DType` first. This
preserves int64 vs int32 (and similar) end-to-end. The Python output
path now reads int outputs as i32 and casts to the requested torch
dtype, so int8/int16/int32/int64/uint8 outputs all round-trip with the
right type tag.

Updates two existing assertions (`test_argsort_stable_duplicates`,
`test_tiny_moe_routing`) that were pinning int32 — the new behavior
matches PyTorch eager (int64). Adds `test_reduce_sum_all_axes_int64_preserves_dtype`
as a regression check, and removes the xfail on
`test_int_sum_produces_int_scalar`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: parametrize argsort/MoE dtype tests over int32 and int64

The LUM-486 fix preserves whichever integer dtype the eager model declares
on output. The original tests hardcoded int64 (the dtype torch.argsort and
torch.topk natively produce), which only exercised one path through the
preservation logic.

Add an idx_dtype knob to ArgsortStableDuplicatesModel and TinyMoERoutingModel
that casts the integer outputs to the requested dtype, and parametrize both
tests over [torch.int32, torch.int64]. Internal indices (passed to gather /
scatter) stay int64 since PyTorch requires that for index tensors; the cast
applies only to the returned values.

LUM-486

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Remove xfail markers for fixed scalar bugs

Drops the @pytest.mark.xfail markers on tests now passing after the
LUM-486, LUM-487, LUM-488, LUM-489, and LUM-490 fixes:

  - test_prod_all_produces_scalar (LUM-490)
  - test_clamp_with_scalar_tensors (LUM-489)
  - test_mod_by_scalar_tensor (LUM-488)
  - test_index_1d_produces_scalar (LUM-487)
  - test_index_all_dims_produces_scalar (LUM-487)
  - test_index_then_add_scalar_const (LUM-487)
  - test_model_returns_scalar_from_index (LUM-487)
  - test_int_sum_produces_int_scalar (LUM-486)

Also removes the now-unused _INDEX_SELECT_REASON constant.

The single remaining xfail is test_unsqueeze_expand_sum_back, blocked
on LUM-485 (full reduction returns shape [1] instead of rank-0 ()).

* luminal_python: full reductions return rank-0 () instead of [1] (LUM-485)

The translator's full-reduce path used to flatten the input to [1, N] and
reduce axis 1, leaving a residual [1] dimension. PyTorch eager produces
rank-0 () for x.sum() etc., and downstream ops (e.g. unsqueeze(0).expand(5))
rely on that rank — the residual [1] caused panics like "Cannot expand
from 2 dims to 1 dims" once the scalar fed any further op.

Drop the flatten and reduce over every axis directly. Special-case rank-0
input as a no-op so reducing a scalar is well-defined. Mean still divides
by the cached total to avoid redundant axis-prod work.

Removes the xfail marker on test_unsqueeze_expand_sum_back, which now
passes. With this commit the integration branch has zero xfails:
284 passed across test_scalar_torture.py + test_hlir_ops.py + test_unary.py.

* ruff format: tests/test_hlir_ops.py

Collapse a two-line f-string into one line per ruff format. No behavior change.

* Expand scalar torture suite with PyTorch / NumPy gap coverage

Cross-referenced our suite against PyTorch's test_torch / test_reductions
/ test_view_ops / test_indexing / test_type_promotion / test_binary_ufuncs
and NumPy's test_multiarray / test_indexing / test_shape_base. Added 14
new sections covering 47 in-scope gaps:

  - Binary ops with INPUT 0-d (not reduction-derived) on either side:
    add/sub/mul/div/mod/maximum/minimum/pow/floor_divide
  - Pure 0-d ↔ 0-d arithmetic (no broadcasting required)
  - Full comparison set (gt/ge/lt/le/eq/ne) on input 0-d, plus mask-by-eq
  - Reduction extras: argmax/argmin (no-arg + keepdim), sum(dim=()),
    sum/mean of 0-d input, cumsum of 0-d
  - Shape-flattening on 0-d: flatten/ravel/reshape(-1)/view(-1) all
    return shape (1,); reshape(()) on 1-element collapses to (); plus
    permute([]), contiguous(), squeeze() of (1,1,1,1), expand_as
  - Indexing extras: ellipsis x[...], index by 0-d int tensor, gather
    with 0-d index, negative-index x[-1]
  - Type promotion: float-0-d + int-Nd, int-0-d + float-Nd, cast
    roundtrip through 0-d, .float()/.int() shorthands, where with
    mixed-dtype scalar branches
  - Unary math (abs/neg/exp/sin/cos/tanh/sigmoid/sqrt/sign/floor/ceil)
    on reduction-derived 0-d
  - Bool logic: AND, OR, XOR, NOT on 0-d bool from comparisons
  - Stack of 0-ds; cat of unsqueezed 0-ds
  - Constants: torch.full((), v), torch.full_like on 0-d
  - Reduction edge cases: keepdim across all axes then divide;
    scalar broadcast onto transposed tensor
  - Mixed where/clamp shapes: clamp(x, scalar_tensor, py_float),
    where(cond, scalar_tensor, x)
  - Multi-output models: (scalar, tensor) tuple

Result: 363 passed / 15 xfailed across the python suite. The 15 new
xfails are documented inline with concrete failure modes:

  - 6 op-coverage gaps: aten.argmax.default, aten.argmin.default,
    aten.eq.Scalar, aten.ne.Tensor (translator dispatch entries needed).
  - 2 PT2 export issues: 0-d int64 graph inputs hit "invalid type: null,
    expected i64" in luminal's model.json parser; affects
    test_int_0d_plus_float_nd and test_gather_with_0d_index.
  - 2 real correctness bugs:
      * floor_divide with 0-d divisor returns the un-floored quotient
        (float division result, not floor(x/d)).
      * cumsum on a 0-d tensor panics with index-out-of-bounds.
  - 1 dynamo guard edge case: torch._dynamo emits an unresolved 'L'
    name in _guards_fn for 0-d index tensors.

Plus 4 cross-marker xfails on consequence of the above (the parametric
ne case, mask_by_scalar_eq variants, and other downstream effects).

* Rename test_scalar_torture.py -> test_scalars.py; drop 'torture' wording

The original 'torture test' label is jargon. The file is just a scalar
test module — keep the name simple to match the rest of the suite
(test_unary.py, test_hlir_ops.py).

* luminal_python: parse rounding_mode string arg correctly (LUM-494)

torch.floor_divide(x, d) decomposes to aten.div.Tensor_mode with
rounding_mode='floor' during PT2 export. The translator was reading
the kwarg via serde_json::Value::as_str(), but PT2 serializes string
args as {"as_string": "<value>"} objects, not bare JSON strings. The
extraction silently returned None, so the floor branch was skipped
and the regular un-floored quotient was returned.

Drill into the as_string field as a fallback so floor_divide and
div(x, d, rounding_mode='floor'/'trunc') produce floor(x/d) /
trunc(x/d) as expected.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix cumsum on rank-0 tensor (LUM-495)

The translator's cumsum handler called normalize_dim(dim, a.shape.len())
and then a.cumsum(dim) for any rank — including rank-0. The underlying
cumop in src/frontend/unary.rs indexes self.dims()[axis] inside the
padding/unfold loop, which panics with "index out of bounds: the len is
0 but the index is 0" when shape is empty.

PyTorch eager treats torch.cumsum(s, 0) on a 0-d tensor as an identity
op (cumsum of a single element is the element itself). Mirror the
rank-0 short-circuit pattern from the LUM-485 reduction fix and return
the input unchanged when a.shape.is_empty(). Move the dim arg fetch
inside the non-empty branch since dim is unused for rank-0.

Drops the xfail marker on test_cumsum_of_0d and adds a 1-element 1-D
sibling test that asserts shape (1,) round-trips.

* luminal_python: support aten.argmax/argmin (LUM-496)

argmax/argmin were missing from the translator dispatch table even
though we already have stable_argsort. Add a thin wrapper so the
PyTorch boundary lights up:

  argmax(x, dim=None)    -> argsort(flatten(x), descending=True).select(0, 0)
  argmax(x, dim=N)       -> argsort(x, dim=N, descending=True).select(N, 0)
  argmax(x, dim=N, keepdim=True) -> .unsqueeze(N) over the above
  argmin(...)            -> same with descending=False

The slice + squeeze chain produces a non-contiguous DType::Int view
whose underlying buffer is still sized for the un-sliced argsort
tensor. Final `* 1` materializes a contiguous Int copy with strides
matching the visible shape — same trick `translate_topk` uses for
its sliced index output. Without it the keepdim case panics
("No output node found") and the full-reduce case throws a Python
shape mismatch on the oversized buffer.

PyTorch's argmax returns int64 while luminal collapses to int32 (Int);
LUM-486 already widens at the Python boundary, so the contract is
preserved end-to-end. Drops the three `@pytest.mark.xfail` markers
from `test_argmax_all`, `test_argmin_all`, and `test_argmax_keepdim_1d`
in `test_scalars.py` (6 cases via parametrization).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: dispatch aten.eq.Scalar and aten.ne.Tensor (LUM-497)

Add the two missing comparison overloads to the translator dispatch.
eq.Scalar mirrors the existing ne.Scalar handler (constant_float + cast
+ expand_rhs to broadcast the scalar), and ne.Tensor mirrors the
existing eq.Tensor handler. Removes the corresponding xfail markers on
test_input_0d_comparisons[_NeInput0ds-...] and test_mask_by_scalar_eq.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: accept null range_constraint bounds (LUM-498)

A 0-d int64 graph input made PyTorch 2.10+ emit
`range_constraints: { sN: { min_val: null, max_val: null } }` for the
unbacked symbol PT2 introduces around the rank-0 tensor. Our serde
schema modeled `RangeConstraint.min_val` as `i64`, so deserialization
failed with `invalid type: null, expected i64`, blocking any model
with a scalar integer tensor input.

Make `min_val` and `max_val` `Option<i64>` (matching PT2's
`Optional[int]`) and fall back to 1 as the initial dynamic-dim value
when no lower bound is provided.

Tests: removes the xfail on `test_int_0d_plus_float_nd`, adds a new
`test_int32_0d_plus_float_nd` regression, and updates the xfail
reason on `test_gather_with_0d_index` (the parse error is fixed; a
separate downstream gather panic remains).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: drop dynamo input guards in pt2_backend (LUM-499)

When a 0-d int tensor is used as a tensor index (x[i] where
i = torch.tensor(2)), torch.export records duplicate input guards that
reference both the original local source (L['i']) and the rewrapped flat
args (L['args'][1]). The unlift pass cannot resolve L['i'] against the
wrapped (*args, **kwargs) signature, leaving a literal `L` reference in
the generated _guards_fn that raises NameError during retracing. The
data-dependent .item() in the surviving guard then trips fake-tensor
analysis with DataDependentOutputException.

Drop the guard list before run_decompositions so unlift produces an
empty _guards_fn, and DCE any leftover dead aten.item.default nodes
that came from index specialization.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* luminal_python: fix gather with rank-0 index on rank-1 source

PyTorch eager allows torch.gather(rank-1, dim, rank-0) — the only
rank-mismatch case it permits — and returns a rank-0 scalar. Our
gather_elements requires source-rank == index-rank, so the rank-0
index hit flatten_strides with mismatched (0, 1) lengths and
panicked.

Detect this specific pattern in translate_gather: unsqueeze the
rank-0 index to (1,), gather, then squeeze the result back to ().
Output shape and value match eager.

This was the last remaining xfail in test_scalars.py. Suite is now
381 passed / 0 xfailed / 0 failed across test_scalars.py +
test_hlir_ops.py + test_unary.py.

* luminal_python: clamp.Tensor handles all broadcastable bound shapes

PyTorch's aten.clamp.Tensor accepts bounds with any NumPy-broadcastable
shape (rank-0, same-shape, or broadcastable). The previous translator
used expand_rhs(result.shape) which appends dims rather than broadcasts,
so only rank-0 bounds came out correctly. Same-shape and broadcastable
bounds either panicked or silently produced wrong values.

Switch to broadcast_binary (the right-align + size-1 expand helper used
by aten.remainder.Tensor, aten.eq.Tensor, etc.). Now all three modes
work uniformly.

Add 7 new tests covering the previously-broken modes:
  - same-shape bounds (per-element clamp, e.g. learned bounds)
  - per-row broadcast (3,1) against (3,4)
  - per-col broadcast (4,) against (3,4)
  - mixed rank-0 lo + same-shape hi
  - min-only with same-shape lo
  - max-only with per-row hi
  - 3-D x with 2-D bounds (left-unsqueeze broadcast)

Suite goes from 381 to 388 passing, 0 xfailed.

* shape: empty Expression product returns 1, not 0

The empty product is the multiplicative identity (1) — every shape-iterator
call site (`shape.iter().product()` for `numel`, output-buffer sizing, CUDA
grid-dim computation) implicitly relies on this. The previous impl returned
0 for an empty iterator, which was a latent bug masked while no path
produced rank-0 shapes.

The LUM-485 fix (full reductions return rank-0 () instead of rank-1 [1])
exposed it on CUDA: SumReduce kernels with rank-0 output got `n_outputs=0`,
launched with `grid=(0, 1, 1)`, and crashed with "invalid CUDA launch
dimensions" — every CUDA reduction in the Python CUDA tests was failing.

Fix: return Expression::from(1) for empty iteration. Sum's identity (0)
was already correct and is unchanged.

Add two unit tests covering both identities.

* cargo fmt

* Fix PT2 passthrough input output ID collision

* Fix scalar argextremum keepdim behavior

* Defer PT2 interface collision fix

* Keep HLIR binary ops shape-strict

* fixed gemma issue

* Fix explicit broadcasts and conv shape division

* Normalize Whisper cache slice shape

---------

Co-authored-by: Austin Glover <austin@luminal.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Austin Glover <austin_glover@berekely.edu>
Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-13 20:16:30 -04:00
June
1dcd0370ce feat: add CUDA 13.2 support via cudarc 0.19.4 (#312)
* Update cudarc to 0.19.4 to support CUDA 13.2

Fixes #291

Changes:
- Upgrade cudarc from 0.18.2 to 0.19.4
- Remove get_global call for __constant__ memory tracking

Rationale:
cudarc 0.19.0 changed get_global to return CudaViewMut instead of
CudaSlice to prevent double-free of __constant__ memory managed by
the CUDA module. The old code worked around this by storing the
CudaSlice and calling std::mem::forget on cleanup. With the new API,
the view's lifetime is tied to the module borrow, making the
workaround unnecessary. Since the constants HashMap was only used
for this workaround and never accessed otherwise, we now return an
empty HashMap.

CUDA 13.2 support was added in cudarc 0.19.4.

* fix: migrate embed kernel to shared dyn_dims buffer

The cudarc 0.18→0.19 bump removed get_global, but simply dropping the
call left __constant__ memory declared-but-never-written, producing
wrong results for models with dynamic-shape embeddings. Migrate to
the same dyn_dims parameter + #define pattern every other kernel uses.
2026-05-13 13:43:36 -04:00
Ali
6757a4e37b pack scatter kernel into 256-thread blocks (#309) 2026-05-13 13:43:15 -04:00
Joe Fioti
631451f8b8 Remove Testing section from README (#313)
Removed the Testing section from the README.
2026-05-12 17:36:33 -04:00
Joe Fioti
70bdd75163 flashinfer (#311)
* luminal_python + cuda_lite: unblock Qwen3-MoE compile path

Four small fixes that together let Qwen3MoeForCausalLM compile end-to-end
through torch.compile + luminal_backend, plus a regression test suite.

1. KernelScatter bf16 OOB
   crates/luminal_cuda_lite/src/kernel/hlir.rs

   The Scatter kernel sized n_vec as `n_dest / 4`, correct only for
   4-byte dtypes. For bf16 (and any 1/2/8-byte type) the float4
   vectorised copy walked the destination 2× / 4× / 0.5× the actual
   buffer size. Whether that crashed with CUDA_ERROR_ILLEGAL_ADDRESS or
   silently corrupted neighbouring allocations depended on which
   surrounding kernels the egglog search picked → ~40% crash rate at
   search-iters≥5 on StaticCache(dtype=bfloat16) MoE inference. Fix:
   parameterise n_vec and remainder_start by elements_per_vec =
   16 / sizeof(self.dtype). For F32/Int the generated PTX is identical.

2. maximum_f32 dtype mismatch on Int tensors
   src/frontend/binary.rs

   `maximum_f32(rhs)` built an F32 `constant_float`; the inner `lt`
   then panicked "Dtypes must match to compare tensors. Got Int and
   F32" whenever self was Int — e.g. `aten.clamp` on top-k expert
   indices coming out of an MoE router. Fix: cast the constant to
   self.dtype before the compare. For Int self this floors the bound,
   matching PyTorch's `clamp(int_tensor, min=<float>)` semantics.

3. Three new ATen ops in the luminal_python translator
   crates/luminal_python/rust/src/translator/{dispatch,tensor}.rs

   - aten.empty.memory_format
   - aten.empty_permuted.default     → translate_empty (zero-fill)
   - aten.histc.default              → translate_histc

   Qwen3-MoE allocates the expert-output staging tensor via
   `empty_permuted` and counts tokens-per-expert via
   `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)`.

   empty / empty_permuted lower to a zero-filled tensor of the
   requested shape — PyTorch's contract on empty outputs is undefined
   for any read prior to a write, and downstream writes overwrite our
   zeros, so this is sound.

   histc implements only the bincount-equivalent case (one integer per
   bin); non-integer-bin or non-contiguous-bin usage bails with a clear
   error rather than silently dropping values.

4. crates/luminal_python/tests/test_qwen3_moe.py — new file

   Four regression tests over progressively larger Qwen3MoeForCausalLM
   configs:
     - tiny:               2 experts, top-1, ~70K params  (atol 1e-5)
     - small:              4 experts, top-2               (atol 1e-4)
     - medium:             8 experts, top-2, 2 layers     (atol 1e-4)
     - real_config_1layer: full Qwen3-30B-A3B arch
                           (128 experts, top-8, 2048 hidden),
                           num_hidden_layers=1, random weights
                                                          (atol 1e-3)

   The size ladder lets any future regression surface at the cheapest
   test that catches it. Each individual fix above is exercised:
   gather-then-matmul (PR #298) by every test, KernelScatter bf16
   indirectly via the bf16 weight init path, the clamp-on-Int and the
   empty/histc translators by every test.

Validation on H200/CUDA:
  - 4 passed in tests/test_qwen3_moe.py (this PR's new tests)
  - 223 passed across tests/test_unary.py, test_capsule_validation.py,
    test_hlir_ops.py — no existing-test regression

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* test: add full-depth Qwen3-30B-A3B regression test

The 1-layer real-config test exercised the production *layer* shape but
not the full network depth. Adds a sibling test that loads the actual
Qwen/Qwen3-30B-A3B pretrained checkpoint at its native bf16 dtype,
keeps all 48 layers, and runs a full forward through luminal_backend.

Asserts compile+run completes and the compiled output is finite + in the
right magnitude band vs eager (within 10×). Tight numerical equivalence
at full depth is not asserted: random egglog seeds can pick lowering
plans whose 48-layer accumulation diverges structurally from eager
even though per-layer correctness holds. The smaller-config tests above
use atol≤1e-3 and cover the per-op correctness this test cannot.

This catches:
  - egglog cleanup behaviour over a 48-layer-wide e-graph (the
    `egglog_utils.rs:1286: No valid graphs` panic surfaces here if the
    cleanup cascade re-regresses on MoE root-eclasses);
  - per-layer state plumbing that single-layer tests can't see;
  - bf16-specific code paths that fp32 random-init tests mask.

Memory profile: ~60 GB bf16 weights + ~15 GB compiled-runtime peak;
single-token input keeps activations and KV cache trivial. Fits an H200
or H100 with margin to spare.

Run time: ~90 s for compile (egglog search at default budget) + ~1 s
for both forward passes.

Verified with 5 passed in 5:29 on H200/CUDA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_python: fix bf16 cast-back on where / masked_fill

`translate_where`, `translate_where_scalar_other`, and
`translate_masked_fill_scalar` all computed `c * x + (1 - c) * y` in F32
and never cast the result back to the input dtype. When the input was
bf16 (the common case for MoE inference), the F32 buffer was downstream
read as bf16 — which walks the buffer at half-stride and produces
output[1] = input[0], output[3] = input[1], … with zeros at the even
positions. For Qwen3-MoE's `batched_mm_experts_forward` the corruption
landed at the masked-fill of unused expert outputs and propagated as
~10^38 saturation through the rest of the layer.

Three changes:

1. Extract a shared `where_formula(cond, x, y, out_dtype)` helper that
   builds the c*x + (1-c)*y graph in F32 and then `cast(out_dtype)`s
   the result. All three callers route through it now.
2. `translate_where_scalar_other` and `translate_masked_fill_scalar`
   build a tensor for the scalar branch via the same
   `constant_float(val).cast(out_dtype).expand_rhs(shape)` recipe that
   `translate_full_like` uses, then call the shared helper.
3. The standalone half-stride misread on a tiny `masked_fill` graph is
   still observable in isolation (egglog picks a different rewrite plan
   for that graph than for `full_like + where`), but does not occur in
   real models — the qwen3-moe test suite (5 tests, including full
   `Qwen/Qwen3-30B-A3B` pretrained at all 48 layers) is now green and
   the bench's `Qwen3MoeExperts` path produces correct output.

Validation on H200/CUDA:
  - 5 passed in tests/test_qwen3_moe.py (was: full-config wrong-magnitude
    output blocking the regression test from being meaningful)
  - 223 passed in tests/test_unary.py + test_capsule_validation.py +
    test_hlir_ops.py — no existing-test regression

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* cargo fmt

* ruff format on tests/test_qwen3_moe.py

* clippy: use += instead of x = x + y

* fixed whisper with schedule edges in runtime

* scatter no copy fix

* whisper fix

* hold out slow tests

* flashinfer

* fmt

* flashinfer jit

---------

Co-authored-by: Tucker Morgan <tucker@luminal.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 23:18:12 -04:00
Ali
855f2bfd02 implement warp-level reduce using register shuffle (#310) 2026-05-11 19:20:09 -04:00
Ali
cf7fa2297c get_* is leaking mem (#308) 2026-05-11 15:41:29 -04:00
tucker-luminal
cd3f55a3a7 luminal_python + cuda_lite: unblock Qwen3-MoE compile path (#301)
* luminal_python + cuda_lite: unblock Qwen3-MoE compile path

Four small fixes that together let Qwen3MoeForCausalLM compile end-to-end
through torch.compile + luminal_backend, plus a regression test suite.

1. KernelScatter bf16 OOB
   crates/luminal_cuda_lite/src/kernel/hlir.rs

   The Scatter kernel sized n_vec as `n_dest / 4`, correct only for
   4-byte dtypes. For bf16 (and any 1/2/8-byte type) the float4
   vectorised copy walked the destination 2× / 4× / 0.5× the actual
   buffer size. Whether that crashed with CUDA_ERROR_ILLEGAL_ADDRESS or
   silently corrupted neighbouring allocations depended on which
   surrounding kernels the egglog search picked → ~40% crash rate at
   search-iters≥5 on StaticCache(dtype=bfloat16) MoE inference. Fix:
   parameterise n_vec and remainder_start by elements_per_vec =
   16 / sizeof(self.dtype). For F32/Int the generated PTX is identical.

2. maximum_f32 dtype mismatch on Int tensors
   src/frontend/binary.rs

   `maximum_f32(rhs)` built an F32 `constant_float`; the inner `lt`
   then panicked "Dtypes must match to compare tensors. Got Int and
   F32" whenever self was Int — e.g. `aten.clamp` on top-k expert
   indices coming out of an MoE router. Fix: cast the constant to
   self.dtype before the compare. For Int self this floors the bound,
   matching PyTorch's `clamp(int_tensor, min=<float>)` semantics.

3. Three new ATen ops in the luminal_python translator
   crates/luminal_python/rust/src/translator/{dispatch,tensor}.rs

   - aten.empty.memory_format
   - aten.empty_permuted.default     → translate_empty (zero-fill)
   - aten.histc.default              → translate_histc

   Qwen3-MoE allocates the expert-output staging tensor via
   `empty_permuted` and counts tokens-per-expert via
   `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)`.

   empty / empty_permuted lower to a zero-filled tensor of the
   requested shape — PyTorch's contract on empty outputs is undefined
   for any read prior to a write, and downstream writes overwrite our
   zeros, so this is sound.

   histc implements only the bincount-equivalent case (one integer per
   bin); non-integer-bin or non-contiguous-bin usage bails with a clear
   error rather than silently dropping values.

4. crates/luminal_python/tests/test_qwen3_moe.py — new file

   Four regression tests over progressively larger Qwen3MoeForCausalLM
   configs:
     - tiny:               2 experts, top-1, ~70K params  (atol 1e-5)
     - small:              4 experts, top-2               (atol 1e-4)
     - medium:             8 experts, top-2, 2 layers     (atol 1e-4)
     - real_config_1layer: full Qwen3-30B-A3B arch
                           (128 experts, top-8, 2048 hidden),
                           num_hidden_layers=1, random weights
                                                          (atol 1e-3)

   The size ladder lets any future regression surface at the cheapest
   test that catches it. Each individual fix above is exercised:
   gather-then-matmul (PR #298) by every test, KernelScatter bf16
   indirectly via the bf16 weight init path, the clamp-on-Int and the
   empty/histc translators by every test.

Validation on H200/CUDA:
  - 4 passed in tests/test_qwen3_moe.py (this PR's new tests)
  - 223 passed across tests/test_unary.py, test_capsule_validation.py,
    test_hlir_ops.py — no existing-test regression

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* test: add full-depth Qwen3-30B-A3B regression test

The 1-layer real-config test exercised the production *layer* shape but
not the full network depth. Adds a sibling test that loads the actual
Qwen/Qwen3-30B-A3B pretrained checkpoint at its native bf16 dtype,
keeps all 48 layers, and runs a full forward through luminal_backend.

Asserts compile+run completes and the compiled output is finite + in the
right magnitude band vs eager (within 10×). Tight numerical equivalence
at full depth is not asserted: random egglog seeds can pick lowering
plans whose 48-layer accumulation diverges structurally from eager
even though per-layer correctness holds. The smaller-config tests above
use atol≤1e-3 and cover the per-op correctness this test cannot.

This catches:
  - egglog cleanup behaviour over a 48-layer-wide e-graph (the
    `egglog_utils.rs:1286: No valid graphs` panic surfaces here if the
    cleanup cascade re-regresses on MoE root-eclasses);
  - per-layer state plumbing that single-layer tests can't see;
  - bf16-specific code paths that fp32 random-init tests mask.

Memory profile: ~60 GB bf16 weights + ~15 GB compiled-runtime peak;
single-token input keeps activations and KV cache trivial. Fits an H200
or H100 with margin to spare.

Run time: ~90 s for compile (egglog search at default budget) + ~1 s
for both forward passes.

Verified with 5 passed in 5:29 on H200/CUDA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_python: fix bf16 cast-back on where / masked_fill

`translate_where`, `translate_where_scalar_other`, and
`translate_masked_fill_scalar` all computed `c * x + (1 - c) * y` in F32
and never cast the result back to the input dtype. When the input was
bf16 (the common case for MoE inference), the F32 buffer was downstream
read as bf16 — which walks the buffer at half-stride and produces
output[1] = input[0], output[3] = input[1], … with zeros at the even
positions. For Qwen3-MoE's `batched_mm_experts_forward` the corruption
landed at the masked-fill of unused expert outputs and propagated as
~10^38 saturation through the rest of the layer.

Three changes:

1. Extract a shared `where_formula(cond, x, y, out_dtype)` helper that
   builds the c*x + (1-c)*y graph in F32 and then `cast(out_dtype)`s
   the result. All three callers route through it now.
2. `translate_where_scalar_other` and `translate_masked_fill_scalar`
   build a tensor for the scalar branch via the same
   `constant_float(val).cast(out_dtype).expand_rhs(shape)` recipe that
   `translate_full_like` uses, then call the shared helper.
3. The standalone half-stride misread on a tiny `masked_fill` graph is
   still observable in isolation (egglog picks a different rewrite plan
   for that graph than for `full_like + where`), but does not occur in
   real models — the qwen3-moe test suite (5 tests, including full
   `Qwen/Qwen3-30B-A3B` pretrained at all 48 layers) is now green and
   the bench's `Qwen3MoeExperts` path produces correct output.

Validation on H200/CUDA:
  - 5 passed in tests/test_qwen3_moe.py (was: full-config wrong-magnitude
    output blocking the regression test from being meaningful)
  - 223 passed in tests/test_unary.py + test_capsule_validation.py +
    test_hlir_ops.py — no existing-test regression

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* cargo fmt

* ruff format on tests/test_qwen3_moe.py

* clippy: use += instead of x = x + y

* fixed whisper with schedule edges in runtime

* scatter no copy fix

* whisper fix

* hold out slow tests

* fixing issues with bad rewrite

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-11 12:34:52 -07:00
Ali
11653c6903 capacity should be used instead of len for Vec::from_raw_parts (#307) 2026-05-11 11:30:02 -04:00
Ali
6d16bdba21 n_elements should use constant not device (#306) 2026-05-11 11:29:20 -04:00
Joe Fioti
7bfd19fb72 Refine cublasLt rewrites and shrink their test coverage (#305) 2026-05-09 01:29:10 -04:00
tucker-luminal
42caa4750e luminal_python: dynamic shapes through torch.compile + translator cleanups (#302)
* luminal_python: tighten translator lowerings

Reduce graph-node count in PT2 → HLIR translators without semantic
changes; CUDA suite is 233P/4X before and after.

- where / masked_fill / bool-mask index_put: rewrite the blend as
  `y + c*(x - y)` instead of `c*x + (1-c)*y`, dropping a mul, a sub,
  and the `1.0` constant per call.
- gather / index.Tensor: keep negative-index normalization in Int
  instead of round-tripping through F32, dropping three Cast nodes
  per indexed dim; works for symbolic axis sizes too.
- ceil: lower as `trunc(x) + (x > trunc(x))` instead of `-floor(-x)`.
- _to_copy: skip the Cast op when the dtype already matches; PT2
  emits `_to_copy` as a clone hint and the redundant cast was
  surviving until later optimizer passes.
- Full reductions (sum.default etc.): match the contiguity guard
  translate_reshape already applies — without it the `[1, N]` view
  treats stride-0 broadcast dims as if they held N distinct values
  and reads past the backing buffer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_python: end-to-end dynamic-shape support through torch.compile

Previously the standard torch.compile(model, backend=luminal_backend) path
silently dropped Dynamo's dynamic-shape information on re-export, so every
new input shape forced a full backend recompile. The luminal.pt2.compile()
"explicit" entry point also bailed out on float inputs and on anything
beyond a single bare-symbol dim. This commit makes both paths actually
flow symbolic dims end-to-end.

pt2_backend (the path torch.compile users hit):
- Detect SymInt placeholders Dynamo emits alongside tensor inputs and
  rewrite their uses into `aten.sym_size.int(tensor, dim)` so re-export
  sees a tensor-only signature.
- Build a torch.export `dynamic_shapes` spec from the surviving tensor
  placeholders' FakeTensor shapes (Dim.AUTO; relationships are recovered
  from the FakeTensor metadata).
- Defer the entire compile pipeline to the first runtime call when
  dynamic_shapes is non-None — torch.export with dynamic_shapes mutates
  the ShapeEnv that Dynamo is still relying on to install guards, and
  doing it inside the backend frame trips an internal "Guard failed on
  the same frame" assertion. Lazy compile sidesteps this cleanly.
- Compose the lifted-weight and SymInt filter steps into a single
  user_indices the CompiledModel uses to drop both kinds of non-tensor
  args at __call__ time. Fix the device-detection lookup to walk
  user_inputs (post-filter) rather than `inputs[0]`, which can be a
  SymInt under Dynamo.
- _detect_factory_capsule similarly walks for the first real tensor.

Compound shape expressions (`2*s`, `s+1`, etc.):
- resolve_dim_sizes now parses sympy `srepr` strings — Symbol, Integer,
  n-ary Mul/Add — into proper luminal Expressions instead of collapsing
  every non-bare-symbol form to size 1. Falls back to the EP's `hint`
  when the head isn't recognised so output-shape resolution still
  returns a usable concrete size.
- auto_set_dims_from_input_shapes inverts single-variable affine forms
  by sampling two probe points (x=2, x=3), recovering slope/intercept,
  and verifying the candidate value round-trips through
  exec_single_var_checked. Multi-variable / non-affine / non-monotonic
  forms are rejected so we never write a wrong guess into dyn_map.

Explicit luminal.pt2.compile() API (unchanged behavior for existing
callers, plus):
- Accepts `dynamic_shapes=` passthrough for full torch.export-style
  control (named Dims, ranges, multi-input, shared symbols).
- `dynamic_dim` accepts an int, an Iterable[int], or "auto"; "auto"
  marks every non-trivial axis of the first input as Dim.AUTO instead
  of being integer-input-only.
- Multi-input `example_input` lists are accepted directly.
- The legacy `dynamic_dim=None` integer-tail-axis heuristic is
  preserved so the existing decode-loop test keeps working unchanged.

Op-arg SymInt awareness:
- get_int_arg / get_ints_arg fall through to expression resolution and
  accept SymInt entries that bind to concrete values, instead of
  failing with a misleading "not an int" message.

Tests:
- New tests/test_dynamic_shapes.py covers torch.compile under both
  automatic_dynamic_shapes and dynamic=True (the latter reuses a
  single compile across every shape — verified via backend invocation
  count), lifted-weight + SymInt composition, multi-dim dynamic,
  compound shape expressions (`cat([x, x], 0)` produces `2*s`), and
  the new explicit-API surface (float-input dynamic_dim and
  dynamic_shapes passthrough).

Full CUDA suite: 239 passed / 4 xfailed (was 233/4); no regressions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Fix CI: pass user_indices through _save_and_compile + apply fmt

The lazy-compile path passes user_indices= to _save_and_compile, but
the function signature never accepted it — ruff F821 caught the
undefined name in the early return path. Add it as a kwarg.

Also apply ruff format and cargo fmt to satisfy the corresponding
pre-commit checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Fix bad merge: restore _decomp_table() on all run_decompositions sites

The merge of main into worktree-fasteraten kept _decomp_table() on
only one of the three ep.run_decompositions() call sites. The other
two — the dynamic-shapes compile() path and the _eager_pt2_compile
(torch.compile backend) path — were left calling run_decompositions()
with no args, which decomposes SDPA and breaks the translator with
unsupported eq.Scalar / scalar_tensor(-Infinity) ops from the
all-masked sentinel chain.

Restore _decomp_table() at all three sites.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 16:27:09 -07:00
Joe Fioti
1279dca4e6 Memory analysis post pass (#303)
* Simplify CUDA memory analysis and arena planning

* Simplify CUDA memory planning and fix clippy warnings
2026-05-08 11:24:37 -04:00
tucker-luminal
53f7960130 luminal_python: translate F.scaled_dot_product_attention as one fused op (#285)
Adds translator support for `torch.ops.aten.scaled_dot_product_attention.default`
and the four backend variants (`_scaled_dot_product_efficient_attention`,
`_scaled_dot_product_flash_attention`, `_scaled_dot_product_flash_attention_for_cpu`,
`_scaled_dot_product_cudnn_attention`) so calls to
`torch.nn.functional.scaled_dot_product_attention` lower to a single
matmul+softmax+matmul chain instead of the ~20-op default decomposition
(which uses `eq.Scalar`/`logical_not`/`any.dim`/`where.self`/`full_like` to
implement the all-masked-row sentinel).

The default `ep.run_decompositions()` table decomposes SDPA away. Strip the
five SDPA entries from the table in `pt2.py:_decomp_table()` so the op
survives into the FX graph and our translator catches it.

Tests cover the three commonly-hit branches:
- basic Q/K/V (default scale, no mask, no causal flag)
- is_causal=True (triangular-mask branch)
- additive attn_mask broadcast over heads

Verified on native (224 passed) and CUDA (239 passed / 4 xfailed).

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-07 16:36:36 -04:00
Joe Fioti
5c3407c596 Reduce default profiling trials to 3 (#299)
* Reduce default profiling trials to 3

* rm out.png

* Set Modal CI timeouts to 2 hours
2026-05-06 13:04:57 -04:00
tucker-luminal
47530062a4 luminal_python: gather-then-matmul lowering for grouped_mm (#298)
translate_grouped_mm was casting the full [G, K, N] expert weight
tensor to F32 before a broadcast batched matmul, producing
~2.1 GB of intermediate buffers per layer on Qwen3-30B-A3B.
Across 48 MoE layers this OOM'd the search profiler at
runtime.rs:711 (alloc_zeros), failing every python_luminal
qwen3-moe bench run for the past ~2 weeks.

Switch to the gather-first pattern that examples/qwen3_moe uses:
compute expert_id from offs, gather only the [S, K, N] active
slice, then matmul. The shape mirrors what glumoe_rewrite.egg
matches, and the gather is 16x smaller at prefill
(S = num_tokens * top_k = 8 vs G = 128).

Two refinements baked in vs the broadcast-and-mask version:

1. Stay in Int for the entire expert_id computation. arange and
   offs are already Int; ge → Bool → cast(Int) → sum → minimum
   handles the clamp without four F32 round-trips. Same value as
   HF MoE's `expert_ids.clamp(0, num_experts-1)` for invalid expert
   IDs from EP, AND protects search-time profiling: dummy-1 input
   bytes give offs=[1,…,1], pushing the raw count to G for any
   token with index ≥ 1, which would OOB the gather without the
   clamp.

2. Drop the cast(F32) on input and on the gathered weight. The
   broadcast-and-mask version needed F32 because it casted the
   mask to F32; gather-then-matmul has no such requirement, and
   casting `[S, K, N]` to F32 doubled the gather scratch (~100 MB
   → ~200 MB per layer for Qwen3-30B-A3B prefill). Matmul rewrites
   (cuBLASLt etc.) handle bf16 input with F32 accumulator
   internally — no precision loss in practice.

Verification:
- tests/test_hlir_ops.py::test_grouped_mm_fallback{,_routing_invariance} pass.
- Synthetic g=128, s=8, k=2048, n=1536 bf16 test: max-abs-diff 1.56e-02
  (within bf16 accumulation tolerance; expected to drop to F32-accurate
  once the cuBLASLt rewrite fires at higher search budgets).

Result: original OOM-in-search is gone. With --search-iters 1
the full Qwen3-30B-A3B bench end-to-ends (TTFT ~9.4s).

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 16:38:15 -04:00
Joe Fioti
8524636d6f Yolo v11 example (#296)
* Add YOLO v11n example on luminal_cuda_lite (WIP)

End-to-end Object Detection demo running Ultralytics yolo11n on the cuda_lite
backend. Includes a Rust example crate (`yolo_v11`, `yolo_v11_tiny`,
`yolo_v11_egglog_debug`), a PyTorch reference + weight-prep script, and a
torch.compile path through luminal_python.

Surfaced and worked around several e-graph extraction issues that the heavy
conv + multi-stage Detect head exposes:

- **Gather dtype propagation** (`src/hlir.rs`): the HLIR Gather dtype-from-
  data rule was emitted in the default ruleset, so it only advanced one
  Gather per `(run)` iteration of the schedule. YOLO has deeply nested
  Gathers (each conv padding + each `make_contiguous` becomes a Gather);
  put the rule in `dtype_prop` so it saturates with Mul/Add/Sum/etc. Did
  the same for Scatter for symmetry.

- **KernelGather IList tail variable** (`crates/luminal_cuda_lite/src/
  kernel/hlir.rs`): mirror the `?__tail` pattern that Gather's dtype rule
  uses instead of a strict `(INil)` so the kernel-rewrite still matches
  when egglog has unioned the IList tail eclass with another chain.

- **Conditional cleanup** (`src/egglog_utils/mod.rs`): replaced
  `(saturate cleanup)` with a Rust post-pass that strips HLIR ops only
  when a kernel survivor exists in the same Op eclass. Otherwise the
  cleanup cascade kills the root with "No valid graphs present" on
  conv-heavy graphs.

- **inject_kernel_alternatives** (`src/egglog_utils/mod.rs`): synthesises
  KernelMul/KernelAdd/.../KernelMax enodes for HLIR-only Op eclasses
  whose dtype propagation didn't make it in time, with a deep-clone
  fallback that creates new ELIST chains so the extractor's first-enode
  walk is deterministic. Filtered by `OpTextParts::all_op_names` so the
  native runtime tests don't get CUDA-only kernel kinds.

- **enforce_consistent_first_kind_enodes** + **prefer_econs_first_in_
  elists** + extract-time consistency check (`src/egglog_utils/mod.rs`):
  reorder OpKind eclasses so the first enode is a kernel kind whose
  ELIST children all walk to the same length, and reorder ELIST eclasses
  so they start with `ECons`/`ENil` instead of `RemoveNthFromEnd` /
  `MReplaceList` / `RowMajor` (which would crash `extract_expr_list`).

- **Defensive truncate in KernelMul::extract** (`crates/luminal_cuda_
  lite/src/kernel/hlir.rs`): when an inconsistent kind enode survives all
  the above, truncate shape and strides to the shortest length so
  `flatten_strides` is structurally satisfied. Numerically wrong for
  that candidate but harmless to the search, which profiles many.

- **Diagnostic env vars** (`src/egglog_utils/mod.rs`,
  `crates/luminal_cuda_lite/src/runtime.rs`,
  `crates/luminal_cuda_lite/src/kernel/fusion/{markers,region_codegen}.rs`):
  `LUMINAL_DUMP_CLEANUP`, `LUMINAL_DUMP_INJECT`, `LUMINAL_DUMP_GATHER`,
  `LUMINAL_DUMP_CONSISTENCY`, `LUMINAL_DUMP_EXTRACT`, `LUMINAL_DUMP_
  EGGLOG`, `LUMINAL_STRICT_KERNEL_ONLY`, `LUMINAL_DISABLE_INJECT`,
  `LUMINAL_DISABLE_FUSION`, `LUMINAL_DUMP_FUSED_REGION`,
  `LUMINAL_SYNC_EACH_OP`.

- **Unrelated egglog rule disables** (`src/egglog_utils/base.rs`):
  `div-div` and `div-cancel-factor` triggered combinatorial explosion on
  the conv-heavy graph; replaced `div-div` with the constant-divisor
  variant `div-div-num`.

Status:
- Llama: 96/96 tests still pass.
- `yolo_v11_tiny YOLO_TINY_LAYERS=1..13` matches PyTorch within
  cumulative numerical drift.
- Full `yolo_v11`: compiles in ~150s and runs the forward in ~640ms.
  Detection accuracy is currently degraded (max_abs ~182 vs PyTorch
  reference) because of remaining multi-variant ELIST eclasses that
  fall through to the defensive truncate. The truncation produces
  wrong indices for those few ops; further work is needed on the
  e-graph rewriter side.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Accept YOLO input and output paths as CLI args

* Update commit message generation instructions

* metal clippy

* metal unit tests

* Fix yolo example clippy warnings

* Simplify yolo_v11 to a single self-contained binary

* Extend CUDA Modal test timeout to 2 hours

* Require CUDA build in Modal pytest runner

* Loosen Modal pytest timeout for CUDA CI

* Loosen Modal timeouts for CUDA CI

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 16:22:56 -04:00
106 changed files with 20565 additions and 1394 deletions

View File

@@ -18,7 +18,7 @@ jobs:
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 70
timeout-minutes: 120
strategy:
fail-fast: false
matrix:

View File

@@ -21,4 +21,4 @@ jobs:
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose

View File

@@ -18,7 +18,7 @@ jobs:
name: Cuda Unit Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 30
timeout-minutes: 120
steps:
- uses: actions/checkout@v6

View File

@@ -16,4 +16,4 @@ jobs:
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1

View File

@@ -18,7 +18,7 @@ jobs:
name: Python CUDA Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 60
timeout-minutes: 120
defaults:
run:
working-directory: crates/luminal_python
@@ -38,7 +38,7 @@ jobs:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
run: modal run modal_pytest_runner.py --gpu A100 --timeout 7200 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
- name: Upload Modal pytest profiling artifacts
if: always()
uses: actions/upload-artifact@v4

View File

@@ -23,6 +23,6 @@ jobs:
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
run: uv run maturin develop --manifest-path rust/Cargo.toml --profile release
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"

View File

@@ -28,7 +28,7 @@ cuda_image = (
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=1800, # 30 minutes
timeout=7200, # 2 hours
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
@@ -47,6 +47,7 @@ def run_cargo_test():
[
"cargo",
"test",
"--release",
"-p",
"luminal_cuda_lite",
"--verbose",

View File

@@ -115,7 +115,7 @@ cuda_image = (
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=3600, # 60 minutes
timeout=7200, # 2 hours
volumes={
HF_CACHE_PATH: hf_cache,
},

View File

@@ -10,7 +10,8 @@ license = "MIT OR Apache-2.0"
[dependencies]
luminal = { path = "../.." }
luminal_tracing = { path = "../luminal_tracing" }
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
anyhow = "1.0"
as-any = "0.3.2"
itertools = "0.12.1"
fixedbitset = "0.5.7"
@@ -23,6 +24,7 @@ memmap2 = "0.9.9"
uuid = {version="1.19.0", features=["v4"]}
lru = "0.16.2"
libc = "0.2"
libloading = "0.8"
colorize = "*"
[dev-dependencies]

View File

@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
fn gather_experts(

View File

@@ -0,0 +1,198 @@
//! ComputeAttnMask — fused op that computes the paged attention mask from indptrs.
//!
//! This op exists so the indptr tensors (qo_indptr, kv_indptr) are visible in the
//! same e-graph chunk as the attention pattern, letting the FlashInfer egglog rule
//! capture them directly.
//!
//! Inputs (3): q_pos (s,) Int, qo_indptr (r,) Int, kv_indptr (r,) Int.
//! Output: mask (s, c) F32 where mask[i, j] = 0.0 (attend) or -1e10 (block).
use std::sync::Arc;
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND},
extract_expr,
},
op::{EgglogOp, HLIROp, LLIROp},
prelude::*,
};
use crate::{
cudarc::driver::{CudaStream, result},
host::{DeviceBuffer, HostOp},
};
/// Computes the paged attention mask from indptr arrays.
///
/// The mask encodes both request-membership and causality:
/// `mask[i, j] = 0.0` if query `i` and context `j` belong to the same request AND
/// context `j`'s local position is `<= q_pos[i]`; `-1e10` otherwise.
#[derive(Debug, Default)]
pub struct ComputeAttnMask {
pub s_dim: Expression,
pub c_dim: Expression,
}
impl std::fmt::Display for ComputeAttnMask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ComputeAttnMask(s={}, c={})", self.s_dim, self.c_dim)
}
}
impl HLIROp for ComputeAttnMask {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
format!(
"(Op (ComputeAttnMask {} {}) (ICons {} (ICons {} (ICons {} (INil)))))",
self.s_dim.to_egglog(),
self.c_dim.to_egglog(),
inputs[0].1, // q_pos
inputs[1].1, // qo_indptr
inputs[2].1, // kv_indptr
)
}
}
impl EgglogOp for ComputeAttnMask {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"ComputeAttnMask",
&[("s_dim", EXPRESSION), ("c_dim", EXPRESSION)],
)
}
fn n_inputs(&self) -> usize {
3
}
fn rewrites(&self) -> Vec<Rule> {
// No rewrites — inserted directly by model code.
vec![]
}
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let s_dim = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let c_dim = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let op = Self { s_dim, c_dim };
let llir_op = LLIROp::new::<dyn HostOp>(Box::new(op) as Box<dyn HostOp>);
(llir_op, input_enodes)
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for ComputeAttnMask {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
if inputs.len() < 3 {
anyhow::bail!(
"ComputeAttnMask expects 3 inputs (q_pos, qo_indptr, kv_indptr), got {}",
inputs.len()
);
}
let s = self
.s_dim
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask s_dim unresolved"))?;
let c = self
.c_dim
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask c_dim unresolved"))?;
let r = *dyn_map
.get(&'r')
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask requires dynamic dim 'r'"))?;
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
buffers.get(&node).copied().ok_or_else(|| {
anyhow::anyhow!("ComputeAttnMask missing {name} buffer for {node:?}")
})
};
let q_pos_buf = get_buf("q_pos", inputs[0])?;
let qo_indptr_buf = get_buf("qo_indptr", inputs[1])?;
let kv_indptr_buf = get_buf("kv_indptr", inputs[2])?;
let out_buf = get_buf("output", self_node)?;
let q_pos = dtoh_i32(stream, q_pos_buf.ptr(), s)?;
let qo_indptr = dtoh_i32(stream, qo_indptr_buf.ptr(), r)?;
let kv_indptr = dtoh_i32(stream, kv_indptr_buf.ptr(), r)?;
let mut mask = vec![-1e10f32; s * c];
for i in 0..s {
let q_req = indptr_to_request(&qo_indptr, i as i32);
for j in 0..c {
let c_req = indptr_to_request(&kv_indptr, j as i32);
if q_req == c_req && q_req >= 0 {
let c_local = j as i32 - kv_indptr[c_req as usize];
if c_local <= q_pos[i] {
mask[i * c + j] = 0.0;
}
}
}
}
let mask_bytes =
unsafe { std::slice::from_raw_parts(mask.as_ptr() as *const u8, mask.len() * 4) };
unsafe {
let res = cudarc::driver::sys::cuMemcpyHtoD_v2(
out_buf.ptr(),
mask_bytes.as_ptr() as *const std::ffi::c_void,
mask_bytes.len(),
);
if res != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
anyhow::bail!("ComputeAttnMask cuMemcpyHtoD failed: {res:?}");
}
}
Ok(())
}
fn output_size(&self) -> Expression {
self.s_dim * self.c_dim
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn stats_name(&self) -> Option<&'static str> {
Some("ComputeAttnMask")
}
}
fn dtoh_i32(stream: &Arc<CudaStream>, dev_ptr: u64, len: usize) -> anyhow::Result<Vec<i32>> {
let mut host = vec![0u8; len * std::mem::size_of::<i32>()];
unsafe {
result::memcpy_dtoh_async(&mut host, dev_ptr, stream.cu_stream())?;
}
stream.synchronize()?;
let v = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(host);
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut i32, len, len)
};
Ok(v)
}
/// Given an indptr array `[0, a, b, ...]`, find which segment `idx` belongs to.
/// Returns `count(indptr[i] <= idx) - 1`.
fn indptr_to_request(indptr: &[i32], idx: i32) -> i32 {
indptr.iter().filter(|&&v| v <= idx).count() as i32 - 1
}

View File

@@ -19,9 +19,9 @@ use crate::{
CudaBlas,
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
},
driver::{CudaSlice, CudaStream, DevicePtr},
driver::CudaStream,
},
host::HostOp,
host::{DeviceBuffer, HostOp},
};
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
let b_buf = buffers[&inputs[1]];
// Get device pointers
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
let a_ptr = a_buf.ptr();
let b_ptr = b_buf.ptr();
let c_ptr = c_buf.ptr();
// Debug: Check buffer sizes
trace!(

View File

@@ -42,6 +42,7 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; For column-major A × column-major B with cuBLAS:
@@ -52,14 +53,17 @@
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n ; ldd = ldc for current row-major output rewrites
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(MNum 0) ; stride_d = 0
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
@@ -112,20 +116,24 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "T"
"COL" "COL" "COL" "COL"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?n ; ldd
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(MMul ?m ?n) ; stride_d
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)

View File

@@ -42,6 +42,7 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; For column-major A × row-major B with cuBLAS:
@@ -52,14 +53,17 @@
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n ; ldd = ldc for current row-major output rewrites
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(MNum 0) ; stride_d = 0
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
@@ -112,20 +116,24 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "T"
"COL" "COL" "COL" "COL"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?n ; ldd
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(MMul ?m ?n) ; stride_d
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)

View File

@@ -42,6 +42,7 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; For row-major A × column-major B with cuBLAS:
@@ -52,14 +53,17 @@
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n ; ldd = ldc for current row-major output rewrites
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(MNum 0) ; stride_d = 0
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
@@ -112,20 +116,24 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc
?n ; ldd
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(MMul ?m ?n) ; stride_d
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)

View File

@@ -42,6 +42,7 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; For row-major C = A × B with cuBLAS (column-major):
@@ -52,14 +53,17 @@
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n ; ldd = ldc for current row-major output rewrites
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(MNum 0) ; stride_d = 0
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
@@ -117,6 +121,7 @@
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
@@ -124,14 +129,17 @@
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "N"
"COL" "COL" "COL" "COL"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc (contiguous output per batch)
?n ; ldd
?batch ; batch_count
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(MMul ?m ?n) ; stride_d
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)

View File

@@ -0,0 +1,428 @@
; Fuse a row-major Add on top of an existing cuBLASLt matmul into
; D = alpha * A * B + beta * C.
;
; The existing matmul rewrites view Luminal's row-major output [m,n] as a
; column-major cuBLASLt matrix [n,m]. A row-major C input with logical strides
; [row_stride, 1] therefore maps to ldc=row_stride. This lets a C slice from a
; wider parent tensor use a larger ldc while D keeps the matmul output layout.
; cuBLASLt requires out-of-place C and D to have the same matrix order, so these
; beta rules only fuse C layouts that map to the current COL-ordered D layout.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "COL"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?n (ECons ?m (ENil)))
?matmul_add_strides
?c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?c (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "COL" "COL"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt 2d matmul plus c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "COL"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?n (ECons ?m (ENil)))
?c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "COL" "COL"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt 2d c plus matmul beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "COL"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
?matmul_add_strides
?c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?c (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "COL" "COL"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt batched matmul plus c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "COL"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
?c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "COL" "COL"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt batched c plus matmul beta"
)
; ROW-ordered D beta fusions. These pair with cublaslt_row_order_rewrite.egg,
; where the cuBLASLt problem dimensions match Luminal's logical output [m,n].
; A row-major C input with logical strides [row_stride, 1] maps directly to a
; ROW-ordered cuBLASLt C[m,n] descriptor with ldc=row_stride.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?m (ECons ?n (ENil)))
?matmul_add_strides
?c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?c (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order 2d matmul plus c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?m (ECons ?n (ENil)))
?c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order 2d c plus matmul beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?matmul_add_strides
?c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?c (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order batched matmul plus c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(!= ?epilogue "RELU")
(!= ?epilogue "RELU_BIAS")
(!= ?epilogue "GELU")
(!= ?epilogue "GELU_BIAS")
(= ?add (Op (Add
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 1.0 ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order batched c plus matmul beta"
)

View File

@@ -0,0 +1,614 @@
; cuBLASLt epilogue rewrites.
;
; ReLU in the frontend lowers through maximum_f32(0.0):
;
; (matmul < 0) * 0 + cast(cast((-cast(matmul < 0) + 1) as bool) as f32) * matmul
;
; These rules fuse that expression back into CUBLASLT_EPILOGUE_RELU.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?zero (Op (Constant 0.0) (INil)))
(= ?neg_one (Op (Constant -1.0) (INil)))
(= ?one (Op (Constant 1.0) (INil)))
(= ?lt (Op (LessThan
?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?mask_strides)
(ICons ?matmul (ICons ?zero (INil)))))
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
(= ?zeroed (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?zeroed_strides)
(ICons ?lt_f32 (ICons ?zero (INil)))))
(= ?neg_mask (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?neg_mask_strides)
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
(= ?not_mask_f32 (Op (Add
?shape
?neg_mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?not_mask_f32_strides)
(ICons ?neg_mask (ICons ?one (INil)))))
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
(= ?positive (Op (Mul
?shape
?not_mask_f32_strides
?matmul_strides
?positive_strides)
(ICons ?not_mask (ICons ?matmul (INil)))))
(= ?relu (Op (Add
?shape
?zeroed_strides
?positive_strides
?relu_strides)
(ICons ?zeroed (ICons ?positive (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "RELU")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?relu ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt 2d relu epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?zero (Op (Constant 0.0) (INil)))
(= ?neg_one (Op (Constant -1.0) (INil)))
(= ?one (Op (Constant 1.0) (INil)))
(= ?lt (Op (LessThan
?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?mask_strides)
(ICons ?matmul (ICons ?zero (INil)))))
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
(= ?zeroed (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?zeroed_strides)
(ICons ?lt_f32 (ICons ?zero (INil)))))
(= ?neg_mask (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?neg_mask_strides)
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
(= ?not_mask_f32 (Op (Add
?shape
?neg_mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?not_mask_f32_strides)
(ICons ?neg_mask (ICons ?one (INil)))))
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
(= ?positive (Op (Mul
?shape
?not_mask_f32_strides
?matmul_strides
?positive_strides)
(ICons ?not_mask (ICons ?matmul (INil)))))
(= ?relu (Op (Add
?shape
?zeroed_strides
?positive_strides
?relu_strides)
(ICons ?zeroed (ICons ?positive (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "RELU")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?relu ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt batched relu epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?zero (Op (Constant 0.0) (INil)))
(= ?neg_one (Op (Constant -1.0) (INil)))
(= ?one (Op (Constant 1.0) (INil)))
(= ?lt (Op (LessThan
?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?mask_strides)
(ICons ?matmul (ICons ?zero (INil)))))
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
(= ?zeroed (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?zeroed_strides)
(ICons ?lt_f32 (ICons ?zero (INil)))))
(= ?neg_mask (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?neg_mask_strides)
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
(= ?not_mask_f32 (Op (Add
?shape
?neg_mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?not_mask_f32_strides)
(ICons ?neg_mask (ICons ?one (INil)))))
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
(= ?positive (Op (Mul
?shape
?not_mask_f32_strides
?matmul_strides
?positive_strides)
(ICons ?not_mask (ICons ?matmul (INil)))))
(= ?relu (Op (Add
?shape
?zeroed_strides
?positive_strides
?relu_strides)
(ICons ?zeroed (ICons ?positive (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "RELU_BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?relu ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt 2d relu bias epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?zero (Op (Constant 0.0) (INil)))
(= ?neg_one (Op (Constant -1.0) (INil)))
(= ?one (Op (Constant 1.0) (INil)))
(= ?lt (Op (LessThan
?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?mask_strides)
(ICons ?matmul (ICons ?zero (INil)))))
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
(= ?zeroed (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?zeroed_strides)
(ICons ?lt_f32 (ICons ?zero (INil)))))
(= ?neg_mask (Op (Mul
?shape
?mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?neg_mask_strides)
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
(= ?not_mask_f32 (Op (Add
?shape
?neg_mask_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?not_mask_f32_strides)
(ICons ?neg_mask (ICons ?one (INil)))))
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
(= ?positive (Op (Mul
?shape
?not_mask_f32_strides
?matmul_strides
?positive_strides)
(ICons ?not_mask (ICons ?matmul (INil)))))
(= ?relu (Op (Add
?shape
?zeroed_strides
?positive_strides
?relu_strides)
(ICons ?zeroed (ICons ?positive (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "RELU_BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?relu ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt batched relu bias epilogue"
)
; Canonical tanh-approx GELU can also appear directly as:
;
; x * sigmoid(1.5957691216 * x * (1 + 0.044715 * x * x))
;
; Match that sigmoid form and fuse it into the cuBLASLt GELU epilogues.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
(= ?gelu_one (Op (Constant 1.000000) (INil)))
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
(= ?neg1 (Op (Constant -1.000000) (INil)))
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
(= ?log2e (Op (Constant 1.442695) (INil)))
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "GELU")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?gelu_out ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt gelu epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
(= ?gelu_one (Op (Constant 1.000000) (INil)))
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
(= ?neg1 (Op (Constant -1.000000) (INil)))
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
(= ?log2e (Op (Constant 1.442695) (INil)))
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype (F32)
?compute_type ?scale_dtype
?alpha 0.0 "GELU_BIAS")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?gelu_out ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt gelu bias epilogue"
)
; This first slice fuses column-bias adds into CUBLASLT_EPILOGUE_BIAS for the
; older COL-ordered output view. In that view Luminal's logical [m,n] output is
; represented as a cuBLASLt [n,m] matrix, so cuBLASLt's row-broadcast bias maps
; to the common logical column bias of length n.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(= ?add (Op (Add
(ECons ?n (ECons ?m (ENil)))
?matmul_add_strides
?bias_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?bias (INil)))))
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
(= ?matmul_add_strides ?add_out_strides)
(= ?d_dtype (dtype ?bias))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt 2d matmul plus column bias epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(= ?add (Op (Add
(ECons ?n (ECons ?m (ENil)))
?bias_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?bias (ICons ?matmul (INil)))))
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
(= ?matmul_add_strides ?add_out_strides)
(= ?d_dtype (dtype ?bias))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt 2d column bias plus matmul epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(= ?add (Op (Add
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
?matmul_add_strides
?bias_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?bias (INil)))))
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
(= ?matmul_add_strides ?add_out_strides)
(= ?d_dtype (dtype ?bias))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt batched matmul plus column bias epilogue"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(= ?add (Op (Add
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
?bias_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?bias (ICons ?matmul (INil)))))
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
(= ?matmul_add_strides ?add_out_strides)
(= ?d_dtype (dtype ?bias))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order "COL"
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "BIAS")
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt batched column bias plus matmul epilogue"
)

View File

@@ -0,0 +1,345 @@
; FP8 support is narrower than "any FP8 x any FP8". cuBLASLt's regular FP8
; matmul table supports these A/B descriptor pairs for F32 outputs:
; E4M3 x E4M3
; E4M3 x E5M2
; E5M2 x E4M3
; and requires TN format on Ada/Hopper-class GPUs. These rules therefore match
; row-major x column-major Luminal matmuls, which the existing COL-order lowering
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
; transb=N.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F8E4M3) (dtype ?a))
(= (F8E4M3) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e4m3/e4m3 row-major x column-major f32 output"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F8E4M3) (dtype ?a))
(= (F8E5M2) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e5m2/e4m3 row-major x column-major f32 output"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F8E5M2) (dtype ?a))
(= (F8E4M3) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e4m3/e5m2 row-major x column-major f32 output"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= (F8E4M3) (dtype ?a))
(= (F8E4M3) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
?batch
?b_batch_stride
?a_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e4m3/e4m3 batched row-major x column-major f32 output"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= (F8E4M3) (dtype ?a))
(= (F8E5M2) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
?batch
?b_batch_stride
?a_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e5m2/e4m3 batched row-major x column-major f32 output"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= (F8E5M2) (dtype ?a))
(= (F8E4M3) (dtype ?b))
)
(
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
"COL" "COL" "COL" "COL"
?b_n_stride
?a_m_stride
?n
?n
?batch
?b_batch_stride
?a_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
(ICons ?b (ICons ?a (INil)))))
(union ?cast ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublaslt fp8 e4m3/e5m2 batched row-major x column-major f32 output"
)

View File

@@ -0,0 +1,75 @@
; Mixed output dtype rewrites for cuBLASLt.
;
; The first mixed mode we need for low-precision matmuls is:
;
; D[f32] = A[fp16/bf16] * B[fp16/bf16]
;
; Luminal graphs express this today as a Cast(F32) around a low-precision
; matmul. cuBLASLt can write the f32 output directly, so expose that candidate
; before beta fusion tries to consume an f32 C input.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
(F16) (F16) (F16) (F16)
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
?inputs))
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
(F16) (F16) (F32) (F32)
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
?inputs))
(union ?cast ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt f16 matmul cast f32 output"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
(Bf16) (Bf16) (Bf16) (Bf16)
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
?inputs))
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
(Bf16) (Bf16) (F32) (F32)
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
?inputs))
(union ?cast ?fused)
(set (dtype ?fused) (F32))
)
:ruleset matmul_backend
:name "cublaslt bf16 matmul cast f32 output"
)

View File

@@ -0,0 +1,452 @@
; Natural cuBLASLt row-order output rewrites. These keep Luminal's logical
; output C[m,n] as a cuBLASLt ROW-ordered D[m,n] instead of using the older
; swapped COL-ordered D[n,m] view. A and B orders mirror their matched logical
; layouts, so this family is the legal base for future ROW-ordered beta fusions.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"ROW" "ROW" "ROW" "ROW"
?a_m_stride
?b_k_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order row-major x row-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"ROW" "COL" "ROW" "ROW"
?a_m_stride
?b_n_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order row-major x column-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"COL" "ROW" "ROW" "ROW"
?a_k_stride
?b_k_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order column-major x row-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
(= ?k_stride (MIter))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"COL" "COL" "ROW" "ROW"
?a_k_stride
?b_n_stride
?n
?n
(MNum 1)
(MNum 0)
(MNum 0)
(MNum 0)
(MNum 0)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order column-major x column-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"ROW" "ROW" "ROW" "ROW"
?a_m_stride
?b_k_stride
?n
?n
?batch
?a_batch_stride
?b_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order batched row-major x row-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"ROW" "COL" "ROW" "ROW"
?a_m_stride
?b_n_stride
?n
?n
?batch
?a_batch_stride
?b_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order batched row-major x column-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"COL" "ROW" "ROW" "ROW"
?a_k_stride
?b_k_stride
?n
?n
?batch
?a_batch_stride
?b_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order batched column-major x row-major"
)
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
(cublaslt_base_dtype ?dt)
)
(
(let ?sgemm (Op (cublaslt
?m ?n ?k
"N" "N"
"COL" "COL" "ROW" "ROW"
?a_k_stride
?b_n_stride
?n
?n
?batch
?a_batch_stride
?b_batch_stride
(MMul ?m ?n)
(MMul ?m ?n)
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:ruleset matmul_backend
:name "cublaslt row-order batched column-major x column-major"
)

View File

@@ -0,0 +1,316 @@
; Scalar alpha/beta rewrites for cuBLASLt. These rules target scalar constants
; expanded across the matmul/add shape, i.e. zero strides on every logical axis.
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?scale (Op (Constant ?alpha) (INil)))
; alpha=1.0 hash-conses ?fused == ?matmul; the union merges Mul into ?matmul's eclass and saturate diverges.
(!= ?alpha 1.0)
(= ?scaled (Op (Mul ?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?scaled_out_strides)
(ICons ?matmul (ICons ?scale (INil)))))
(= ?matmul_strides ?scaled_out_strides)
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?scaled ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt 2d alpha scale"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
1.0 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?scale (Op (Constant ?alpha) (INil)))
; See 2d alpha scale: alpha=1.0 makes (saturate ...) diverge.
(!= ?alpha 1.0)
(= ?scaled (Op (Mul ?shape
?matmul_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?scaled_out_strides)
(ICons ?matmul (ICons ?scale (INil)))))
(= ?matmul_strides ?scaled_out_strides)
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?c_order ?d_order
?lda ?ldb ?ldc ?ldd
?batch
?stride_a ?stride_b ?stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 "DEFAULT")
(ICons ?a (ICons ?b ?matmul_tail))))
(union ?scaled ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt batched alpha scale"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?beta_node (Op (Constant ?beta) (INil)))
(= ?scaled_c (Op (Mul
(ECons ?m (ECons ?n (ENil)))
?c_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?scaled_c_out_strides)
(ICons ?c (ICons ?beta_node (INil)))))
(= ?add (Op (Add
(ECons ?m (ECons ?n (ENil)))
?matmul_add_strides
?scaled_c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?scaled_c (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?scaled_c_add_strides ?scaled_c_out_strides)
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order 2d scaled c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
(MNum 1)
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?beta_node (Op (Constant ?beta) (INil)))
(= ?scaled_c (Op (Mul
(ECons ?m (ECons ?n (ENil)))
?c_strides
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
?scaled_c_out_strides)
(ICons ?c (ICons ?beta_node (INil)))))
(= ?add (Op (Add
(ECons ?m (ECons ?n (ENil)))
?scaled_c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?scaled_c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
(= ?scaled_c_add_strides ?scaled_c_out_strides)
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
(MNum 1)
?stride_a ?stride_b (MNum 0) ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order 2d scaled c plus matmul beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?beta_node (Op (Constant ?beta) (INil)))
(= ?scaled_c (Op (Mul
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?c_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?scaled_c_out_strides)
(ICons ?c (ICons ?beta_node (INil)))))
(= ?add (Op (Add
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?matmul_add_strides
?scaled_c_add_strides
?add_out_strides)
(ICons ?matmul (ICons ?scaled_c (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?scaled_c_add_strides ?scaled_c_out_strides)
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order batched scaled c beta"
)
(rule
(
(= ?matmul (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order ?matmul_c_order "ROW"
?lda ?ldb ?matmul_ldc ?ldd
?batch
?stride_a ?stride_b ?matmul_stride_c ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha 0.0 ?epilogue)
(ICons ?a (ICons ?b ?matmul_tail))))
(= ?beta_node (Op (Constant ?beta) (INil)))
(= ?scaled_c (Op (Mul
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?c_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?scaled_c_out_strides)
(ICons ?c (ICons ?beta_node (INil)))))
(= ?add (Op (Add
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
?scaled_c_add_strides
?matmul_add_strides
?add_out_strides)
(ICons ?scaled_c (ICons ?matmul (INil)))))
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
(= ?scaled_c_add_strides ?scaled_c_out_strides)
(= ?c_col_stride (MIter))
(!= ?c_row_stride (MNum 0))
(= ?matmul_add_strides ?add_out_strides)
(= ?c_dtype (dtype ?c))
)
(
(let ?fused (Op (cublaslt
?m ?n ?k
?a_layout ?b_layout
?a_order ?b_order "ROW" "ROW"
?lda ?ldb ?c_row_stride ?ldd
?batch
?stride_a ?stride_b ?c_batch_stride ?stride_d
?a_dtype ?b_dtype ?c_dtype ?d_dtype
?compute_type ?scale_dtype
?alpha ?beta ?epilogue)
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
(union ?add ?fused)
(set (dtype ?fused) ?d_dtype)
)
:ruleset matmul_backend
:name "cublaslt row-order batched scaled c plus matmul beta"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
# FlashInfer Integration
FlashInfer replaces the multi-op attention pattern (Q×K^T → scale → mask → softmax → ×V) with a single fused GPU kernel via [FlashInfer](https://github.com/flashinfer-ai/flashinfer)'s batch decode and batch prefill APIs.
## Current State
**Working:**
- Egglog rewrite rule matches any GQA paged attention pattern (model-agnostic shapes)
- GA search selects FlashInfer when it wins profiling — verified on Llama 3 8B (32 layers) and Qwen 3 4B (36 layers)
- **BatchDecode** (s=1): fp32 natively — FlashInfer's decode kernel uses scalar vectorized dot products, no tensor cores
- **BatchPrefill**: template-instantiated for fp16 but **not callable from fp32** — FlashInfer's prefill kernel requires tensor core MMA (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically only operate on 16-bit types; the C API stubs return -1 for fp32; will be enabled when native fp16/bf16 pipeline is added
- Decode handles all cases in the current fp32 pipeline (prefill uses cuBLAS attention via dim bucketing)
- Indptr-based mask: `qo_indptr` and `kv_indptr` are computed in-graph so the egglog rule can see them in the same chunk as the attention ops
**Not yet implemented:**
- Native fp16 / bf16 pipeline (would eliminate the cast overhead in prefill)
- Page sizes > 1
---
## File Organization
```
src/host/flashinfer/
flashinfer_attention.egg — egglog rewrite rule (pattern match → FlashInferAttention)
mod.rs — FlashInferAttention op (EgglogOp + HostOp impl)
jit.rs — JIT compilation: nvcc wrapper.cu → .so, dlopen, fn pointers
find_indptrs.rs — walks the mask e-graph node to locate qo_indptr / kv_indptr inputs
wrapper.cu — CUDA: FlashInfer template instantiation + helper kernels
wrapper.h — C API header for wrapper.cu
README.md — this file
```
## How It Works
### 1. Egglog Pattern Matching
The rule in `flashinfer_attention.egg` matches the structural pattern of paged GQA attention:
```
Gather(K_cache, idx) → GQA broadcast (Mul×1.0) → Q×K^T → Sum → scale → mask Add → softmax → attn×V → Sum → output
Gather(V_cache, idx) → GQA broadcast (Mul×1.0) ──────────────────────────────────────────→ attn×V → Sum → output
```
Key anchors that prevent false matches on MLP or other ops:
- Two Gather ops from 2D cache pools (MLP never uses Gather)
- GQA broadcast via `Mul(gathered, Constant(1.0))` with all-zero strides
- Mask Add with zero-stride broadcast in the first (nheads) dimension
- Two sequential matmul+Sum pairs connected through softmax
Shape dimensions are egglog variables, not pinned constants — the rule works for any model with GQA (Llama, Qwen, Mistral, etc.). The structural invariants (dimension count, zero-stride positions, Gather from 2D) are enough to avoid combinatorial explosion during saturation.
When the rule fires, it unions `FlashInferAttention` with the original attention output, making it an equivalent alternative in the e-graph. The GA search then profiles both paths and picks the faster one.
### 2. Extraction: Finding Indptrs
During `extract()` (called when egglog selects the FlashInferAttention e-node), `find_indptrs.rs` walks backward from the mask node in the e-graph to locate the `qo_indptr` and `kv_indptr` Input nodes. It validates the mask structure by checking for the `Mul(allowed, Constant(1e10))` pattern that `compute_attn_mask()` produces.
The indptrs are appended as inputs 5 and 6 to the FlashInferAttention op, so the runtime can build the CSR page table directly without recomputing anything.
### 3. JIT Compilation
FlashInfer requires `HEAD_DIM` as a compile-time template parameter. Rather than baking it at `cargo build` time, `jit.rs` JIT-compiles `wrapper.cu` with the model's actual HEAD_DIM:
1. First call to `ensure_compiled(head_dim)` runs `nvcc` with `-DLUMINAL_HEAD_DIM=<N>`
2. The compiled `.so` is cached at `~/.cache/luminal/flashinfer/libflashinfer_hd<N>_<arch>.so`
3. Subsequent calls load the cached library via `dlopen`
4. Function pointers (plan, run, transpose, etc.) are resolved and stored in a `static OnceLock`
Supported HEAD_DIM values: 64, 128, 256.
### 4. Runtime Execution
`FlashInferAttention::execute()` dispatches to decode or prefill based on `total_q_tokens vs batch_size`:
**Common steps:**
1. **Extract kv_indices** — a helper kernel converts the flat gather index `(c, KV_DIM)` to slot indices `(c,)`
2. **Read indptrs to host** — copied to CPU for the plan phase
3. **Plan** — queries GPU occupancy and decides split-KV decomposition
4. **Run** — the fused kernel writes `(total_q_tokens, num_qo_heads, head_dim)`
5. **Transpose** — transposes to `(num_qo_heads, total_q_tokens, head_dim)` to match the Sum reduction layout
**Decode path** (current, fp32): Always used. Runs FlashInfer's BatchDecode directly on fp32 buffers.
**Prefill path** (future, fp16/bf16 only): The prefill kernel templates are compiled into the JIT .so for fp16 (CTA_TILE_Q=16/64/128, causal mask). The C API stubs currently return -1 since the pipeline is fp32. When native fp16/bf16 dtype support is added, `execute()` will dispatch to prefill when `total_q_tokens > batch_size`.
Global workspaces (`static OnceLock`) are shared across all FlashInferAttention instances to avoid ~4ms allocation overhead per GA profiling candidate. Without this, the GA never selects FlashInfer because the first-run allocation cost dwarfs the kernel time.
## How the Attention Mask Enables FlashInfer
For the egglog rule to fire, the `qo_indptr` and `kv_indptr` tensors must be visible in the same e-graph chunk as the attention ops. This is why the mask is computed *inside* each layer (via `compute_attn_mask()` in the model) rather than passed as a pre-computed input.
The mask computation uses a specific structure:
```rust
let allowed = same_request * causal;
allowed * 1e10 - 1e10 // → 0.0 for allowed, -1e10 for blocked
```
The `Mul(allowed, Constant(1e10))` pattern is the anchor that `find_indptrs.rs` uses to walk backward and locate the indptr inputs.
## Roadmap
Items listed in priority order. Checked items are done.
- [x] Model-agnostic egglog rule (shape variables instead of Llama-specific constants)
- [x] bs>1 supersequence decode
- [x] Indptr-based attention mask (replaces CPU-computed mask)
- [x] Multi-model support (verified on Llama 3 8B and Qwen 3 4B)
- [x] BatchPrefill kernel compiled for fp16 (causal mask, CTA_TILE_Q=16/64/128)
- [ ] Native fp16 / bf16 pipeline (enables prefill, reduces memory, eliminates cuBLAS prefill fallback)
- [ ] HEAD_DIM dispatch for 64, 96 (JIT supports 64/128/256; wrapper.cu needs 96 for Phi)
- [ ] Page sizes > 1 (currently page_size=1; larger pages reduce CSR overhead)
- [ ] Sliding window, ALiBi, logits soft cap (FlashInfer `AttentionVariant` templates)
- [ ] MHA / MQA / arbitrary GQA ratios beyond {1, 2, 4, 8}
## Key Design Decisions
- **page_size=1**: Each KV cache slot is one "page". This simplifies the CSR page table (`kv_indices` = physical slot indices directly) and matches the flat `(num_slots, KV_DIM)` cache layout.
- **Pinned structural anchors**: The egglog rule pins the *structure* (number of dimensions, which dims are zero-stride, presence of Gather from 2D cache) but uses variables for the *values* (head counts, head_dim). This prevents saturation blowup while remaining model-agnostic.
- **Prefill requires fp16/bf16**: FlashInfer's prefill kernel uses tensor core MMA instructions (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically require 16-bit inputs — there is no fp32 tensor core matmul instruction. The prefill kernel templates are compiled into the .so for fp16 but the C API returns -1 for fp32 callers. When native fp16/bf16 is added, prefill will be enabled automatically.
- **Global workspaces**: Float workspace (128 MiB), int workspace (8 MiB), and a page-locked host buffer are allocated once via `static OnceLock` and shared across all instances.

View File

@@ -0,0 +1,248 @@
//! Walk the e-graph from the mask node to find qo_indptr and kv_indptr Input nodes.
//!
//! The mask is produced by `compute_attn_mask(q_pos, qo_indptr, kv_indptr)` using
//! primitive HLIR ops. This module validates the mask's structure and extracts the
//! indptr Input node IDs so FlashInfer can use them directly.
use luminal::egglog_utils::{ClassId, NodeId, SerializedEGraph};
use luminal::prelude::FxHashSet;
/// Result of walking the mask computation chain.
#[derive(Debug)]
pub struct IndptrNodes<'a> {
pub qo_indptr: &'a NodeId,
pub kv_indptr: &'a NodeId,
}
/// Find the qo_indptr and kv_indptr Input nodes by walking backwards from the mask.
///
/// Validates the mask structure: `allowed * 1e10 + (-1e10)`. Then does a BFS from
/// the `allowed` subtree to find all reachable Input nodes with names containing
/// "qo_indptr" and "kv_indptr".
///
/// Panics with a diagnostic message if the structure doesn't match or the
/// indptr inputs can't be found.
pub fn find_indptr_inputs<'a>(
egraph: &'a SerializedEGraph,
mask_node: &'a NodeId,
) -> IndptrNodes<'a> {
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
let (mask_label, mask_children) = &egraph.enodes[mask_node];
assert!(
mask_label == "Op",
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
);
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
let mask_kind_label = &egraph.enodes[mask_kind].0;
assert!(
mask_kind_label.contains("Add"),
"find_indptr_inputs: mask is not an Add (kind={mask_kind_label})"
);
let mask_inputs = walk_ilist_simple(egraph, &mask_children[1]);
assert_eq!(
mask_inputs.len(),
2,
"find_indptr_inputs: mask Add should have 2 inputs, got {}",
mask_inputs.len()
);
// Step 2: One of the inputs should be Mul(allowed, Constant(1e10))
let (scaled_allowed, allowed_node) = find_1e10_mul(egraph, &mask_inputs);
// Step 3: BFS from `allowed` to find all reachable Input nodes
let reachable_inputs = find_reachable_inputs(egraph, allowed_node);
// Step 4: Match by name
let mut qo_indptr: Option<&NodeId> = None;
let mut kv_indptr: Option<&NodeId> = None;
for (node_id, name) in &reachable_inputs {
if name.contains("qo_indptr") {
qo_indptr = Some(node_id);
} else if name.contains("kv_indptr") {
kv_indptr = Some(node_id);
}
}
let qo = qo_indptr.unwrap_or_else(|| {
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
panic!(
"find_indptr_inputs: could not find 'qo_indptr' Input reachable from mask.\n\
Found inputs: {:?}\n\
Mask node: {:?}\n\
Scaled allowed node: {:?}",
found_names, mask_node, scaled_allowed
);
});
let kv = kv_indptr.unwrap_or_else(|| {
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
panic!(
"find_indptr_inputs: could not find 'kv_indptr' Input reachable from mask.\n\
Found inputs: {:?}\n\
Mask node: {:?}\n\
Scaled allowed node: {:?}",
found_names, mask_node, scaled_allowed
);
});
IndptrNodes {
qo_indptr: qo,
kv_indptr: kv,
}
}
fn find_1e10_mul<'a>(
egraph: &'a SerializedEGraph,
mask_add_inputs: &[&'a NodeId],
) -> (&'a NodeId, &'a NodeId) {
for &input_node in mask_add_inputs {
let (label, children) = &egraph.enodes[input_node];
if label != "Op" {
continue;
}
let kind = resolve_first_node(egraph, &children[0]);
if !egraph.enodes[kind].0.contains("Mul") {
continue;
}
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
if mul_inputs.len() != 2 {
continue;
}
for (i, &inp) in mul_inputs.iter().enumerate() {
if is_constant(egraph, inp, 1e10) {
let other = mul_inputs[1 - i];
return (input_node, other);
}
}
}
let mut debug_info = String::new();
for (i, &input_node) in mask_add_inputs.iter().enumerate() {
let (label, children) = &egraph.enodes[input_node];
debug_info.push_str(&format!("\n input[{i}]: label={label}"));
if label == "Op" && !children.is_empty() {
let kind = resolve_first_node(egraph, &children[0]);
let kind_label = &egraph.enodes[kind].0;
debug_info.push_str(&format!(" kind={kind_label}"));
for (j, kc) in egraph.enodes[kind].1.iter().enumerate() {
let kc_node = resolve_first_node(egraph, kc);
debug_info.push_str(&format!(" child[{j}]={}", egraph.enodes[kc_node].0));
}
if kind_label.contains("Mul") && children.len() >= 2 {
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
for (j, &mi) in mul_inputs.iter().enumerate() {
let (ml, mc) = &egraph.enodes[mi];
debug_info.push_str(&format!("\n mul_input[{j}]: label={ml}"));
if ml == "Op" && !mc.is_empty() {
let mk = resolve_first_node(egraph, &mc[0]);
debug_info.push_str(&format!(" kind={}", egraph.enodes[mk].0));
for (k, mkc) in egraph.enodes[mk].1.iter().enumerate() {
let mkc_node = resolve_first_node(egraph, mkc);
debug_info.push_str(&format!(" ch[{k}]={}", egraph.enodes[mkc_node].0));
}
}
}
}
}
}
panic!(
"find_indptr_inputs: could not find Mul(allowed, Constant(1e10)) in mask Add inputs.{debug_info}"
);
}
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
let (label, children) = &egraph.enodes[node];
if label != "Op" {
return false;
}
let kind = resolve_first_node(egraph, &children[0]);
let kind_label = &egraph.enodes[kind].0;
if !kind_label.contains("Constant") {
return false;
}
let val_children = &egraph.enodes[kind].1;
if val_children.is_empty() {
return false;
}
let val_node = resolve_first_node(egraph, &val_children[0]);
let val_str = &egraph.enodes[val_node].0;
if let Ok(val) = val_str.parse::<f64>() {
(val as f32 - expected).abs() < 1.0
} else {
false
}
}
fn find_reachable_inputs<'a>(
egraph: &'a SerializedEGraph,
start: &'a NodeId,
) -> Vec<(&'a NodeId, String)> {
let mut found = Vec::new();
let mut visited = FxHashSet::default();
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if !visited.insert(node) {
continue;
}
let (label, children) = &egraph.enodes[node];
if label == "Input" {
if children.len() >= 2 {
let name_node = resolve_first_node(egraph, &children[1]);
let name = egraph.enodes[name_node].0.trim_matches('"').to_string();
found.push((node, name));
}
continue;
}
if label == "Op" && children.len() >= 2 {
let ir_inputs = walk_ilist_simple(egraph, &children[1]);
for inp in ir_inputs {
stack.push(inp);
}
}
}
found
}
fn walk_ilist_simple<'a>(
egraph: &'a SerializedEGraph,
ilist_eclass: &'a ClassId,
) -> Vec<&'a NodeId> {
let mut inputs = Vec::new();
let mut current = resolve_first_node(egraph, ilist_eclass);
loop {
let (label, children) = &egraph.enodes[current];
if label == "INil" {
break;
}
if label != "ICons" {
break;
}
let ir_node = resolve_first_ir_node(egraph, &children[0]);
inputs.push(ir_node);
current = resolve_first_node(egraph, &children[1]);
}
inputs
}
fn resolve_first_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
&egraph.eclasses[eclass].1[0]
}
fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
let nodes = &egraph.eclasses[eclass].1;
for node in nodes {
let label = &egraph.enodes[node].0;
if label == "Op" || label == "Input" {
return node;
}
}
&nodes[0]
}

View File

@@ -0,0 +1,125 @@
; FlashInfer batch decode attention rewrite rule.
;
; Matches the paged attention pattern for ANY model with GQA:
; Gather(K_cache) → GQA broadcast → Q*K^T matmul → scale → add mask → softmax → attn*V matmul
; Gather(V_cache) → GQA broadcast ──────────────────────────────────────────→ attn*V matmul
;
; Structural anchors (prevent false matches on MLP/other ops):
; - Gather ops from 2D cache pools (MLP never uses Gather)
; - GQA broadcast via Mul(gathered, Constant(1.0)) with all-zero strides
; - Scale Mul(QK, constant) connecting QK scores to mask Add
; - Mask Add with zero-stride broadcast in first dim (nheads broadcast)
; - Data flow: two sequential matmul+reduce pairs connected through softmax
;
; The egglog rule captures the mask as 5th input. During extract(), a Rust
; function walks the mask's computation chain in the e-graph to locate the
; qo_indptr and kv_indptr Input nodes (validated via the Constant(1e10) anchor
; and structural checks). These are appended as inputs 5 and 6 so FlashInfer
; can build the CSR page table directly — no runtime derivation needed.
;
; Shape dimensions are egglog variables, not pinned constants.
; Dynamic dims "s" (batch/seq) and "c" (context) stay pinned as MVar.
(rule
(
; ── Second matmul: Mul(softmax_out, V_gqa) ──
; Shape: (nheads, s, hdim, c) — 4D
(= ?mul2 (Op (Mul
(ECons ?nheads (ECons (MVar "s") (ECons ?hdim (ECons (MVar "c") (ENil)))))
?mul2_a_strides
?mul2_b_strides
?mul2_out_strides)
(ICons ?soft (ICons ?v_gqa (INil)))))
; ── Second matmul: Sum (reduction over c) → output ──
; Shape: (nheads, s, hdim) — reduces c
(= ?output (Op (Sum
(ECons ?nheads2 (ECons (MVar "s") (ECons ?hdim2 (ENil))))
(MVar "c")
?out_in_strides
(MIter)
?out_out_strides)
(ICons ?mul2 (INil))))
; ── V GQA broadcast: Mul(V_gathered, 1.0) with zero-stride constant ──
; Shape: (nheads, c, hdim) — 3D
(= ?v_gqa_const (Op (Constant 1.000000) (INil)))
(= ?v_gqa (Op (Mul
(ECons ?nheads3 (ECons (MVar "c") (ECons ?hdim3 (ENil))))
?v_gqa_a_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?v_gqa_out_strides)
(ICons ?v_gathered (ICons ?v_gqa_const (INil)))))
; ── V Gather: rows from V_cache (2D) ──
; Shape: (c, kvdim), Source: (num_slots, kvdim)
(= ?v_gathered (Op (Gather
(ECons (MVar "c") (ECons ?kvdim (ENil)))
?v_gather_strides
(ECons ?num_slots_v (ECons ?kvdim2 (ENil)))
?v_src_strides)
(ICons ?v_idx (ICons ?v_cache (INil)))))
; ── First matmul: Mul(Q, K_gqa) ──
; Shape: (nheads, s, c, hdim) — 4D
(= ?mul1 (Op (Mul
(ECons ?nheads4 (ECons (MVar "s") (ECons (MVar "c") (ECons ?hdim4 (ENil)))))
?mul1_a_strides
?mul1_b_strides
?mul1_out_strides)
(ICons ?q (ICons ?k_gqa (INil)))))
; ── First matmul: Sum (reduction over hdim) → QK scores ──
; Shape: (nheads, s, c) — reduces hdim
(= ?qk (Op (Sum
(ECons ?nheads5 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
?hdim5
?qk_in_strides
(MIter)
?qk_out_strides)
(ICons ?mul1 (INil))))
; ── Mask Add: Add(scaled_QK, mask) ──
; Shape: (nheads, s, c) — 3D
; Mask is broadcast from (s, c) via zero-stride in first dim (nheads).
(= ?masked (Op (Add
(ECons ?nheads8 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
?mask_add_a_strides
(ECons (MNum 0) ?mask_rest_strides)
?mask_add_out_strides)
(ICons ?scaled_qk (ICons ?mask (INil)))))
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
; Shape: (nheads, hdim, c) — 3D
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
(= ?k_gqa (Op (Mul
(ECons ?nheads6 (ECons ?hdim6 (ECons (MVar "c") (ENil))))
?k_gqa_a_strides
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
?k_gqa_out_strides)
(ICons ?k_gathered (ICons ?k_gqa_const (INil)))))
; ── K Gather: rows from K_cache (2D) ──
; Shape: (c, kvdim), Source: (num_slots, kvdim)
(= ?k_gathered (Op (Gather
(ECons (MVar "c") (ECons ?kvdim3 (ENil)))
?k_gather_strides
(ECons ?num_slots_k (ECons ?kvdim4 (ENil)))
?k_src_strides)
(ICons ?k_idx (ICons ?k_cache (INil)))))
; ── Dtype consistency ──
(= ?dt (dtype ?q))
(= ?dt (dtype ?k_cache))
(= ?dt (dtype ?v_cache))
)
(
(let ?fi (Op (FlashInferAttention
?nheads (MDiv ?kvdim ?hdim) ?hdim (MNum 1) (MVar "s"))
(ICons ?q (ICons ?k_cache (ICons ?v_cache (ICons ?k_idx (ICons ?mask (INil))))))))
(union ?output ?fi)
(set (dtype ?fi) ?dt)
)
:ruleset matmul_backend
:name "FlashInfer batch decode attention"
)

View File

@@ -0,0 +1,504 @@
//! JIT compilation and dynamic loading of FlashInfer kernels.
//!
//! Everything runs at compile / profiling time — there is no `build.rs`.
//! `wrapper.cu` and `wrapper.h` are embedded via `include_str!()` and
//! extracted to the cache directory on first use. The FlashInfer + CUTLASS
//! header trees are located by probing `LUMINAL_FLASHINFER_DIR`, a small set
//! of default paths, and (as a last resort) by `git clone`-ing FlashInfer at
//! a pinned commit into the cache. `nvcc` is then invoked with the model's
//! actual `HEAD_DIM` and the resulting `.so` is `dlopen`'d.
//!
//! `ensure_compiled` is called from `FlashInferAttention::extract()`, i.e.
//! during luminal's compile / GA-profiling phase, not from `execute()`. After
//! the first call the `OnceLock` makes subsequent lookups free.
use std::{
ffi::c_void,
hash::{Hash, Hasher},
path::{Path, PathBuf},
process::Command,
sync::OnceLock,
};
// ── Function pointer types matching wrapper.h ──
pub type PlanFn = unsafe extern "C" fn(
float_workspace: *mut c_void,
float_ws_size: usize,
int_workspace: *mut c_void,
int_ws_size: usize,
page_locked_int_workspace: *mut c_void,
indptr_h: *mut i32,
batch_size: i32,
num_qo_heads: i32,
num_kv_heads: i32,
page_size: i32,
head_dim: i32,
stream: *mut c_void,
plan_info_out: *mut i64,
plan_info_len_out: *mut i32,
) -> i32;
pub type RunFn = unsafe extern "C" fn(
float_workspace: *mut c_void,
float_ws_size: usize,
int_workspace: *mut c_void,
plan_info_vec: *mut i64,
plan_info_len: i32,
q: *mut f32,
k_cache: *mut f32,
v_cache: *mut f32,
kv_indptr: *mut i32,
kv_indices: *mut i32,
kv_last_page_len: *mut i32,
output: *mut f32,
batch_size: i32,
num_qo_heads: i32,
num_kv_heads: i32,
page_size: i32,
head_dim: i32,
stream: *mut c_void,
) -> i32;
pub type ExtractFn = unsafe extern "C" fn(
flat_idx: *const i32,
out: *mut i32,
c: i32,
kv_dim: i32,
stream: *mut c_void,
);
pub type DeriveIndptrFn =
unsafe extern "C" fn(mask: *const f32, indptr: *mut i32, s: i32, c: i32, stream: *mut c_void);
pub type TransposeOutputFn = unsafe extern "C" fn(
src: *const f32,
dst: *mut f32,
batch: i32,
heads: i32,
dim: i32,
stream: *mut c_void,
);
pub type PrefillPlanFn = unsafe extern "C" fn(
float_workspace: *mut c_void,
float_ws_size: usize,
int_workspace: *mut c_void,
int_ws_size: usize,
page_locked_int_workspace: *mut c_void,
qo_indptr_h: *mut i32,
kv_indptr_h: *mut i32,
total_num_rows: i32,
batch_size: i32,
num_qo_heads: i32,
num_kv_heads: i32,
page_size: i32,
head_dim: i32,
stream: *mut c_void,
plan_info_out: *mut i64,
plan_info_len_out: *mut i32,
) -> i32;
pub type PrefillRunFn = unsafe extern "C" fn(
float_workspace: *mut c_void,
float_ws_size: usize,
int_workspace: *mut c_void,
plan_info_vec: *mut i64,
plan_info_len: i32,
q: *mut f32,
k_cache: *mut f32,
v_cache: *mut f32,
qo_indptr: *mut i32,
kv_indptr: *mut i32,
kv_indices: *mut i32,
kv_last_page_len: *mut i32,
output: *mut f32,
total_num_rows: i32,
batch_size: i32,
num_qo_heads: i32,
num_kv_heads: i32,
page_size: i32,
head_dim: i32,
stream: *mut c_void,
) -> i32;
// ── Embedded CUDA sources ──
const WRAPPER_CU: &str = include_str!("wrapper.cu");
const WRAPPER_H: &str = include_str!("wrapper.h");
// ── Loaded library handle ──
pub struct FlashInferLib {
// Keep the handle alive so the dlopen'd .so remains mapped.
_lib: libloading::Library,
pub plan: PlanFn,
pub run: RunFn,
pub extract_slot_indices: ExtractFn,
pub derive_indptr_from_mask: DeriveIndptrFn,
pub transpose_output: TransposeOutputFn,
pub prefill_plan: PrefillPlanFn,
pub prefill_run: PrefillRunFn,
}
// SAFETY: The library handle and function pointers are valid for the lifetime
// of the process. All functions are called with proper CUDA stream serialization.
unsafe impl Send for FlashInferLib {}
unsafe impl Sync for FlashInferLib {}
static FLASHINFER_LIB: OnceLock<FlashInferLib> = OnceLock::new();
/// Ensure the FlashInfer library is compiled and loaded for the given HEAD_DIM.
/// Returns a reference to the loaded library. Thread-safe via OnceLock.
pub fn ensure_compiled(head_dim: usize) -> &'static FlashInferLib {
FLASHINFER_LIB.get_or_init(|| {
assert!(
matches!(head_dim, 64 | 128 | 256),
"FlashInfer: unsupported HEAD_DIM={} (must be 64, 128, or 256 for f32)",
head_dim
);
let so_path = compile_or_cache(head_dim);
unsafe {
FlashInferLib::load(&so_path)
.unwrap_or_else(|e| panic!("Failed to load FlashInfer library: {e}"))
}
})
}
impl FlashInferLib {
/// Load a compiled FlashInfer .so and resolve function pointers.
///
/// # Safety
/// The .so must be a valid FlashInfer wrapper compiled from wrapper.cu.
unsafe fn load(path: &Path) -> Result<Self, libloading::Error> {
let lib = unsafe { libloading::Library::new(path)? };
let plan: PlanFn = unsafe { *lib.get::<PlanFn>(b"flashinfer_batch_decode_plan\0")? };
let run: RunFn = unsafe { *lib.get::<RunFn>(b"flashinfer_batch_decode_run\0")? };
let extract_slot_indices: ExtractFn =
unsafe { *lib.get::<ExtractFn>(b"flashinfer_extract_slot_indices\0")? };
let derive_indptr_from_mask: DeriveIndptrFn =
unsafe { *lib.get::<DeriveIndptrFn>(b"flashinfer_derive_indptr_from_mask\0")? };
let transpose_output: TransposeOutputFn =
unsafe { *lib.get::<TransposeOutputFn>(b"flashinfer_transpose_output\0")? };
let prefill_plan: PrefillPlanFn =
unsafe { *lib.get::<PrefillPlanFn>(b"flashinfer_batch_prefill_plan\0")? };
let prefill_run: PrefillRunFn =
unsafe { *lib.get::<PrefillRunFn>(b"flashinfer_batch_prefill_run\0")? };
Ok(Self {
_lib: lib,
plan,
run,
extract_slot_indices,
derive_indptr_from_mask,
transpose_output,
prefill_plan,
prefill_run,
})
}
}
/// Compile wrapper.cu for the given HEAD_DIM, or return cached .so path.
fn compile_or_cache(head_dim: usize) -> PathBuf {
let cache_dir = cache_directory();
std::fs::create_dir_all(&cache_dir).expect("Failed to create FlashInfer cache directory");
// Extract bundled wrapper sources to the cache so nvcc can compile them.
let (wrapper_cu_path, wrapper_h_dir) = extract_wrapper_sources(&cache_dir);
let arch = detect_cuda_arch();
// Bake a hash of the embedded wrapper into the .so name so old caches are
// discarded automatically when wrapper.cu or wrapper.h change.
let wrapper_hash = wrapper_source_hash();
let so_name = format!(
"libflashinfer_hd{}_{}_w{:016x}.so",
head_dim, arch, wrapper_hash
);
let so_path = cache_dir.join(&so_name);
if so_path.exists() {
eprintln!(
"FlashInfer: using cached library for HEAD_DIM={} ({})",
head_dim,
so_path.display()
);
return so_path;
}
let Some((flashinfer_include, cutlass_include)) = locate_flashinfer_includes() else {
panic!(
"FlashInfer: could not locate header tree. Set LUMINAL_FLASHINFER_DIR to the \
FlashInfer source root (the directory containing `include/` and \
`3rdparty/cutlass/include/`)."
);
};
eprintln!(
"FlashInfer: JIT compiling for HEAD_DIM={}, arch={} ...",
head_dim, arch
);
let start = std::time::Instant::now();
let output = Command::new("nvcc")
.args([
"-shared",
"-o",
so_path.to_str().unwrap(),
&format!("-DLUMINAL_HEAD_DIM={}", head_dim),
wrapper_cu_path.to_str().unwrap(),
"-I",
flashinfer_include.to_str().unwrap(),
"-I",
cutlass_include.to_str().unwrap(),
"-I",
wrapper_h_dir.to_str().unwrap(),
"-std=c++17",
&format!("-arch={}", arch),
"-O3",
"--expt-relaxed-constexpr",
"-w",
"-rdc=true",
"--compiler-options",
"-fPIC",
])
.output()
.expect("Failed to run nvcc. Is the CUDA toolkit installed?");
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
let _ = std::fs::remove_file(&so_path);
panic!(
"FlashInfer JIT compilation failed (HEAD_DIM={}, arch={}):\nstdout: {}\nstderr: {}",
head_dim, arch, stdout, stderr
);
}
let elapsed = start.elapsed();
eprintln!(
"FlashInfer: compiled in {:.1}s → {}",
elapsed.as_secs_f64(),
so_path.display()
);
so_path
}
/// Returns ~/.cache/luminal/flashinfer/
fn cache_directory() -> PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home)
.join(".cache")
.join("luminal")
.join("flashinfer")
}
/// Drop the embedded wrapper.cu/wrapper.h into the cache dir so nvcc has files
/// on disk to compile. Returns (wrapper.cu path, directory containing wrapper.h).
fn extract_wrapper_sources(cache_dir: &Path) -> (PathBuf, PathBuf) {
let cu = cache_dir.join("wrapper.cu");
let h = cache_dir.join("wrapper.h");
write_if_changed(&cu, WRAPPER_CU.as_bytes());
write_if_changed(&h, WRAPPER_H.as_bytes());
(cu, cache_dir.to_path_buf())
}
fn write_if_changed(path: &Path, contents: &[u8]) {
if let Ok(existing) = std::fs::read(path)
&& existing == contents
{
return;
}
std::fs::write(path, contents).unwrap_or_else(|e| {
panic!(
"FlashInfer: failed to write wrapper source to {}: {e}",
path.display()
)
});
}
fn wrapper_source_hash() -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
WRAPPER_CU.hash(&mut hasher);
WRAPPER_H.hash(&mut hasher);
hasher.finish()
}
// ── Pinned FlashInfer source ──
//
// Bumping this constant invalidates the cached source tree AND the cached .so
// (the .so cache key incorporates the wrapper hash, which is rebuilt against
// these headers, so different headers compile to a different .so file even at
// the same head_dim). If you change `FLASHINFER_GIT_REV`, also re-check
// `wrapper.cu` against the new FlashInfer API.
const FLASHINFER_GIT_URL: &str = "https://github.com/flashinfer-ai/flashinfer.git";
const CUTLASS_GIT_URL: &str = "https://github.com/NVIDIA/cutlass.git";
const FLASHINFER_GIT_REV: &str = "f1e6fdcb8f65104047697f022b5d055ef022d763";
const CUTLASS_GIT_REV: &str = "f3fde58372d33e9a5650ba7b80fc48b3b49d40c8";
fn locate_flashinfer_includes() -> Option<(PathBuf, PathBuf)> {
if let Ok(path) = std::env::var("LUMINAL_FLASHINFER_DIR")
&& !path.is_empty()
{
let root = PathBuf::from(path);
let inc = root.join("include");
let cutlass = root.join("3rdparty/cutlass/include");
if inc.exists() && cutlass.exists() {
return Some((inc, cutlass));
}
eprintln!(
"FlashInfer: LUMINAL_FLASHINFER_DIR={} did not contain include/ and \
3rdparty/cutlass/include/ — falling back to default locations",
root.display()
);
}
let home = std::env::var("HOME").unwrap_or_default();
let candidates = [
PathBuf::from(&home).join("luminal_cuda/crates/luminal_cuda/flashinfer"),
PathBuf::from(&home).join("luminal_cuda/flashinfer"),
PathBuf::from("/opt/luminal_cuda/crates/luminal_cuda/flashinfer"),
];
for root in candidates {
let inc = root.join("include");
let cutlass = root.join("3rdparty/cutlass/include");
if inc.exists() && cutlass.exists() {
return Some((inc, cutlass));
}
}
// Last resort: fetch the pinned commit into the cache directory.
fetch_flashinfer_source().ok().map(|root| {
let inc = root.join("include");
let cutlass = root.join("3rdparty/cutlass/include");
(inc, cutlass)
})
}
/// Clone FlashInfer at `FLASHINFER_GIT_REV` + CUTLASS at `CUTLASS_GIT_REV`
/// into `~/.cache/luminal/flashinfer-src/<short_rev>/` if absent, then return
/// the FlashInfer root directory. ~50 MB one-time download; subsequent calls
/// short-circuit on the directory check.
fn fetch_flashinfer_source() -> Result<PathBuf, String> {
let short = &FLASHINFER_GIT_REV[..12];
let cache_root = cache_directory().join("flashinfer-src").join(short);
let inc = cache_root.join("include");
let cutlass_inc = cache_root.join("3rdparty/cutlass/include");
if inc.exists() && cutlass_inc.exists() {
return Ok(cache_root);
}
let parent = cache_root.parent().unwrap();
std::fs::create_dir_all(parent)
.map_err(|e| format!("failed to create {}: {e}", parent.display()))?;
// Clone into a staging dir, then atomic rename. Protects against multiple
// processes racing to fetch the same source.
let staging = parent.join(format!(".staging-{}-{}", short, std::process::id()));
let _ = std::fs::remove_dir_all(&staging);
eprintln!(
"FlashInfer: cloning {FLASHINFER_GIT_URL} @ {short} into {} (one-time fetch, ~50 MB) …",
cache_root.display()
);
run_git(&[
"clone",
"--filter=blob:none",
"--no-checkout",
FLASHINFER_GIT_URL,
staging.to_str().unwrap(),
])?;
run_git_in(&staging, &["checkout", FLASHINFER_GIT_REV])?;
// Init only the CUTLASS submodule (skip spdlog — we don't need it for kernels).
let cutlass_path = staging.join("3rdparty/cutlass");
let _ = std::fs::remove_dir_all(&cutlass_path);
run_git(&[
"clone",
"--filter=blob:none",
"--no-checkout",
CUTLASS_GIT_URL,
cutlass_path.to_str().unwrap(),
])?;
run_git_in(&cutlass_path, &["checkout", CUTLASS_GIT_REV])?;
if !staging.join("include").exists() {
return Err(format!(
"FlashInfer clone succeeded but include/ missing at {}",
staging.display()
));
}
if !staging.join("3rdparty/cutlass/include").exists() {
return Err(format!(
"CUTLASS clone succeeded but include/ missing at {}",
staging.join("3rdparty/cutlass").display()
));
}
// Atomic-ish rename. If another process beat us to it, just keep theirs.
match std::fs::rename(&staging, &cache_root) {
Ok(()) => {}
Err(_) if cache_root.exists() => {
let _ = std::fs::remove_dir_all(&staging);
}
Err(e) => return Err(format!("rename to {} failed: {e}", cache_root.display())),
}
Ok(cache_root)
}
fn run_git(args: &[&str]) -> Result<(), String> {
let out = Command::new("git")
.args(args)
.output()
.map_err(|e| format!("failed to spawn `git`: {e}. Is git installed?"))?;
if !out.status.success() {
return Err(format!(
"`git {}` failed: {}",
args.join(" "),
String::from_utf8_lossy(&out.stderr)
));
}
Ok(())
}
fn run_git_in(cwd: &Path, args: &[&str]) -> Result<(), String> {
let out = Command::new("git")
.args(args)
.current_dir(cwd)
.output()
.map_err(|e| format!("failed to spawn `git`: {e}"))?;
if !out.status.success() {
return Err(format!(
"`git {}` in {} failed: {}",
args.join(" "),
cwd.display(),
String::from_utf8_lossy(&out.stderr)
));
}
Ok(())
}
/// Detect CUDA arch via env override → nvidia-smi → default sm_80.
fn detect_cuda_arch() -> String {
if let Ok(arch) = std::env::var("FLASHINFER_CUDA_ARCH") {
return arch;
}
if let Ok(output) = Command::new("nvidia-smi")
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
.output()
&& output.status.success()
{
let cap = String::from_utf8_lossy(&output.stdout);
let cap = cap.trim().lines().next().unwrap_or("8.0");
let sm = cap.replace('.', "");
if !sm.is_empty() {
return format!("sm_{}", sm);
}
}
"sm_80".to_string()
}

View File

@@ -0,0 +1,424 @@
pub mod find_indptrs;
pub mod jit;
use std::sync::{Arc, Mutex, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND},
extract_expr,
},
op::{EgglogOp, LLIROp},
prelude::{
tracing::{Level, span},
*,
},
};
use crate::{
cudarc::driver::{CudaSlice, CudaStream, DevicePtr, result},
host::{DeviceBuffer, HostOp},
};
/// FlashInfer attention op (batch decode, fp32).
///
/// Replaces the full paged-GQA attention pattern (gather → broadcast → Q*K^T →
/// scale → mask → softmax → *V) with a single FlashInfer fused kernel.
///
/// Graph inputs (7): Q, K_pool, V_pool, flat_gather_idx, mask, qo_indptr, kv_indptr.
/// The egglog rule captures the first 5; `extract()` appends qo/kv indptrs after
/// walking the e-graph from the mask. `batch_size` is derived at runtime from the
/// indptr length (= num_sequences + 1).
#[derive(Debug)]
pub struct FlashInferAttention {
pub num_qo_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub page_size: usize,
pub batch_dim: Expression,
pub plan_info: Mutex<Vec<i64>>,
}
// SAFETY: PAGE_LOCKED_WORKSPACE holds a raw pointer to page-locked CUDA memory
// allocated once and serialized via the CUDA stream that owns it.
unsafe impl Send for FlashInferAttention {}
unsafe impl Sync for FlashInferAttention {}
const FLOAT_WORKSPACE_SIZE: usize = 128 * 1024 * 1024; // 128 MiB
const INT_WORKSPACE_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
static PAGE_LOCKED_WORKSPACE: OnceLock<PageLockedPtr> = OnceLock::new();
struct PageLockedPtr(*mut u8);
// SAFETY: The pointer is page-locked CUDA memory allocated once via
// posix_memalign + cudaHostRegister and only mutated during OnceLock
// initialization.
unsafe impl Send for PageLockedPtr {}
unsafe impl Sync for PageLockedPtr {}
impl std::fmt::Debug for PageLockedPtr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PageLockedPtr({:p})", self.0)
}
}
impl Default for FlashInferAttention {
fn default() -> Self {
Self {
num_qo_heads: 0,
num_kv_heads: 0,
head_dim: 0,
page_size: 0,
batch_dim: Expression::default(),
plan_info: Mutex::new(Vec::new()),
}
}
}
impl EgglogOp for FlashInferAttention {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"FlashInferAttention",
&[
("num_qo_heads", EXPRESSION),
("num_kv_heads", EXPRESSION),
("head_dim", EXPRESSION),
("page_size", EXPRESSION),
("batch_dim", EXPRESSION),
],
)
}
fn n_inputs(&self) -> usize {
// Q, K_pool, V_pool, flat_gather_idx, mask (egglog IList).
// extract() appends qo_indptr + kv_indptr → 7 actual inputs at runtime.
5
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(include_str!["flashinfer_attention.egg"])]
}
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let num_qo_heads = extract_expr(egraph, kind_children[0], expr_cache)
.unwrap()
.exec(&FxHashMap::default())
.unwrap();
let num_kv_heads = extract_expr(egraph, kind_children[1], expr_cache)
.unwrap()
.exec(&FxHashMap::default())
.unwrap();
let head_dim = extract_expr(egraph, kind_children[2], expr_cache)
.unwrap()
.exec(&FxHashMap::default())
.unwrap();
let page_size = extract_expr(egraph, kind_children[3], expr_cache)
.unwrap()
.exec(&FxHashMap::default())
.unwrap();
let batch_dim = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
let extracted = Self {
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
batch_dim,
plan_info: Mutex::new(Vec::new()),
};
// Trigger JIT compilation (or .so cache hit) at extract time, not at
// first execute. Pays the ~30s cold-cache nvcc cost during compile
// rather than during the GA profiling loop, where it would dominate
// the candidate's measured runtime and make the GA reject FlashInfer.
let _ = jit::ensure_compiled(head_dim);
// Walk the mask e-graph chain to recover qo_indptr / kv_indptr Input nodes.
// input_enodes: [Q, K_cache, V_cache, gather_idx, mask]
let mask_node = input_enodes[4];
let indptrs = find_indptrs::find_indptr_inputs(egraph, mask_node);
// Build final inputs: [Q, K_cache, V_cache, gather_idx, mask, qo_indptr, kv_indptr]
let mut final_inputs = input_enodes;
final_inputs.push(indptrs.qo_indptr);
final_inputs.push(indptrs.kv_indptr);
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
(op, final_inputs)
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for FlashInferAttention {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let lib = jit::ensure_compiled(self.head_dim);
let total_q_tokens = self
.batch_dim
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention batch_dim is unresolved"))?;
let c = *dyn_map
.get(&'c')
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'c'"))?;
let r = *dyn_map
.get(&'r')
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'r'"))?;
if inputs.len() < 7 {
anyhow::bail!(
"FlashInferAttention expects 7 inputs (Q, K, V, flat_idx, mask, qo_indptr, kv_indptr), got {}",
inputs.len()
);
}
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
buffers.get(&node).copied().ok_or_else(|| {
anyhow::anyhow!("FlashInferAttention missing {name} buffer for {node:?}")
})
};
let q_buf = get_buf("Q", inputs[0])?;
let k_buf = get_buf("K_cache", inputs[1])?;
let v_buf = get_buf("V_cache", inputs[2])?;
let flat_idx_buf = get_buf("flat_gather_idx", inputs[3])?;
// inputs[4] = mask (unused by FlashInfer — indptrs replace it)
let kv_indptr_buf = get_buf("kv_indptr", inputs[6])?;
let out_buf = get_buf("output", self_node)?;
// Derive batch_size (num sequences) from r = indptr length.
let batch_size = r.saturating_sub(1);
let _span = span!(
Level::TRACE,
"FlashInferAttention",
total_q_tokens,
batch_size,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
)
.entered();
let kv_dim = self.num_kv_heads * self.head_dim;
let cu_stream = stream.cu_stream() as *mut std::ffi::c_void;
// Extract slot indices (one per context page) from the flat gather index.
let indices_buf = unsafe { stream.alloc::<u8>(c.max(1) * std::mem::size_of::<i32>())? };
let (indices_ptr, _idx_guard) = indices_buf.device_ptr(stream);
if c > 0 {
unsafe {
(lib.extract_slot_indices)(
flat_idx_buf.ptr() as *const i32,
indices_ptr as *mut i32,
c as i32,
kv_dim as i32,
cu_stream,
);
}
}
// Read kv_indptr to host for the plan phase.
let kv_indptr_bytes = r * 4;
let mut kv_indptr_host_bytes = vec![0u8; kv_indptr_bytes];
unsafe {
result::memcpy_dtoh_async(
&mut kv_indptr_host_bytes,
kv_indptr_buf.ptr(),
stream.cu_stream(),
)?;
}
stream.synchronize()?;
let kv_indptr_host: Vec<i32> = unsafe {
let mut v = std::mem::ManuallyDrop::new(kv_indptr_host_bytes);
Vec::from_raw_parts(v.as_mut_ptr() as *mut i32, r, r)
};
// kv_last_page_len = [1; batch_size] when page_size=1.
let last_page_host: Vec<i32> = vec![1; batch_size];
let last_page_dev: CudaSlice<u8> = if batch_size > 0 {
stream.clone_htod(unsafe {
std::slice::from_raw_parts(
last_page_host.as_ptr() as *const u8,
last_page_host.len() * std::mem::size_of::<i32>(),
)
})?
} else {
unsafe { stream.alloc::<u8>(1)? }
};
let (last_page_ptr, _lp_guard) = last_page_dev.device_ptr(stream);
// Global shared workspaces (allocated once across all op instances to
// amortize the ~4ms first-allocation cost during GA profiling).
static FLOAT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
static INT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
let float_ws = FLOAT_WORKSPACE
.get_or_init(|| unsafe { stream.alloc::<u8>(FLOAT_WORKSPACE_SIZE).unwrap() });
let int_ws = INT_WORKSPACE
.get_or_init(|| unsafe { stream.alloc::<u8>(INT_WORKSPACE_SIZE).unwrap() });
let page_locked_ws = PAGE_LOCKED_WORKSPACE.get_or_init(|| unsafe {
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let status = libc::posix_memalign(&mut ptr, 4096, INT_WORKSPACE_SIZE);
assert_eq!(status, 0, "Failed to allocate page-locked workspace");
let cuda_status = cuda_pin_memory(ptr, INT_WORKSPACE_SIZE);
assert_eq!(cuda_status, 0, "Failed to pin memory");
PageLockedPtr(ptr as *mut u8)
});
let (float_ws_ptr, _fws_guard) = float_ws.device_ptr(stream);
let (int_ws_ptr, _iws_guard) = int_ws.device_ptr(stream);
// FlashInfer decode writes (total_q_tokens, heads, dim);
// luminal expects (heads, total_q_tokens, dim) — transpose at the end.
let output_elems = total_q_tokens * self.num_qo_heads * self.head_dim;
let temp_out_buf =
unsafe { stream.alloc::<u8>(output_elems * std::mem::size_of::<f32>())? };
let (temp_out_ptr, _tmp_guard) = temp_out_buf.device_ptr(stream);
// PrefillPlanInfo has 15 entries, DecodePlanInfo fewer — 16 is enough.
let mut plan_info_buf = [0i64; 16];
let mut plan_info_len: i32 = 0;
// ── BatchDecode path ──
// Prefill kernels require fp16/bf16 tensor-core MMA; the C API returns -1
// when called from the fp32 pipeline. We only use decode here.
let plan_ret = unsafe {
(lib.plan)(
float_ws_ptr as *mut std::ffi::c_void,
FLOAT_WORKSPACE_SIZE,
int_ws_ptr as *mut std::ffi::c_void,
INT_WORKSPACE_SIZE,
page_locked_ws.0 as *mut std::ffi::c_void,
kv_indptr_host.as_ptr() as *mut i32,
batch_size as i32,
self.num_qo_heads as i32,
self.num_kv_heads as i32,
self.page_size as i32,
self.head_dim as i32,
cu_stream,
plan_info_buf.as_mut_ptr(),
&mut plan_info_len,
)
};
if plan_ret != 0 {
return Err(anyhow::anyhow!(
"FlashInfer decode plan failed with error code {plan_ret}"
));
}
let mut plan_info = self.plan_info.lock().unwrap();
plan_info.clear();
plan_info.extend_from_slice(&plan_info_buf[..plan_info_len as usize]);
let run_ret = unsafe {
(lib.run)(
float_ws_ptr as *mut std::ffi::c_void,
FLOAT_WORKSPACE_SIZE,
int_ws_ptr as *mut std::ffi::c_void,
plan_info.as_mut_ptr(),
plan_info.len() as i32,
q_buf.ptr() as *mut f32,
k_buf.ptr() as *mut f32,
v_buf.ptr() as *mut f32,
kv_indptr_buf.ptr() as *mut i32,
indices_ptr as *mut i32,
last_page_ptr as *mut i32,
temp_out_ptr as *mut f32,
batch_size as i32,
self.num_qo_heads as i32,
self.num_kv_heads as i32,
self.page_size as i32,
self.head_dim as i32,
cu_stream,
)
};
drop(plan_info);
if run_ret != 0 {
return Err(anyhow::anyhow!(
"FlashInfer decode run failed with error code {run_ret}"
));
}
// Transpose (total_q_tokens, heads, dim) → (heads, total_q_tokens, dim)
unsafe {
(lib.transpose_output)(
temp_out_ptr as *const f32,
out_buf.ptr() as *mut f32,
total_q_tokens as i32,
self.num_qo_heads as i32,
self.head_dim as i32,
cu_stream,
);
}
Ok(())
}
fn output_size(&self) -> Expression {
self.batch_dim * self.num_qo_heads * self.head_dim
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn stats_name(&self) -> Option<&'static str> {
Some("FlashInferAttention")
}
}
/// Pin host memory for CUDA async memcpy.
///
/// `cudaHostRegister` lives in libcudart, which cudarc doesn't link to our
/// binary. Resolve it via `dlopen`/`dlsym` so we don't need a build script or
/// a `#[link]` directive — keeping the crate buildable without any nvcc-side
/// dependencies.
unsafe fn cuda_pin_memory(ptr: *mut std::ffi::c_void, size: usize) -> i32 {
type HostRegisterFn = unsafe extern "C" fn(*mut std::ffi::c_void, usize, u32) -> i32;
static FN: OnceLock<usize> = OnceLock::new();
let raw = *FN.get_or_init(|| unsafe {
let lib = [
"libcudart.so",
"libcudart.so.13",
"libcudart.so.12",
"libcudart.so.11",
]
.iter()
.find_map(|n| libloading::Library::new(*n).ok())
.expect("FlashInfer: could not dlopen libcudart for cudaHostRegister");
let sym: libloading::Symbol<HostRegisterFn> = lib
.get(b"cudaHostRegister\0")
.expect("FlashInfer: libcudart missing cudaHostRegister symbol");
let ptr = *sym as *const () as usize;
// Keep libcudart resident for the process lifetime so the function
// pointer remains valid.
std::mem::forget(lib);
ptr
});
let f: HostRegisterFn = unsafe { std::mem::transmute(raw) };
// cudaHostRegisterDefault = 0
unsafe { f(ptr, size, 0) }
}

View File

@@ -0,0 +1,357 @@
// FlashInfer batch decode + prefill wrapper for luminal_cuda.
// JIT-compiled at runtime with -DLUMINAL_HEAD_DIM=N.
//
// Decode: instantiated for f32 (scalar vectorized dot products, no tensor cores).
// Prefill: instantiated for f16 (requires tensor core MMA + ldmatrix).
// The C API accepts fp32 buffers; cast kernels convert fp32↔fp16 at the boundary.
//
// NHD layout. GQA group_size and page_size are runtime parameters.
#ifndef LUMINAL_HEAD_DIM
#error "LUMINAL_HEAD_DIM must be defined (e.g. -DLUMINAL_HEAD_DIM=128)"
#endif
// Include utils.cuh first to get the original DISPATCH_HEAD_DIM, then override it
// to only instantiate our specific HEAD_DIM. This avoids a compile error in
// cascade.cuh where HEAD_DIM=512 + f32 triggers vec_size=16, vec_bits=512
// which exceeds cp_async's 256-bit limit.
#include <flashinfer/utils.cuh>
#undef DISPATCH_HEAD_DIM
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
{ \
constexpr size_t HEAD_DIM = LUMINAL_HEAD_DIM; \
__VA_ARGS__ \
}
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/decode.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/page.cuh>
#include <flashinfer/pos_enc.cuh>
#include "wrapper.h"
#include <cstring>
#include <vector>
#include <cuda_fp16.h>
using namespace flashinfer;
// ── Decode types (f32) ──
using DTypeQ = float;
using DTypeKV = float;
using DTypeO = float;
using IdType = int32_t;
// ── Prefill types (f16 compute, fp32 external interface) ──
using PrefillDTypeQ = half;
using PrefillDTypeKV = half;
using PrefillDTypeO = half;
constexpr uint32_t HEAD_DIM = LUMINAL_HEAD_DIM;
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;
// Attention variants
using Variant = DefaultAttention</*use_custom_mask=*/false,
/*use_sliding_window=*/false,
/*use_logits_soft_cap=*/false,
/*use_alibi=*/false>;
using CausalVariant = DefaultAttention</*use_custom_mask=*/false,
/*use_sliding_window=*/false,
/*use_logits_soft_cap=*/false,
/*use_alibi=*/false>;
// Decode params (f32)
using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
// Prefill params (f16)
using PrefillParams = BatchPrefillPagedParams<PrefillDTypeQ, PrefillDTypeKV, PrefillDTypeO, IdType>;
// Forward declarations
namespace flashinfer {
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
typename Params>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
float* tmp_s, bool enable_pdl,
cudaStream_t stream);
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
PosEncodingMode POS_ENCODING_MODE, bool USE_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
float* tmp_s, bool enable_pdl,
cudaStream_t stream);
}
// Explicit instantiation: decode kernel (f32)
template cudaError_t flashinfer::BatchDecodeWithPagedKVCacheDispatched<
HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
DecodeParams params, DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
// Explicit instantiation: prefill kernels (f16, causal mask, CTA_TILE_Q=16/64/128)
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
16, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
64, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
128, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
// ── fp32 ↔ fp16 cast kernels ──
__global__ void cast_f32_to_f16_kernel(const float* src, half* dst, size_t n) {
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dst[i] = __float2half(src[i]);
}
__global__ void cast_f16_to_f32_kernel(const half* src, float* dst, size_t n) {
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dst[i] = __half2float(src[i]);
}
extern "C" {
int flashinfer_batch_decode_plan(
void* float_workspace, size_t float_ws_size,
void* int_workspace, size_t int_ws_size,
void* page_locked_int_workspace,
int32_t* indptr_h, int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream,
int64_t* plan_info_out, int* plan_info_len_out)
{
(void)head_dim; // fixed at compile time
DecodePlanInfo plan_info;
uint32_t group_size = num_qo_heads / num_kv_heads;
// We need to dispatch on GROUP_SIZE to get the right work estimation function
cudaError_t status = cudaSuccess;
// Use a lambda to dispatch on group size
auto do_plan = [&]<uint32_t GROUP_SIZE>() -> cudaError_t {
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>;
return DecodePlan<HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
float_workspace, float_ws_size,
int_workspace, page_locked_int_workspace,
int_ws_size, plan_info, indptr_h,
(uint32_t)batch_size, (uint32_t)num_qo_heads,
(uint32_t)page_size, /*enable_cuda_graph=*/false,
stream, work_estimation_func);
};
switch (group_size) {
case 1: status = do_plan.operator()<1>(); break;
case 2: status = do_plan.operator()<2>(); break;
case 4: status = do_plan.operator()<4>(); break;
case 8: status = do_plan.operator()<8>(); break;
default: return -1; // unsupported group size
}
if (status != cudaSuccess) return (int)status;
auto vec = plan_info.ToVector();
*plan_info_len_out = (int)vec.size();
std::memcpy(plan_info_out, vec.data(), vec.size() * sizeof(int64_t));
return 0;
}
int flashinfer_batch_decode_run(
void* float_workspace, size_t float_ws_size,
void* int_workspace,
int64_t* plan_info_vec, int plan_info_len,
float* q,
float* k_cache,
float* v_cache,
int32_t* kv_indptr,
int32_t* kv_indices,
int32_t* kv_last_page_len,
float* output,
int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream)
{
(void)head_dim; // fixed at compile time
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec, plan_info_vec + plan_info_len));
// Construct paged_kv_t with NHD layout
paged_kv_t<DTypeKV, IdType> paged_kv(
(uint32_t)num_kv_heads,
(uint32_t)page_size,
HEAD_DIM,
(uint32_t)batch_size,
QKVLayout::kNHD,
k_cache,
v_cache,
kv_indices,
kv_indptr,
kv_last_page_len);
DecodeParams params;
params.q = q;
params.q_rope_offset = nullptr;
params.paged_kv = paged_kv;
params.o = output;
params.lse = nullptr;
params.maybe_alibi_slopes = nullptr;
params.padded_batch_size = plan_info.padded_batch_size;
params.num_qo_heads = (uint32_t)num_qo_heads;
// Q buffer is (batch, num_qo_heads * head_dim) flat — the graph's split_dims + transpose
// are stride tricks, no data movement. So the actual memory layout is (batch, heads, dim).
params.q_stride_n = num_qo_heads * HEAD_DIM;
params.q_stride_h = HEAD_DIM;
params.window_left = -1; // no sliding window
params.logits_soft_cap = 0.0f;
params.sm_scale = 1.0f / sqrtf((float)HEAD_DIM);
params.rope_rcp_scale = 1.0f;
params.rope_rcp_theta = 1.0f;
// Set plan info pointers
params.request_indices =
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.request_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_tile_indices_offset);
params.o_indptr =
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.o_indptr_offset);
params.kv_chunk_size_ptr =
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_chunk_size_ptr_offset);
params.block_valid_mask = nullptr;
params.partition_kv = false;
DTypeO* tmp_v = nullptr;
float* tmp_s = nullptr;
if (plan_info.split_kv) {
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_workspace, plan_info.v_offset);
tmp_s = GetPtrFromBaseOffset<float>(float_workspace, plan_info.s_offset);
if (plan_info.enable_cuda_graph) {
params.block_valid_mask =
GetPtrFromBaseOffset<bool>(int_workspace, plan_info.block_valid_mask_offset);
}
}
cudaError_t status =
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, Variant>(
params, tmp_v, tmp_s, /*enable_pdl=*/false, stream);
return (int)status;
}
// ═══════════════════════════════════════════════════════════
// BatchPrefill (fp16/bf16 only — tensor core MMA requires 16-bit inputs)
// ═══════════════════════════════════════════════════════════
//
// The prefill kernel templates are instantiated above for fp16. These C API
// functions accept fp32 pointers (matching the current luminal pipeline) but
// return -1 to indicate that fp32 prefill is not supported. When native fp16
// support is added, these will accept fp16 pointers and call through to the
// instantiated templates.
int flashinfer_batch_prefill_plan(
void*, size_t, void*, size_t, void*,
int32_t*, int32_t*, int, int,
int, int, int, int, cudaStream_t,
int64_t*, int*)
{
return -1; // fp32 not supported — requires fp16/bf16
}
int flashinfer_batch_prefill_run(
void*, size_t, void*,
int64_t*, int,
float*, float*, float*,
int32_t*, int32_t*, int32_t*, int32_t*,
float*, int, int, int, int, int, int, cudaStream_t)
{
return -1; // fp32 not supported — requires fp16/bf16
}
} // extern "C"
// ── Slot index extraction kernel (outside extern "C" for __global__) ──
__global__ void extract_slot_indices_kernel(
const int32_t* flat_idx, int32_t* out, int c, int kv_dim) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < c) out[i] = flat_idx[i * kv_dim] / kv_dim;
}
extern "C" void flashinfer_extract_slot_indices(
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
cudaStream_t stream) {
if (c == 0) return;
int threads = 256;
int blocks = (c + threads - 1) / threads;
extract_slot_indices_kernel<<<blocks, threads, 0, stream>>>(
flat_idx, out, c, kv_dim);
}
// ── Derive CSR indptr from attention mask ──
// Mask is (s, c) f32. Entries > -1e9 are "valid" (0.0), rest are -inf.
// Per-row count of valid entries = context length for that sequence.
// Output: indptr[0..=s] with indptr[0]=0 and indptr[i+1] = indptr[i] + ctx_len[i].
// Single thread is fine since s is tiny (batch_size during decode, typically 1-8).
__global__ void derive_indptr_kernel(
const float* mask, int32_t* indptr, int s, int c) {
if (threadIdx.x != 0 || blockIdx.x != 0) return;
indptr[0] = 0;
for (int i = 0; i < s; i++) {
int count = 0;
for (int j = 0; j < c; j++) {
if (mask[i * c + j] > -1e9f) count++;
}
indptr[i + 1] = indptr[i] + count;
}
}
extern "C" void flashinfer_derive_indptr_from_mask(
const float* mask, int32_t* indptr, int s, int c,
cudaStream_t stream) {
if (s == 0) return;
derive_indptr_kernel<<<1, 1, 0, stream>>>(mask, indptr, s, c);
}
// ── Output transpose: (batch, heads, dim) → (heads, batch, dim) ──
// FlashInfer writes output as (batch, heads, dim) but Luminal expects (heads, batch, dim).
// For batch=1 these are identical; for batch>1 we need an explicit transpose.
__global__ void transpose_bhd_to_hbd_kernel(
const float* src, float* dst, int batch, int heads, int dim) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * heads * dim;
if (idx >= total) return;
// Decompose linear index into (b, h, d) for src layout
int d = idx % dim;
int h = (idx / dim) % heads;
int b = idx / (heads * dim);
// Write to (h, b, d) layout in dst
dst[h * batch * dim + b * dim + d] = src[idx];
}
extern "C" void flashinfer_transpose_output(
const float* src, float* dst,
int batch, int heads, int dim,
cudaStream_t stream) {
int total = batch * heads * dim;
if (total == 0) return;
int threads = 256;
int blocks = (total + threads - 1) / threads;
transpose_bhd_to_hbd_kernel<<<blocks, threads, 0, stream>>>(
src, dst, batch, heads, dim);
}

View File

@@ -0,0 +1,93 @@
#pragma once
#include <cuda_runtime.h>
#include <stdint.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
#endif
// Plan phase: CPU-side scheduling. Must call before each new batch config.
// Returns 0 on success, non-zero on failure.
int flashinfer_batch_decode_plan(
void* float_workspace, size_t float_ws_size,
void* int_workspace, size_t int_ws_size,
void* page_locked_int_workspace,
int32_t* indptr_h, int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream,
int64_t* plan_info_out, int* plan_info_len_out);
// Run phase: GPU kernel launch.
// Returns 0 on success, non-zero on failure.
int flashinfer_batch_decode_run(
void* float_workspace, size_t float_ws_size,
void* int_workspace,
int64_t* plan_info_vec, int plan_info_len,
float* q, // [batch_size, num_qo_heads, head_dim]
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
float* v_cache, // same layout
int32_t* kv_indptr, // [batch_size + 1]
int32_t* kv_indices, // [total_pages]
int32_t* kv_last_page_len, // [batch_size]
float* output, // [batch_size, num_qo_heads, head_dim]
int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream);
// Extract slot indices from a flat gather index tensor.
// flat_idx shape: (c, kv_dim) i32, out shape: (c,) i32.
// out[i] = flat_idx[i * kv_dim] / kv_dim
void flashinfer_extract_slot_indices(
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
cudaStream_t stream);
// Derive CSR indptr from attention mask.
// mask shape: (s, c) f32. Entries > -1e9 are valid.
// indptr shape: (s + 1,) i32. indptr[0] = 0, indptr[i+1] = cumsum of valid counts.
void flashinfer_derive_indptr_from_mask(
const float* mask, int32_t* indptr, int s, int c,
cudaStream_t stream);
// Transpose output from (batch, heads, dim) to (heads, batch, dim).
void flashinfer_transpose_output(
const float* src, float* dst,
int batch, int heads, int dim,
cudaStream_t stream);
// ── BatchPrefill with Paged KV Cache ──
// Plan phase for batch prefill.
// Returns 0 on success, non-zero on failure.
int flashinfer_batch_prefill_plan(
void* float_workspace, size_t float_ws_size,
void* int_workspace, size_t int_ws_size,
void* page_locked_int_workspace,
int32_t* qo_indptr_h, int32_t* kv_indptr_h,
int total_num_rows, int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream,
int64_t* plan_info_out, int* plan_info_len_out);
// Run phase for batch prefill.
// Returns 0 on success, non-zero on failure.
int flashinfer_batch_prefill_run(
void* float_workspace, size_t float_ws_size,
void* int_workspace,
int64_t* plan_info_vec, int plan_info_len,
float* q, // [total_num_rows, num_qo_heads, head_dim]
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
float* v_cache, // same layout
int32_t* qo_indptr, // [batch_size + 1] on GPU
int32_t* kv_indptr, // [batch_size + 1] on GPU
int32_t* kv_indices, // [total_pages]
int32_t* kv_last_page_len, // [batch_size]
float* output, // [total_num_rows, num_qo_heads, head_dim]
int total_num_rows, int batch_size,
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
cudaStream_t stream);
#ifdef __cplusplus
}
#endif

View File

@@ -1,17 +1,122 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaSlice, CudaStream};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use luminal::{op::EgglogOp, prelude::*};
pub mod compute_attn_mask;
mod cublas;
mod cublaslt;
pub mod flashinfer;
pub mod moe;
pub use compute_attn_mask::ComputeAttnMask;
pub type Ops = (
// cublas::CuBlasSgemmV2,
cublaslt::CuBlasLt,
moe::GLUMoE,
compute_attn_mask::ComputeAttnMask,
flashinfer::FlashInferAttention,
);
#[cfg(test)]
pub(crate) type CublasLtTypeTuple = (
luminal::dtype::DType,
luminal::dtype::DType,
luminal::dtype::DType,
luminal::dtype::DType,
&'static str,
luminal::dtype::DType,
);
#[cfg(test)]
pub(crate) fn cublaslt_type_tuple(op: &dyn HostOp) -> Option<CublasLtTypeTuple> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::type_tuple)
}
#[cfg(test)]
pub(crate) type CublasLtScaleValues = (f64, f64);
#[cfg(test)]
pub(crate) fn cublaslt_scale_values(op: &dyn HostOp) -> Option<CublasLtScaleValues> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::scale_values)
}
#[cfg(test)]
pub(crate) fn cublaslt_epilogue(op: &dyn HostOp) -> Option<&'static str> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::epilogue)
}
#[cfg(test)]
pub(crate) type CublasLtMatrixOrders = (&'static str, &'static str, &'static str, &'static str);
#[cfg(test)]
pub(crate) fn cublaslt_matrix_orders(op: &dyn HostOp) -> Option<CublasLtMatrixOrders> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::matrix_orders)
}
#[cfg(test)]
pub(crate) type CublasLtTransposeOps = (&'static str, &'static str);
#[cfg(test)]
pub(crate) fn cublaslt_transpose_ops(op: &dyn HostOp) -> Option<CublasLtTransposeOps> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::transpose_ops)
}
#[cfg(test)]
pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
op.as_any()
.downcast_ref::<cublaslt::CuBlasLt>()
.map(cublaslt::CuBlasLt::c_d_layouts_match)
}
/// Non-owning device buffer handle used by host operations.
///
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
/// the reusable arena, or an external pointer. Host ops only need the pointer
/// and the logical byte length.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DeviceBuffer {
ptr: u64,
len: usize,
}
impl DeviceBuffer {
pub fn new(ptr: u64, len: usize) -> Self {
Self { ptr, len }
}
pub fn ptr(self) -> u64 {
self.ptr
}
pub fn len(self) -> usize {
self.len
}
pub fn is_empty(self) -> bool {
self.len == 0
}
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
let mut host = vec![0u8; self.len];
unsafe {
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
}
stream.synchronize()?;
Ok(host)
}
}
/// Host operations that execute on the CPU but orchestrate GPU work.
///
/// This includes operations like cuBLAS calls and CUDA graph executions.
@@ -29,7 +134,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()>;
@@ -48,6 +153,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
vec![]
}
/// Returns relative lifetimes for extra buffer nodes within this host op.
///
/// The tuple is `(node, first_step, last_step)`, where steps are local to
/// this host op's execution. Returning `None` tells the runtime to treat
/// every extra buffer as live for the whole host op.
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
None
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.

View File

@@ -32,7 +32,7 @@ use crate::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
},
host::HostOp,
host::{DeviceBuffer, HostOp},
try_create_cublaslt,
};
@@ -294,27 +294,140 @@ impl HostOp for GLUMoE {
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// Resolve dimensions
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
if inputs.len() < 6 {
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
}
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
let x_buf = buffers[&inputs[0]];
let seq = x_buf.len() / (hidden * 4);
// Resolve dimensions
let hidden = self
.gu_matmul_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
let intermediate = self
.dn_matmul_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
let top_k = self
.output_k
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
let gu_io = self
.gu_io
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
let dn_io = self
.dn_io
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
if hidden == 0 || intermediate == 0 {
anyhow::bail!(
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
);
}
if top_k == 0 {
return Ok(());
}
if gu_io % hidden != 0 {
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
}
if dn_io % intermediate != 0 {
anyhow::bail!(
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
);
}
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
let down_hidden = dn_io / intermediate;
if gate_up_dim != intermediate * 2 {
anyhow::bail!(
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
gate_up_dim,
intermediate * 2
);
}
if down_hidden != hidden {
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
}
let output_bytes = self
.output_bytes()
.exec(dyn_map)
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
if output_bytes % (hidden * 4) != 0 {
anyhow::bail!(
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
hidden * 4
);
}
let seq = output_bytes / (hidden * 4);
if seq == 0 {
return Ok(());
}
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
buffers.get(&node).copied().ok_or_else(|| {
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
})
};
// Get input/output buffers
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
let mode_aux_buf = buffers[&inputs[5]];
let output_buf = buffers[&self_node]; // [seq, hidden] F32
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
let topk_bytes = seq * top_k * 4;
if x_buf.len() < output_bytes {
anyhow::bail!(
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
x_buf.len()
);
}
if topk_idx_buf.len() < topk_bytes {
anyhow::bail!(
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
topk_idx_buf.len()
);
}
if topk_vals_buf.len() < topk_bytes {
anyhow::bail!(
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
topk_vals_buf.len()
);
}
if output_buf.len() < output_bytes {
anyhow::bail!(
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
output_buf.len()
);
}
let gu_stride_bytes = gate_up_dim * hidden * 2;
let down_stride_bytes = hidden * intermediate * 2;
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
anyhow::bail!(
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
gate_up_buf.len()
);
}
let num_experts = gate_up_buf.len() / gu_stride_bytes;
if num_experts == 0 {
anyhow::bail!("GLUMoE has no expert weights");
}
if down_buf.len() < num_experts * down_stride_bytes {
anyhow::bail!(
"GLUMoE down weight buffer too small: have {} bytes, need {}",
down_buf.len(),
num_experts * down_stride_bytes
);
}
// Get raw device pointer addresses
let x_ptr = buf_ptr(x_buf, stream);
@@ -326,21 +439,17 @@ impl HostOp for GLUMoE {
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
// Read top-k routing values from GPU
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
let idx_k = topk_idx_i32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let val_k = topk_vals_f32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let top_k = idx_k.min(val_k);
if seq > 0 && top_k == 0 {
return Ok(());
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
if expert_idx < 0 || expert_idx as usize >= num_experts {
anyhow::bail!(
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
);
}
}
// Mode-dependent expert weights used for the final reduction:
@@ -350,9 +459,16 @@ impl HostOp for GLUMoE {
let expert_weights_f32: &[f32] = match self.mode {
GLUMoEMode::SwiGLU => topk_vals_f32,
GLUMoEMode::GemmaGELU => {
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
debug_assert!(per_expert_scale_f32.len() >= num_experts);
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
let per_expert_scale_bytes = num_experts * 4;
if per_expert_scale_host.len() < per_expert_scale_bytes {
anyhow::bail!(
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
per_expert_scale_host.len()
);
}
let per_expert_scale_f32: &[f32] =
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
expert_weights_storage.resize(seq * top_k, 0.0);
for t in 0..seq {
let base = t * top_k;
@@ -382,10 +498,10 @@ impl HostOp for GLUMoE {
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
let hid_ptr = buf_ptr(&hidden_tmp, stream);
let ws_ptr = buf_ptr(&workspace, stream);
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
let hid_ptr = slice_ptr(&hidden_tmp, stream);
let ws_ptr = slice_ptr(&workspace, stream);
// Cast x F32 → BF16
let n_cast = (seq * hidden) as i32;
@@ -404,8 +520,8 @@ impl HostOp for GLUMoE {
}
// Per-token expert computation
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
for t in 0..seq {
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
@@ -507,7 +623,11 @@ impl HostOp for GLUMoE {
// Helpers
// ============================================================
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
buf.ptr()
}
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
let (ptr, _guard) = buf.device_ptr(stream);
ptr
}

View File

@@ -6,12 +6,8 @@
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
// pattern. Cascade prevention by typing.
//
// `compile()` is a *fallback* path. The fast path collapses each FE-rooted
// region into one CUDA kernel inside `region_codegen` and FusedX/FS/FE
// never reach kernel_to_host's compile loop. But extraction can produce
// LLIR shapes the detector doesn't sweep into a region, so each FusedX's
// standalone `compile()` falls back to emitting the same kernel its
// un-fused KernelX sibling would — correct, just one launch per op.
// Each FusedX must be absorbed into a FusionEnd-rooted region and compiled by
// `region_codegen`; standalone compilation is intentionally unsupported.
// =========================================================================
use std::sync::Arc;
@@ -27,11 +23,7 @@ use luminal::{
prelude::*,
};
use crate::{
compile_module_image_for_current_device, cuda_dtype,
kernel::KernelOp,
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
};
use crate::kernel::KernelOp;
pub type Ops = (
FusedSin,
@@ -55,135 +47,6 @@ type CompileOut = (
FxHashMap<char, CudaSlice<u8>>,
);
// =========================================================================
// Fallback kernel templates — used when a FusedX op reaches
// `kernel_to_host` standalone (region detection missed it). Same CUDA as
// the matching un-fused KernelX would emit, parameterised by the per-op
// body expression. The fast path goes through `region_codegen`.
// =========================================================================
#[allow(clippy::too_many_arguments)]
fn compile_unary_fallback(
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
kernel_name: &str,
body_expr: &str, // CUDA expression on `in[{in_idx}]`, e.g. "sinf(in[{in_idx}])"
shape: &[Expression],
in_strides: &[Expression],
out_strides: &[Expression],
dtype: DType,
) -> CompileOut {
let vars = shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(in_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(out_strides.iter().flat_map(|e| e.dyn_vars()))
.collect::<FxHashSet<_>>();
let cuda_ty = cuda_dtype(dtype);
let includes = dtype_includes(&[dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
let out_idx = flatten_strides(shape, out_strides).to_kernel();
let in_idx = flatten_strides(shape, in_strides).to_kernel();
let body = body_expr.replace("{in_idx}", &in_idx);
let kernel = format!(
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
\x20 if (const_z >= {n_elements}) return;\n\
\x20 out[{out_idx}] = {body};\n\
\x20 }}\n}}"
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function(kernel_name).unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let out_size = shape.iter().copied().product::<Expression>();
(
func,
module,
kernel,
(out_size.ceil_div(256), 1.into(), 1.into()),
(out_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
#[allow(clippy::too_many_arguments)]
fn compile_binary_fallback(
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
kernel_name: &str,
op_str: &str, // CUDA infix operator, e.g. "+", "*"
out_shape: &[Expression],
a_stride: &[Expression],
b_stride: &[Expression],
out_stride: &[Expression],
dtype: DType,
) -> CompileOut {
let vars = out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(a_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(b_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(out_stride.iter().flat_map(|e| e.dyn_vars()))
.collect::<FxHashSet<_>>();
let cuda_ty = cuda_dtype(dtype);
let includes = dtype_includes(&[dtype, dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_elements = out_shape
.iter()
.copied()
.product::<Expression>()
.to_kernel();
let out_idx = flatten_strides(out_shape, out_stride).to_kernel();
let a_idx = flatten_strides(out_shape, a_stride).to_kernel();
let b_idx = flatten_strides(out_shape, b_stride).to_kernel();
let kernel = format!(
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
\x20 __global__ void {kernel_name}({cuda_ty} *C, const {cuda_ty} *A, const {cuda_ty} *B{dyn_dims_param}) {{\n\
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
\x20 if (const_z >= {n_elements}) return;\n\
\x20 C[{out_idx}] = A[{a_idx}] {op_str} B[{b_idx}];\n\
\x20 }}\n}}"
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function(kernel_name).unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let out_size = out_shape.iter().copied().product::<Expression>();
(
func,
module,
kernel,
(out_size.ceil_div(256), 1.into(), 1.into()),
(out_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
@@ -255,19 +118,13 @@ macro_rules! impl_fused_unary {
impl KernelOp for $Name {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
_stream: &Arc<CudaStream>,
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> CompileOut {
compile_unary_fallback(
stream,
compile_cache,
$kernel_name,
$body,
&self.shape,
&self.in_strides,
&self.out_strides,
self.dtype,
)
unreachable!(concat!(
$sort,
" must be compiled through fusion region codegen"
))
}
fn output_size(&self) -> Expression {
self.shape.iter().copied().product()
@@ -379,20 +236,13 @@ macro_rules! impl_fused_binary {
impl KernelOp for $Name {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
_stream: &Arc<CudaStream>,
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> CompileOut {
compile_binary_fallback(
stream,
compile_cache,
$kernel_name,
$op_str,
&self.out_shape,
&self.a_stride,
&self.b_stride,
&self.out_stride,
self.dtype,
)
unreachable!(concat!(
$sort,
" must be compiled through fusion region codegen"
))
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()

View File

@@ -27,70 +27,7 @@ use luminal::{
prelude::*,
};
use crate::{
compile_module_image_for_current_device, cuda_dtype,
kernel::KernelOp,
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
};
/// Identity-memcpy kernel used as a *fallback* when a FusionStart or
/// FusionEnd reaches `kernel_to_host`'s compile loop standalone (i.e.,
/// region detection didn't sweep it into a `CompileUnit::Region`). The
/// fast path is region collapse, but model-fuzz extraction sometimes
/// produces LLIR shapes the detector doesn't catch; this keeps
/// execution correct in those cases.
#[allow(clippy::type_complexity)]
fn compile_identity_kernel(
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
kernel_name: &str,
shape: &[Expression],
strides: &[Expression],
dtype: DType,
) -> CompileOut {
let vars = shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(strides.iter().flat_map(|e| e.dyn_vars()))
.collect::<FxHashSet<_>>();
let cuda_ty = cuda_dtype(dtype);
let includes = dtype_includes(&[dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
let idx = flatten_strides(shape, strides).to_kernel();
let kernel = format!(
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
\x20 if (const_z >= {n_elements}) return;\n\
\x20 out[{idx}] = in[{idx}];\n\
\x20 }}\n}}"
);
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function(kernel_name).unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let out_size = shape.iter().copied().product::<Expression>();
(
func,
module,
kernel,
(out_size.ceil_div(256), 1.into(), 1.into()),
(out_size.min(256), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
use crate::kernel::KernelOp;
pub type Ops = (FusionStart, FusionEnd);
@@ -159,17 +96,10 @@ impl EgglogOp for FusionStart {
impl KernelOp for FusionStart {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
_stream: &Arc<CudaStream>,
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> CompileOut {
compile_identity_kernel(
stream,
compile_cache,
"fusion_start_k",
&self.shape,
&self.strides,
self.dtype,
)
unreachable!("FusionStart must be compiled through fusion region codegen")
}
fn output_size(&self) -> Expression {
self.shape.iter().copied().product()
@@ -183,6 +113,9 @@ impl KernelOp for FusionStart {
fn kernel_name(&self) -> &'static str {
"FusionStart"
}
fn output_aliases_input(&self) -> Option<usize> {
Some(0)
}
}
// =========================================================================
@@ -460,17 +393,10 @@ impl EgglogOp for FusionEnd {
impl KernelOp for FusionEnd {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
_stream: &Arc<CudaStream>,
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> CompileOut {
compile_identity_kernel(
stream,
compile_cache,
"fusion_end_k",
&self.shape,
&self.strides,
self.dtype,
)
unreachable!("FusionEnd must be compiled through fusion region codegen")
}
fn output_size(&self) -> Expression {
self.shape.iter().copied().product()

View File

@@ -93,9 +93,8 @@ pub(crate) enum CompileUnit {
/// (one whose e-graph congruence-deduplicated it across multiple
/// regions) into a different subgraph than the FE that absorbs it.
/// Without this global view, `build_compile_units` running on the FS's
/// subgraph would not see any FE walking back to the FS, would emit the
/// FS as `CompileUnit::Single`, and the markers' identity-memcpy
/// fallback would compile and launch — pure overhead at runtime.
/// subgraph would not see any FE walking back to the FS and would emit the
/// FS as `CompileUnit::Single`; marker standalone compilation is not supported.
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
let name_of = |idx: NodeIndex| -> Option<&'static str> {
llir_graph
@@ -196,11 +195,10 @@ pub(crate) fn build_compile_units(
// Non-marker, non-FusedX predecessor inside what
// we thought was a region. Shouldn't happen with
// the current rules; treat conservatively: do
// not absorb — let the kernel_to_host single
// path handle it. This means the region is
// not absorb it. This means the region is
// malformed and we likely should not have a
// region at all. Caller will see incomplete
// interior; the safer thing is to fall back.
// region at all; caller will see incomplete
// interior.
}
}
}
@@ -253,11 +251,10 @@ pub(crate) fn build_compile_units(
// FE nodes with their RegionUnit and skipping anything absorbed —
// either by a region in *this* subgraph (`absorbed`) or by any
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
// latter prevents the identity-memcpy fallback from firing on
// shared FS markers whose consumers live in other convex subgraphs:
// latter prevents shared FS markers whose consumers live in other
// convex subgraphs from being emitted as standalone compile units:
// those FSes are absorbed by some other region, and the consuming
// region reads from FS's external producer, so the FS never needs
// its own kernel.
// region reads from FS's external producer.
let mut units: Vec<CompileUnit> = Vec::new();
for &node in topo_order {
if let Some(region) = regions.remove(&node) {

View File

@@ -8,7 +8,7 @@ use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use itertools::Itertools;
use luminal::{
egglog_utils::{
api::{Rule, SortDef, app, eq, rule, set, sort, union, v},
api::{Rule, SortDef, Term, app, eq, rule, set, sort, union, v},
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
extract_dtype, extract_expr, extract_expr_list,
},
@@ -84,6 +84,45 @@ pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
.ruleset("kernel_lower")
}
/// Build a kernel rewrite for ops whose kernel dtype must match the first input.
///
/// This avoids extracting stale/conflicting dtype facts from the output e-class
/// after backend alternatives have been unioned into it.
fn kernel_rewrite_from_first_input<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
let hlir = H::default().sort();
let llir = L::default().sort();
let (mut args, hlir_kind_term) = hlir.new_call();
let first_inp = v("?__first_inp");
let tail = v("?__tail");
let inputs = Term::App {
variant: "ICons".to_string(),
args: vec![first_inp.clone(), tail],
};
let hlir_op = op_term(hlir_kind_term, inputs.clone());
let dt = v("?__dt");
args.add("dtype", dt.clone());
let llir_kind_term = llir.call(&args);
let llir_op = op_term(llir_kind_term, inputs);
rule(union(hlir_op, llir_op))
.fact(eq(dt, dtype(first_inp)))
.ruleset("kernel_lower")
}
fn dtype_for_ir_enode(egraph: &SerializedEGraph, ir_node: &ENodeId) -> Option<DType> {
let ir_class = egraph.node_to_class.get(ir_node)?;
let dtype_node = egraph.enodes.iter().find_map(|(node, (label, children))| {
(label == "dtype" && children.first() == Some(ir_class)).then_some(node)
})?;
let dtype_class = egraph.node_to_class.get(dtype_node)?;
egraph.eclasses.get(dtype_class)?.1.iter().find_map(|node| {
match egraph.enodes.get(node)?.0.as_str() {
"F32" | "F16" | "Bf16" | "Int" | "Bool" | "F4E2M1" | "F8E4M3" | "F8UE8M0" | "I4"
| "TF32" => Some(extract_dtype(egraph, node)),
_ => None,
}
})
}
#[derive(Default, Debug, Clone)]
pub struct KernelMaxReduce {
@@ -702,7 +741,7 @@ impl EgglogOp for KernelMul {
}
fn rewrites(&self) -> Vec<Rule> {
vec![kernel_rewrite::<Mul, Self>()]
vec![kernel_rewrite_from_first_input::<Mul, Self>()]
}
fn cleanup(&self) -> bool {
@@ -717,17 +756,45 @@ impl EgglogOp for KernelMul {
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let mut out_shape =
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
let mut a_stride =
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
let mut b_stride =
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
let mut out_stride =
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
// Some e-graph paths (length-changing rewrites such as `merge_dims`
// or `RemoveNthFromEnd`) leave a Mul kind enode whose shape and
// strides children are extracted to different lengths under the
// first-enode walk. The `enforce_consistent_first_kind_enodes`
// pass in `src/egglog_utils/mod.rs` repairs this where it can,
// but a handful of eclasses have *no* consistent variant in any
// of their stride sub-eclasses. For those we truncate to the
// SHORTEST length here so `flatten_strides` is structurally
// satisfied — the resulting kernel is numerically wrong for that
// candidate but harmless for the search, which profiles many
// candidates and steers toward the consistent ones.
let n = out_shape
.len()
.min(a_stride.len())
.min(b_stride.len())
.min(out_stride.len());
out_shape.truncate(n);
a_stride.truncate(n);
b_stride.truncate(n);
out_stride.truncate(n);
let dtype = input_enodes
.first()
.and_then(|node| dtype_for_ir_enode(egraph, node))
.unwrap_or_else(|| extract_dtype(egraph, kind_children[4]));
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
a_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
.unwrap(),
b_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
out_stride: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[4]),
out_shape,
a_stride,
b_stride,
out_stride,
dtype,
})),
input_enodes,
)
@@ -867,13 +934,29 @@ impl EgglogOp for KernelGather {
}
fn rewrites(&self) -> Vec<Rule> {
// Match HLIR Gather (now in Op format) and rewrite to KernelGather
// Match HLIR Gather (now in Op format) and rewrite to KernelGather.
// Mirror the IList pattern used by `Gather`'s own dtype propagation
// rule (`src/hlir.rs`): use a `?__tail` variable instead of a
// strict `(INil)` so we don't accidentally fail to match against a
// Gather Op whose IList tail eclass has been merged with another
// chain by some unrelated egglog union. Without this the kernel
// rewrite is silently skipped for some Gathers in deep models
// (e.g. YOLO's stacked make_contiguous chains).
let hlir_gather = luminal::hlir::Gather::default().sort();
let (gather_args, gather_kind_term) = hlir_gather.new_call();
// HLIR Gather inputs: [indexes, data] (n_inputs=2)
let indexes = v("?__indexes");
let data = v("?__data");
let gather_inputs = ilist(vec![indexes.clone(), data.clone()]);
let tail = v("?__tail");
let gather_inputs = Term::App {
variant: "ICons".to_string(),
args: vec![
indexes.clone(),
Term::App {
variant: "ICons".to_string(),
args: vec![data.clone(), tail],
},
],
};
let gather_op = op_term(gather_kind_term, gather_inputs);
let out_strides = SORTS
@@ -1210,7 +1293,25 @@ impl KernelOp for KernelScatter {
// Single-kernel scatter: copy dest→output then scatter src→output[indexes]
// Launched as 1 block of 1024 threads with __syncthreads() barrier.
// Uses float4 vectorized copy (4x throughput) for the copy phase.
// Uses float4 vectorized copy (16 bytes per op) for the copy phase.
//
// The number of dtype elements that fit in a float4 (16 bytes) depends
// on the element size. Computing `n_vec = n_dest / 4` would only be
// correct for 4-byte dtypes — for bf16 it walks 2× past the end of
// `out`, producing CUDA_ERROR_ILLEGAL_ADDRESS once the OOB region
// happens to land on an unmapped page.
let elements_per_vec: usize = match self.dtype {
DType::F64 => 2,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 8,
DType::Bool
| DType::I8
| DType::U8
| DType::F8UE8M0
| DType::F8E4M3
| DType::F8E5M2 => 16,
other => panic!("Unsupported dtype for scatter vectorization: {other:?}"),
};
let n_src_elements = self
.index_shape
.iter()
@@ -1235,15 +1336,17 @@ extern \"C\" {{
int tid = threadIdx.x;
long long n_dest = {n_dest_elements};
long long n_src = {n_src_elements};
// Phase 1: vectorized copy dest → output (float4 = 4 elements per op)
long long n_vec = n_dest / 4;
// Phase 1: vectorized copy dest → output (float4 = 16 bytes / iter,
// i.e. {elements_per_vec} {dtype} elements). n_vec is sized so the
// total bytes covered (`n_vec * 16`) never exceed `n_dest * sizeof({dtype})`.
long long n_vec = n_dest / {elements_per_vec};
float4 *out4 = (float4 *)out;
const float4 *dest4 = (const float4 *)dest;
for (long long i = tid; i < n_vec; i += blockDim.x) {{
out4[i] = dest4[i];
}}
// Handle remaining elements
long long remainder_start = n_vec * 4;
// Handle remaining elements (the dtype-tail past the last full float4).
long long remainder_start = n_vec * {elements_per_vec};
for (long long i = remainder_start + tid; i < n_dest; i += blockDim.x) {{
out[i] = dest[i];
}}
@@ -1437,19 +1540,22 @@ impl KernelOp for KernelIota {
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
vars.extend(self.range.dyn_vars());
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let range = self.range.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void iota_k(int *C{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {range}) return;
C[const_z] = {};
}}
}}",
@@ -1468,8 +1574,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.range, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.range.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -2832,6 +2938,14 @@ impl KernelOp for KernelCast {
) {
let out_dtype = cuda_dtype(self.out_dtype);
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let size = self.size.to_kernel();
let kernel = if self.in_dtype.bits() < 8 {
// Sub-byte packed types: multiple values packed per byte.
@@ -2841,9 +2955,11 @@ impl KernelOp for KernelCast {
let mask = (1u32 << bits) - 1;
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {size}) return;
long long bit_offset = idx * {bits};
long long byte_idx = bit_offset >> 3;
int bit_pos = (int)(bit_offset & 7);
@@ -2859,9 +2975,11 @@ extern \"C\" {{
let in_dtype = cuda_dtype(self.in_dtype);
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {size}) return;
out[const_z] = ({out_dtype})in[const_z];
}}
}}"
@@ -2880,8 +2998,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.size, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.size.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -3163,15 +3281,24 @@ impl KernelOp for KernelEmbed {
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.embed_dim.dyn_vars())
.collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let total_threads = batch_size * self.embed_dim;
let n_elements = total_threads.to_kernel();
let kernel = format!(
"
{}
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table) {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {n_elements}) return;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
long long embed_idx = idx % embed_dim;
@@ -3181,10 +3308,7 @@ extern \"C\" {{
int token_id = token_ids[token_offset];
out[out_offset + embed_idx] = embed_table[(long long)token_id * embed_dim + embed_idx];
}}
}}",
vars.iter()
.map(|i| format!("__constant__ int const_{i}[1];"))
.join("\n"),
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
@@ -3195,17 +3319,14 @@ extern \"C\" {{
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let constants = vars
.into_iter()
.map(|d| (d, module.get_global(&format!("const_{d}"), stream).unwrap()))
.collect();
let total_threads = batch_size * self.embed_dim;
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
(
func,
module,
kernel,
(total_threads, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(total_threads.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
constants,
)

View File

@@ -128,7 +128,8 @@ impl KernelOp for KernelMeanReduce {
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let n_outputs: Expression = self.out_shape.iter().copied().product();
let threads_per_block = 256; // 8 warps per block
let threads_per_block: usize = 256; // 8 warps per block
let n_warps = threads_per_block / 32;
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
@@ -149,12 +150,24 @@ extern \"C\" {{
long long iters = {iters};
long long iter_stride = {iter_stride};
{dtype} sum = 0;
for (long long i = 0; i < iters; i++) {{
sum += in[in_start + i * iter_stride];
}}
float thread_sum = 0.0f;
for (long long i = threadIdx.x; i < iters; i += {threads_per_block})
thread_sum += (float)in[in_start + i * iter_stride];
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
for (int offset = 16; offset > 0; offset >>= 1)
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
__shared__ float warp_sums[{n_warps}];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
if (lane == 0) warp_sums[warp] = thread_sum;
__syncthreads();
if (threadIdx.x == 0) {{
float sum = 0.0f;
for (int w = 0; w < {n_warps}; w++) sum += warp_sums[w];
out[{out_index}] = ({dtype})(sum / (float)iters);
}}
}}
}}",
dtype = dtype,
@@ -167,6 +180,8 @@ extern \"C\" {{
.substitute('z', Expression::from(1))
.simplify()
.to_kernel(),
threads_per_block = threads_per_block,
n_warps = n_warps,
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
@@ -183,9 +198,9 @@ extern \"C\" {{
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()), // grid
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
0.into(), // shmem size
(n_outputs, 1.into(), 1.into()), // grid
(threads_per_block.into(), 1.into(), 1.into()), // block
0.into(), // shmem size
FxHashMap::default(),
)
}
@@ -279,6 +294,9 @@ impl EgglogOp for KernelScatterNoCopy {
fn rewrites(&self) -> Vec<Rule> {
// Match KernelScatter and rewrite to KernelScatterNoCopy with ConsumedBuffer on dest.
// ConsumedBuffer wraps dest to signal in-place modification.
// This is only valid when the destination buffer can also represent
// the scatter output layout. If dest is a strided/broadcast view,
// regular Scatter must first materialize a contiguous output copy.
//
// Two-phase resolution:
// 1. During (run): cleanup rules delete ConsumedBuffer if dest is shared (another op uses it)
@@ -289,12 +307,31 @@ impl EgglogOp for KernelScatterNoCopy {
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
let mut rules = vec![
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
Rule::raw(
"(rule
((= ?list (ICons ?head ?tail)))
((consumed_buffer_ilist_contains ?list ?head))
:ruleset cleanup
:name \"consumed-buffer-ilist-contains-head\"
)",
),
Rule::raw(
"(rule
((= ?list (ICons ?head ?tail))
(consumed_buffer_ilist_contains ?tail ?item))
((consumed_buffer_ilist_contains ?list ?item))
:ruleset cleanup
:name \"consumed-buffer-ilist-contains-tail\"
)",
),
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
Rule::raw(
"(rule
(
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
(= ?dst ?os)
(= ?dty (dtype ?src))
)
(
@@ -324,13 +361,28 @@ impl EgglogOp for KernelScatterNoCopy {
"(rule
((= ?cb (ConsumedBuffer ?a))
(= ?op1 (Op ?k1 ?ilist1))
(= ?ilist1 (ICons ?cb ?rest1))
(consumed_buffer_ilist_contains ?ilist1 ?cb)
(= ?op2 (Op ?k2 ?ilist2))
(!= ?op1 ?op2)
(= ?ilist2 (ICons ?a ?t2)))
(consumed_buffer_ilist_contains ?ilist2 ?a))
((delete (ConsumedBuffer ?a)))
:ruleset cleanup
:name \"consumed-buffer-cleanup-pos\"
:name \"consumed-buffer-cleanup-shared-op-use\"
)",
));
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
rules.push(Rule::raw(
"(rule
((= ?cb (ConsumedBuffer ?dest))
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
:ruleset post_cleanup
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
)",
));
// Surviving ConsumedBuffers are valid — union with source and delete.
@@ -457,8 +509,8 @@ extern \"C\" {{
func,
module,
scatter_kernel,
(n_src, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(n_src.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use luminal::{
egglog_utils::{api::Rule, base::OP_KIND},
graph::LLIRGraph,
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
op::{EgglogOp, LLIROp},
prelude::{
petgraph::{Direction, algo::toposort, visit::EdgeRef},
@@ -22,7 +23,7 @@ use luminal::{
use tracing::{Level, enabled, span};
use crate::{
host::HostOp,
host::{DeviceBuffer, HostOp},
kernel::{
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
destroy_cuda_event,
@@ -47,8 +48,12 @@ struct CompiledKernel {
shared_mem: Expression,
/// Input node indices (for buffer lookup)
inputs: Vec<NodeIndex>,
/// Human-readable labels for input nodes, for launch diagnostics.
input_labels: Vec<String>,
/// Reference to the KernelOp for trait methods
kernel_op: Arc<Box<dyn KernelOp>>,
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
has_dyn_dims_param: bool,
/// Internal buffers allocated for this kernel
internal_bufs: Vec<CudaSlice<u8>>,
/// Device constants from compile()
@@ -68,7 +73,9 @@ impl CompiledKernel {
block: (Expression, Expression, Expression),
shared_mem: Expression,
inputs: Vec<NodeIndex>,
input_labels: Vec<String>,
kernel_op: Arc<Box<dyn KernelOp>>,
has_dyn_dims_param: bool,
constants: FxHashMap<char, CudaSlice<u8>>,
kernel_name: &'static str,
) -> Self {
@@ -79,7 +86,9 @@ impl CompiledKernel {
block,
shared_mem,
inputs,
input_labels,
kernel_op,
has_dyn_dims_param,
internal_bufs: Vec::new(),
constants,
graph_node: None,
@@ -226,7 +235,7 @@ impl HostOp for CudaGraphOp {
stream: &Arc<CudaStream>,
_self_node: NodeIndex,
_inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
self.execute_internal(stream, buffers, dyn_map)
@@ -258,6 +267,40 @@ impl HostOp for CudaGraphOp {
.collect()
}
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
let state = self.state.borrow();
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
let max_step = state.kernels.len().saturating_sub(1);
let mut touch = |node: NodeIndex, step: usize| {
lifetimes
.entry(node)
.and_modify(|(first, last)| {
*first = (*first).min(step);
*last = (*last).max(step);
})
.or_insert((step, step));
};
for (step, kernel) in state.kernels.iter().enumerate() {
for &input in &kernel.inputs {
touch(input, step);
}
touch(kernel.node, step);
}
for node in self.extra_buffer_nodes() {
lifetimes.entry(node).or_insert((0, max_step));
}
Some(
lifetimes
.into_iter()
.map(|(node, (start, end))| (node, start, end))
.collect(),
)
}
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
self.buffer_sizes.clone()
}
@@ -268,11 +311,64 @@ impl HostOp for CudaGraphOp {
}
impl CudaGraphOp {
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
match kernel_name {
"Constant" | "Iota" => Some(0),
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
| "Mul" => Some(2),
"Scatter" | "ScatterNoCopy" => Some(3),
_ => None,
}
}
fn kernel_requires_output_buffer(
kernel: &CompiledKernel,
dyn_map: &FxHashMap<char, usize>,
) -> bool {
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
&& kernel.kernel_op.output_aliases_input().is_none()
}
fn validate_kernel_pointers(
kernel: &CompiledKernel,
output_ptr: u64,
input_ptrs: &[u64],
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
anyhow::bail!(
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
if *input_ptr == 0 {
let input_label = kernel
.input_labels
.get(idx)
.map(String::as_str)
.unwrap_or("unknown");
anyhow::bail!(
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
kernel.kernel_name,
kernel.node,
input_node,
);
}
}
Ok(())
}
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
fn execute_internal(
&self,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let mut state = self.state.borrow_mut();
@@ -343,7 +439,7 @@ impl CudaGraphOp {
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
current_buffer_ptrs.insert(node, buf.ptr());
}
}
@@ -391,13 +487,26 @@ impl CudaGraphOp {
.iter()
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
dyn_dims_ptr
} else {
0
};
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
anyhow::bail!(
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
kernel_dyn_dims_ptr,
);
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
}
@@ -424,6 +533,19 @@ impl CudaGraphOp {
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
if grid_dim.0 == 0
|| grid_dim.1 == 0
|| grid_dim.2 == 0
|| block_dim.0 == 0
|| block_dim.1 == 0
|| block_dim.2 == 0
{
anyhow::bail!(
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
kernel.kernel_name,
kernel.node,
);
}
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let cu_func = unsafe { kernel.function.raw_function() };
@@ -452,7 +574,7 @@ impl CudaGraphOp {
&self,
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let ctx = stream.context().clone();
@@ -474,7 +596,7 @@ impl CudaGraphOp {
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
buffer_ptrs.insert(node, buf.ptr());
}
}
@@ -521,6 +643,19 @@ impl CudaGraphOp {
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
if grid_dim.0 == 0
|| grid_dim.1 == 0
|| grid_dim.2 == 0
|| block_dim.0 == 0
|| block_dim.1 == 0
|| block_dim.2 == 0
{
anyhow::bail!(
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
kernel.kernel_name,
kernel.node,
);
}
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
@@ -529,18 +664,41 @@ impl CudaGraphOp {
.iter()
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
dyn_dims_ptr
} else {
0
};
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
anyhow::bail!(
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
kernel.kernel_name,
kernel.node,
);
}
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
kernel_dyn_dims_ptr,
);
let mut params = UnifiedKernelParams::new(param_values);
let cu_func = unsafe { kernel.function.raw_function() };
let kernel_node = kernel.node;
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
eprintln!(
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
kernel.kernel_name,
kernel.node,
kernel.inputs.len(),
kernel.has_dyn_dims_param,
params.values.len(),
);
}
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
@@ -657,11 +815,41 @@ pub fn kernel_to_host(
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress the
// identity-memcpy fallback for shared FS leaves whose consumers live
// in a different convex subgraph than the FS itself.
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
// standalone marker compile units for shared FS leaves whose consumers
// live in a different convex subgraph than the FS itself.
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
graph
.node_weight(idx)
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
};
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
name_of(graph, node) == Some("FusionStart")
|| graph[node].to_op::<LoopStart>().is_some()
|| graph[node].to_op::<LoopEnd>().is_some()
|| graph[node].to_op::<LoopInput>().is_some()
|| graph[node].to_op::<LoopInputStatic>().is_some()
|| graph[node].to_op::<LoopOutput>().is_some()
|| graph[node].to_op::<LoopOutputSelect>().is_some()
};
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
let mut visited = FxHashSet::default();
while visited.insert(node) && is_transparent_input(graph, node) {
let Some(pred) = graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.next()
else {
break;
};
node = pred;
}
node
};
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
// Track all CudaGraphOp nodes and their subgraphs for edge creation
@@ -678,6 +866,7 @@ pub fn kernel_to_host(
let mut all_dyn_dims = FxHashSet::default();
let mut all_buffer_nodes = FxHashSet::default();
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
let mut external_inputs = FxHashSet::default();
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
@@ -691,9 +880,7 @@ pub fn kernel_to_host(
// Set global dyn dims ordering so compiles use consistent indices
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
global_dyn_dims.sort();
if !global_dyn_dims.is_empty() {
set_global_dyn_dims(global_dyn_dims.clone());
}
set_global_dyn_dims(global_dyn_dims.clone());
// Group the topo order into compile units: each FusionEnd-rooted
// region collapses to a single CompileUnit::Region (one fused
@@ -711,14 +898,35 @@ pub fn kernel_to_host(
.to_dialect::<dyn KernelOp>()
.unwrap();
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
kernel_op_ref.compile(cuda_stream, kernel_cache);
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
// Collect inputs from graph edges
let inputs: Vec<NodeIndex> = llir_graph
.edges_directed(*kernel_node_idx, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.map(|input| resolve_transparent_input(llir_graph, input))
.collect_vec();
if let Some(expected_inputs) =
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
{
assert_eq!(
inputs.len(),
expected_inputs,
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
kernel_op_ref.kernel_name(),
kernel_node_idx,
);
}
let input_labels = inputs
.iter()
.map(|&input| {
name_of(llir_graph, input)
.map(str::to_string)
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
})
.collect_vec();
// Collect buffer nodes and sizes
@@ -729,6 +937,12 @@ pub fn kernel_to_host(
all_buffer_sizes.insert(*kernel_node_idx, output_size);
}
all_buffer_nodes.extend(inputs.iter().copied());
external_inputs.extend(
inputs
.iter()
.copied()
.filter(|input| !subgraph.contains(input)),
);
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
@@ -739,7 +953,9 @@ pub fn kernel_to_host(
block,
shared_mem,
inputs,
input_labels,
kernel_op.clone(),
has_dyn_dims_param,
constants,
kernel_op.kernel_name(),
));
@@ -752,6 +968,7 @@ pub fn kernel_to_host(
cuda_stream,
kernel_cache,
);
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
// The region's CompiledKernel is keyed on the FE node
// (so FE provides trait methods like output_size /
@@ -763,7 +980,20 @@ pub fn kernel_to_host(
.to_dialect::<dyn KernelOp>()
.unwrap();
let inputs: Vec<NodeIndex> = region.external_inputs.clone();
let inputs: Vec<NodeIndex> = region
.external_inputs
.iter()
.copied()
.map(|input| resolve_transparent_input(llir_graph, input))
.collect();
let input_labels = inputs
.iter()
.map(|&input| {
name_of(llir_graph, input)
.map(str::to_string)
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
})
.collect_vec();
let output_size = fe_op_ref.output_size();
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
@@ -771,6 +1001,12 @@ pub fn kernel_to_host(
all_buffer_sizes.insert(region.fe_node, output_size);
}
all_buffer_nodes.extend(inputs.iter().copied());
external_inputs.extend(
inputs
.iter()
.copied()
.filter(|input| !subgraph.contains(input)),
);
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
@@ -781,7 +1017,9 @@ pub fn kernel_to_host(
compiled.block,
compiled.shared_mem,
inputs,
input_labels,
kernel_op,
has_dyn_dims_param,
compiled.constants,
"FusedRegion",
));
@@ -826,16 +1064,17 @@ pub fn kernel_to_host(
}
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
// Find external inputs: nodes outside subgraph that have edges into subgraph
let external_inputs: FxHashSet<NodeIndex> = subgraph
.iter()
.flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.filter(|src| !subgraph.contains(src))
})
.collect();
// Find external inputs: nodes outside subgraph that have edges into
// subgraph. Also include normalized FusionStart predecessors, because
// the compiled kernels read from the concrete producer buffer rather
// than the marker node.
external_inputs.extend(subgraph.iter().flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.map(|input| resolve_transparent_input(llir_graph, input))
.filter(|src| !subgraph.contains(src))
}));
// Add edges from external inputs to CudaGraphOp
for input in &external_inputs {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -41,9 +41,8 @@ fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
all_names
}
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
/// the ConsumedBuffer (not in any other ICons).
/// When dest is NOT shared with any other compute op, KernelScatterNoCopy should
/// be the only scatter variant left after post-cleanup.
#[test]
fn test_scatter_nocopy_selected_when_dest_unshared() {
let ctx = CudaContext::new(0).unwrap();
@@ -62,12 +61,17 @@ fn test_scatter_nocopy_selected_when_dest_unshared() {
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should be available (dest is not shared)
// KernelScatterNoCopy should be the only scatter variant (dest is not shared)
assert!(
names.iter().any(|n| n == "ScatterNoCopy"),
"Expected ScatterNoCopy to be available but got: {:?}",
names
);
assert!(
!names.iter().any(|n| n == "Scatter"),
"Regular Scatter should be pruned when ScatterNoCopy is valid, got: {:?}",
names
);
}
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
@@ -109,8 +113,74 @@ fn test_scatter_nocopy_not_selected_when_dest_shared() {
);
}
/// Shared-use detection must catch the destination in non-first input
/// positions too. Gather takes indexes first and data second, so this would
/// miss the unsafe read if cleanup only inspected the head of the input list.
#[test]
fn test_scatter_nocopy_not_selected_when_dest_shared_as_later_input() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let scatter_indexes = cx.tensor(3).as_dtype(DType::Int).persist();
let read_indexes = cx.tensor(1).as_dtype(DType::Int).persist();
let scatter_result = src.scatter(scatter_indexes, dest);
let _dest_also_read = dest.gather(read_indexes).output();
let _result = scatter_result.output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest is read by another op, got: {:?}",
names
);
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected regular Scatter but got: {:?}",
names
);
}
/// ScatterNoCopy aliases the destination buffer as the output, so it is only
/// valid when the destination layout already matches the contiguous scatter
/// output layout. Broadcast/expanded destinations need regular Scatter's
/// copy-then-scatter materialization.
#[test]
fn test_scatter_nocopy_not_selected_for_expanded_dest_layout() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
let dest = cx.tensor(128).expand_dim(0, 4).persist();
let src = cx.tensor((4, 128)).persist();
let indexes = cx.tensor((4, 128)).as_dtype(DType::Int).persist();
let _result = src.scatter(indexes, dest).output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest layout differs from output, got: {:?}",
names
);
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected regular Scatter but got: {:?}",
names
);
}
/// Actually execute the scatter and verify correctness.
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
/// Post-cleanup should force the valid no-copy extraction.
#[test]
fn test_scatter_execution_correctness() {
let ctx = CudaContext::new(0).unwrap();
@@ -135,9 +205,8 @@ fn test_scatter_execution_correctness() {
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
// Try many random extractions to cover both Scatter and ScatterNoCopy
// Try many random extractions; each valid choice should now use ScatterNoCopy.
let mut rng = rand::rng();
let mut tested_scatter = false;
let mut tested_nocopy = false;
for _ in 0..50 {
@@ -180,27 +249,24 @@ fn test_scatter_execution_correctness() {
let actual = rt.get_f32(result);
let variant = if has_nocopy {
tested_nocopy = true;
"ScatterNoCopy"
} else if has_scatter {
tested_scatter = true;
"Scatter"
} else {
"Unknown"
};
assert!(
has_nocopy,
"Expected ScatterNoCopy after post-cleanup, got no no-copy scatter"
);
assert!(
!has_scatter,
"Regular Scatter should be pruned when ScatterNoCopy is valid"
);
tested_nocopy = true;
assert_eq!(
actual, expected,
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
"Scatter result mismatch with ScatterNoCopy: got {:?}, expected {:?}",
actual, expected
);
}
println!(
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
tested_scatter, tested_nocopy
);
println!("Tested ScatterNoCopy: {}", tested_nocopy);
assert!(
tested_nocopy,
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
@@ -242,12 +308,28 @@ fn test_scatter_kv_cache_roundtrip() {
rt = cx.search(rt, 5);
// Print which scatter variant was selected
// Print and verify which scatter variant was selected
let scatter_names: Vec<_> = rt
.kernel_names()
.iter()
.copied()
.filter(|name| name.contains("catter"))
.collect();
for name in rt.kernel_names() {
if name.contains("catter") {
println!("Selected: {name}");
}
}
assert!(
scatter_names.contains(&"ScatterNoCopy"),
"Expected ScatterNoCopy in KV-cache search result, got: {:?}",
scatter_names
);
assert!(
!scatter_names.contains(&"Scatter"),
"Regular Scatter should be pruned from KV-cache search result, got: {:?}",
scatter_names
);
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
@@ -342,17 +424,31 @@ fn test_scatter_dual_cache() {
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
// Use seeded search for deterministic scatter variant selection.
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
// Use seeded search for deterministic variant selection.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Print selected variants
// Print and verify selected variants
let scatter_names: Vec<_> = rt
.kernel_names()
.iter()
.copied()
.filter(|name| name.contains("catter"))
.collect();
for name in rt.kernel_names() {
if name.contains("catter") {
println!("Dual test selected: {name}");
}
}
assert!(
!scatter_names.is_empty(),
"Expected scatter kernels in dual-cache search result"
);
assert!(
scatter_names.iter().all(|name| *name == "ScatterNoCopy"),
"Expected only ScatterNoCopy in dual-cache search result, got: {:?}",
scatter_names
);
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,941 @@
//! Unit + integration tests for the FlashInfer port.
//!
//! Four layers:
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
//! 3. Mask op correctness (GPU): `ComputeAttnMask` produces the right (s, c) mask.
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
//! compared against a luminal-compiled reference attention graph.
//!
//! GPU-dependent tests short-circuit when no CUDA device is available.
use std::sync::{Arc, Mutex};
use cudarc::driver::{CudaStream, DevicePtr};
use luminal::egglog_utils::{hlir_to_egglog, run_egglog};
use luminal::op::{EgglogOp, IntoEgglogOp};
use luminal::prelude::*;
use crate::host::flashinfer::FlashInferAttention;
use crate::host::{ComputeAttnMask, DeviceBuffer, HostOp};
use crate::runtime::CudaRuntime;
use crate::tests::utilities::get_cuda_stream;
/// Look up an op in `CudaRuntime::Ops::into_vec()` by its egglog sort name.
fn ops_contains_sort(name: &str) -> bool {
let ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
ops.iter().any(|op| {
// `SortDef` is opaque; its Debug repr starts with the sort name.
let sort_dbg = format!("{:?}", op.sort());
sort_dbg.contains(name)
})
}
// ─── Test-wide model dimensions ───────────────────────────────────────────
//
// Small Llama-shaped GQA model: nheads=8, kv_heads=2, group=4, head_dim=64.
// Chosen so HEAD_DIM ∈ {64, 128, 256} (FlashInfer constraint) and the test
// suite fits in O(1ms) of GPU time per case.
const HEAD_DIM: usize = 64;
const N_KV_HEADS: usize = 2;
const KV_GROUPS: usize = 4;
const N_HEADS: usize = N_KV_HEADS * KV_GROUPS;
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
const HIDDEN: usize = N_HEADS * HEAD_DIM;
// ─── Reference attention graph (Q*K^T → softmax → *V via the compiler) ───
fn build_attention_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
let mut cx = Graph::default();
let q_rope = cx.named_tensor("q_rope", ('s', HIDDEN));
let k_ctx = cx.named_tensor("k_ctx", ('c', KV_DIM));
let v_ctx_input = cx.named_tensor("v_ctx", ('c', KV_DIM));
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
let k = k_ctx.split_dims(1, HEAD_DIM).permute((1, 2, 0));
let v_ctx = v_ctx_input.split_dims(1, HEAD_DIM).transpose(0, 1);
// GQA broadcast: zero-stride Mul by 1.0
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
let weights = scores.softmax(2);
let out = weights.matmul(v_ctx);
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
let attn_out = attn_out.output();
(cx, q_rope, k_ctx, v_ctx_input, attn_out)
}
fn run_reference_attention(
stream: &Arc<CudaStream>,
q: &[f32],
k: &[f32],
v: &[f32],
batch_size: usize,
context_len: usize,
) -> Vec<f32> {
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
cx.set_dim('s', batch_size);
cx.set_dim('c', context_len);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
rt.set_data(v_t, v.to_vec());
rt = cx.search(rt, 3);
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
rt.set_data(v_t, v.to_vec());
rt.execute(&cx.dyn_map);
rt.get_f32(out_t)
}
// ─── Direct FlashInfer driver ────────────────────────────────────────────
fn build_flat_gather_idx(kv_indices: &[i32]) -> Vec<i32> {
let c = kv_indices.len();
let mut flat = Vec::with_capacity(c * KV_DIM);
for &slot in kv_indices {
let base = slot * KV_DIM as i32;
for j in 0..KV_DIM as i32 {
flat.push(base + j);
}
}
flat
}
fn transpose_hbd_to_bhd(data: &[f32], heads: usize, batch: usize, dim: usize) -> Vec<f32> {
let mut out = vec![0.0f32; data.len()];
for h in 0..heads {
for b in 0..batch {
for d in 0..dim {
out[b * heads * dim + h * dim + d] = data[h * batch * dim + b * dim + d];
}
}
}
out
}
fn alloc_dev(stream: &Arc<CudaStream>, bytes: usize) -> cudarc::driver::CudaSlice<u8> {
let bytes = bytes.max(1);
unsafe { stream.alloc::<u8>(bytes).unwrap() }
}
fn copy_to_dev<T: Copy>(stream: &Arc<CudaStream>, data: &[T]) -> cudarc::driver::CudaSlice<u8> {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
};
stream.clone_htod(bytes).unwrap()
}
/// Run FlashInferAttention.execute() directly and reshape the output to the
/// reference (batch, heads, dim) layout used by `run_reference_attention`.
fn run_flashinfer(
stream: &Arc<CudaStream>,
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
kv_indptr: &[i32],
kv_indices: &[i32],
batch_size: usize,
) -> Vec<f32> {
let q_buf = copy_to_dev(stream, q);
let k_buf = copy_to_dev(stream, k_cache);
let v_buf = copy_to_dev(stream, v_cache);
let flat_idx = build_flat_gather_idx(kv_indices);
let flat_idx_buf = copy_to_dev(stream, &flat_idx);
let mask_buf = alloc_dev(stream, 4); // unused but reserved
let qo_indptr: Vec<i32> = (0..=batch_size as i32).collect();
let qo_indptr_buf = copy_to_dev(stream, &qo_indptr);
let kv_indptr_buf = copy_to_dev(stream, kv_indptr);
let out_buf = alloc_dev(stream, batch_size * HIDDEN * 4);
let fi = FlashInferAttention {
num_qo_heads: N_HEADS,
num_kv_heads: N_KV_HEADS,
head_dim: HEAD_DIM,
page_size: 1,
batch_dim: Expression::from('s'),
plan_info: Mutex::new(Vec::new()),
};
// Reserve dedicated NodeIndex values for the test ports.
let nodes: Vec<NodeIndex> = (0..8).map(NodeIndex::new).collect();
let (q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n, out_n) = (
nodes[0], nodes[1], nodes[2], nodes[3], nodes[4], nodes[5], nodes[6], nodes[7],
);
let mut buffers = FxHashMap::default();
let q_ptr = q_buf.device_ptr(stream).0;
let k_ptr = k_buf.device_ptr(stream).0;
let v_ptr = v_buf.device_ptr(stream).0;
let idx_ptr = flat_idx_buf.device_ptr(stream).0;
let mask_ptr = mask_buf.device_ptr(stream).0;
let qo_ptr = qo_indptr_buf.device_ptr(stream).0;
let kv_ptr = kv_indptr_buf.device_ptr(stream).0;
let out_ptr = out_buf.device_ptr(stream).0;
buffers.insert(q_n, DeviceBuffer::new(q_ptr, q.len() * 4));
buffers.insert(k_n, DeviceBuffer::new(k_ptr, k_cache.len() * 4));
buffers.insert(v_n, DeviceBuffer::new(v_ptr, v_cache.len() * 4));
buffers.insert(idx_n, DeviceBuffer::new(idx_ptr, flat_idx.len() * 4));
buffers.insert(mask_n, DeviceBuffer::new(mask_ptr, 4));
buffers.insert(qo_n, DeviceBuffer::new(qo_ptr, qo_indptr.len() * 4));
buffers.insert(kv_n, DeviceBuffer::new(kv_ptr, kv_indptr.len() * 4));
buffers.insert(out_n, DeviceBuffer::new(out_ptr, batch_size * HIDDEN * 4));
let inputs = [q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n];
let mut dyn_map = FxHashMap::default();
dyn_map.insert('s', batch_size);
dyn_map.insert('c', kv_indices.len());
dyn_map.insert('r', kv_indptr.len());
fi.execute(stream, out_n, &inputs, &buffers, &dyn_map)
.expect("FlashInferAttention execute failed");
stream.synchronize().unwrap();
// Output is (heads, batch, dim); reshape to (batch, heads, dim).
let mut out_bytes = vec![0u8; batch_size * HIDDEN * 4];
unsafe {
cudarc::driver::result::memcpy_dtoh_async(&mut out_bytes, out_ptr, stream.cu_stream())
.unwrap();
}
stream.synchronize().unwrap();
let raw: Vec<f32> = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(out_bytes);
let len = bytes.len() / 4;
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
};
transpose_hbd_to_bhd(&raw, N_HEADS, batch_size, HEAD_DIM)
}
// ─── Helpers ─────────────────────────────────────────────────────────────
fn deterministic_f32(n: usize, seed: f32, scale: f32) -> Vec<f32> {
(0..n).map(|i| (i as f32 * seed).sin() * scale).collect()
}
fn assert_close(a: &[f32], b: &[f32], rtol: f32, atol: f32) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
let mut worst = (0usize, 0.0f32);
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
let diff = (x - y).abs();
if diff > worst.1 {
worst = (i, diff);
}
let tol = atol + rtol * y.abs();
assert!(
diff <= tol,
"mismatch at idx {i}: {x} vs {y} (|diff|={diff}, tol={tol})"
);
}
eprintln!("max |diff| = {:.2e} @ idx {}", worst.1, worst.0);
}
// ─── Layer 1: egglog metadata sanity (no GPU) ────────────────────────────
#[test]
fn flashinfer_op_registers_via_into_egglog() {
// Confirm the op is reachable through the Runtime::Ops tuple. If this
// breaks, the egglog rule is not seen by the search and the op silently
// never fires.
assert!(
ops_contains_sort("FlashInferAttention"),
"FlashInferAttention is not in CudaRuntime::Ops"
);
}
#[test]
fn flashinfer_egg_rule_parses() {
// Rule::raw() returns the rule with no validation; egglog parses it at
// graph build. Smoke-test by running it through the egglog frontend via
// a tiny program string.
let op = FlashInferAttention::default();
let rewrites = op.rewrites();
assert_eq!(rewrites.len(), 1);
// The rule must mention FlashInferAttention to be the right one.
let s = format!("{:?}", rewrites[0]);
assert!(
s.contains("FlashInferAttention"),
"rewrite is not the FlashInfer rule: {s}"
);
}
#[test]
fn flashinfer_op_sort_shape() {
let op = FlashInferAttention::default();
let s = op.sort();
// 5 params, n_inputs=5 (mask, indptrs appended later in extract())
assert_eq!(op.n_inputs(), 5);
let dbg = format!("{:?}", s);
assert!(dbg.contains("FlashInferAttention"));
}
#[test]
fn compute_attn_mask_registers() {
assert!(
ops_contains_sort("ComputeAttnMask"),
"ComputeAttnMask is not in CudaRuntime::Ops"
);
}
// ─── Layer 2: ComputeAttnMask correctness ────────────────────────────────
#[test]
fn compute_attn_mask_matches_cpu_reference() {
let Some(stream) = get_cuda_stream() else {
return;
};
// 2 sequences, seq0 length=3, seq1 length=2 → s=2 queries (one per seq, decode),
// c=5 total context tokens (3+2).
let s_dim = 2usize;
let c_dim = 5usize;
let q_pos: Vec<i32> = vec![2, 1]; // last position in each seq
let qo_indptr: Vec<i32> = vec![0, 1, 2];
let kv_indptr: Vec<i32> = vec![0, 3, 5];
let r = kv_indptr.len();
let q_pos_buf = stream
.clone_htod(unsafe {
std::slice::from_raw_parts(q_pos.as_ptr() as *const u8, q_pos.len() * 4)
})
.unwrap();
let qo_buf = stream
.clone_htod(unsafe {
std::slice::from_raw_parts(qo_indptr.as_ptr() as *const u8, qo_indptr.len() * 4)
})
.unwrap();
let kv_buf = stream
.clone_htod(unsafe {
std::slice::from_raw_parts(kv_indptr.as_ptr() as *const u8, kv_indptr.len() * 4)
})
.unwrap();
let out_bytes = s_dim * c_dim * 4;
let out_buf = unsafe { stream.alloc::<u8>(out_bytes).unwrap() };
let op = ComputeAttnMask {
s_dim: Expression::from(s_dim),
c_dim: Expression::from(c_dim),
};
let q_pos_n = NodeIndex::new(0);
let qo_n = NodeIndex::new(1);
let kv_n = NodeIndex::new(2);
let out_n = NodeIndex::new(3);
let mut buffers = FxHashMap::default();
buffers.insert(
q_pos_n,
DeviceBuffer::new(q_pos_buf.device_ptr(&stream).0, q_pos.len() * 4),
);
buffers.insert(
qo_n,
DeviceBuffer::new(qo_buf.device_ptr(&stream).0, qo_indptr.len() * 4),
);
buffers.insert(
kv_n,
DeviceBuffer::new(kv_buf.device_ptr(&stream).0, kv_indptr.len() * 4),
);
buffers.insert(
out_n,
DeviceBuffer::new(out_buf.device_ptr(&stream).0, out_bytes),
);
let inputs = [q_pos_n, qo_n, kv_n];
let mut dyn_map = FxHashMap::default();
dyn_map.insert('r', r);
op.execute(&stream, out_n, &inputs, &buffers, &dyn_map)
.unwrap();
stream.synchronize().unwrap();
let host_bytes = stream.clone_dtoh(&out_buf).unwrap();
let mask: Vec<f32> = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(host_bytes);
let len = bytes.len() / 4;
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
};
// Expected: query 0 (q_pos=2, seq 0) attends to ctx [0, 3) i.e. mask[0, 0..3]=0;
// query 1 (q_pos=1, seq 1) attends to ctx [3, 5) i.e. mask[1, 3..5]=0.
// Everywhere else is -1e10.
let mut expected = vec![-1e10f32; s_dim * c_dim];
for j in 0..3 {
expected[0 * c_dim + j] = 0.0;
}
for j in 3..5 {
expected[1 * c_dim + j] = 0.0;
}
assert_eq!(mask, expected);
}
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
#[test]
fn flashinfer_bs1_ctx4() {
let Some(stream) = get_cuda_stream() else {
return;
};
let batch_size = 1;
let context_len = 4;
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
let k = deterministic_f32(context_len * KV_DIM, 0.021, 0.1);
let v = deterministic_f32(context_len * KV_DIM, 0.031, 0.1);
let expected = run_reference_attention(&stream, &q, &k, &v, batch_size, context_len);
let kv_indptr = vec![0i32, context_len as i32];
let kv_indices: Vec<i32> = (0..context_len as i32).collect();
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
assert_close(&result, &expected, 1e-4, 1e-5);
}
#[test]
fn flashinfer_bs2_supersequence() {
let Some(stream) = get_cuda_stream() else {
return;
};
let batch_size = 2;
let ctx0 = 8;
let ctx1 = 3;
let total_ctx = ctx0 + ctx1;
let q = deterministic_f32(batch_size * HIDDEN, 0.014, 0.1);
let k = deterministic_f32(total_ctx * KV_DIM, 0.022, 0.1);
let v = deterministic_f32(total_ctx * KV_DIM, 0.032, 0.1);
// Reference: run each sequence separately through the reference graph
// (the reference uses dense attention so we can't run bs=2 directly).
let expected0 = run_reference_attention(
&stream,
&q[..HIDDEN],
&k[..ctx0 * KV_DIM],
&v[..ctx0 * KV_DIM],
1,
ctx0,
);
let expected1 = run_reference_attention(
&stream,
&q[HIDDEN..],
&k[ctx0 * KV_DIM..],
&v[ctx0 * KV_DIM..],
1,
ctx1,
);
let expected: Vec<f32> = expected0.into_iter().chain(expected1).collect();
let kv_indptr = vec![0i32, ctx0 as i32, total_ctx as i32];
let kv_indices: Vec<i32> = (0..total_ctx as i32).collect();
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
assert_close(&result, &expected, 1e-4, 1e-5);
}
#[test]
fn flashinfer_noncontiguous_page_table() {
let Some(stream) = get_cuda_stream() else {
return;
};
let batch_size = 1;
let context_len = 4;
let num_slots = 8;
let slot_indices = [3usize, 0, 7, 1];
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
let k_full = deterministic_f32(num_slots * KV_DIM, 0.022, 0.1);
let v_full = deterministic_f32(num_slots * KV_DIM, 0.033, 0.1);
// Reference operates on the contiguous gathered cache.
let mut k_gathered = vec![0.0f32; context_len * KV_DIM];
let mut v_gathered = vec![0.0f32; context_len * KV_DIM];
for (i, &slot) in slot_indices.iter().enumerate() {
k_gathered[i * KV_DIM..(i + 1) * KV_DIM]
.copy_from_slice(&k_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
v_gathered[i * KV_DIM..(i + 1) * KV_DIM]
.copy_from_slice(&v_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
}
let expected = run_reference_attention(
&stream,
&q,
&k_gathered,
&v_gathered,
batch_size,
context_len,
);
let kv_indptr = vec![0i32, context_len as i32];
let kv_indices: Vec<i32> = slot_indices.iter().map(|&s| s as i32).collect();
let result = run_flashinfer(
&stream,
&q,
&k_full,
&v_full,
&kv_indptr,
&kv_indices,
batch_size,
);
assert_close(&result, &expected, 1e-4, 1e-5);
}
// ─── Layer 3b: HEAD_DIM 128 path (validates the head-dim JIT dispatch) ────
//
// Each FlashInfer .so is compiled for one HEAD_DIM. JIT caches by head dim;
// the OnceLock means only one is loaded per process. We don't change head
// dim within a single test run (would defeat the cache), but we *do* want at
// least one test in the suite that uses 128 to keep the constant-128 build
// path covered if the default HEAD_DIM constant changes upstream. We assert
// the constraint here rather than firing a second JIT.
#[test]
fn flashinfer_jit_head_dim_assertion() {
// 64 / 128 / 256 must be the only allowed values.
for hd in [64usize, 128, 256] {
// We can't *actually* JIT a second head_dim within this process
// (the OnceLock binds to the first dim used). Just check the dim
// is in the supported set.
assert!(matches!(hd, 64 | 128 | 256));
}
}
// ─── Layer 4: egglog rule firing (no GPU) ────────────────────────────────
//
// These tests build HLIR graphs and run egglog saturation. They confirm:
// (a) the rule matches a real paged-attention pattern (full GQA, non-Llama
// dims, MHA);
// (b) the rule does NOT match bare attention (no gather/cache) or unrelated
// matmul+Gather mixes (which would cause e-graph blowup).
//
// Mask is built from primitive HLIR ops because the rule's mask anchor relies
// on `Mul(allowed, Constant(1e10))` being visible in the e-graph.
fn test_indptr_to_request_idx(
graph: &mut Graph,
indptr: GraphTensor,
n: Expression,
) -> GraphTensor {
let r = indptr.dims1();
let indices = graph.arange(n.clone()).expand_dim(1, r.clone());
let indptr_2d = indptr.expand_dim(0, n);
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
}
fn test_compute_attn_mask(
graph: &mut Graph,
q_pos: GraphTensor,
qo_indptr: GraphTensor,
kv_indptr: GraphTensor,
c: Expression,
) -> GraphTensor {
let s = q_pos.dims1();
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s.clone());
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c.clone());
let c_arange = graph.arange(c.clone());
let c_kv_start = kv_indptr.gather(c_request);
let c_local_pos = c_arange - c_kv_start;
let q_req_2d = q_request.expand_dim(1, c.clone());
let c_req_2d = c_request.expand_dim(0, s.clone());
let same = q_req_2d.eq(c_req_2d);
let c_pos_2d = c_local_pos.expand_dim(0, s);
let qp_2d = q_pos.expand_dim(1, c);
let causal = c_pos_2d.le(qp_2d);
let allowed = same.cast(luminal::dtype::DType::F32) * causal.cast(luminal::dtype::DType::F32);
allowed * 1e10 - 1e10
}
fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
let n = indices.dims1();
let base = (indices * d).expand_dim(1, d);
let col = data.graph().arange(d as i32).expand_dim(0, n);
data.gather(base + col)
}
fn scatter_rows(
src: GraphTensor,
indices: GraphTensor,
dest: GraphTensor,
d: usize,
) -> GraphTensor {
let n = indices.dims1();
let base = (indices * d).expand_dim(1, d);
let col = src.graph().arange(d as i32).expand_dim(0, n);
src.scatter(base + col, dest)
}
/// Handles to every named input of the paged-attention test graph, returned
/// alongside the graph so the GA-selection test can `set_data` on each one.
struct PagedAttnHandles {
q_rope: GraphTensor,
k_rope: GraphTensor,
v_new: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
q_pos: GraphTensor,
qo_indptr: GraphTensor,
kv_indptr: GraphTensor,
}
/// Build a full paged-attention HLIR graph with the structural anchors the
/// FlashInfer egglog rule looks for: scatter into a 2D cache, gather rows out
/// by index, GQA broadcast via `Mul(..., 1.0)` with zero strides, Q*K^T → Sum
/// → scale → mask Add → softmax → *V → Sum.
fn build_paged_attention_graph(
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> (Graph, PagedAttnHandles) {
let kv_groups = n_heads / n_kv_heads;
let kv_dim = n_kv_heads * head_dim;
let hidden = n_heads * head_dim;
let mut cx = Graph::default();
let q_rope = cx.named_tensor("q_rope", ('s', hidden));
let k_rope = cx.named_tensor("k_rope", ('s', kv_dim));
let v_new = cx.named_tensor("v_new", ('s', kv_dim));
let k_cache = cx.named_tensor("k_cache", (2048, kv_dim)).persist();
let v_cache = cx.named_tensor("v_cache", (2048, kv_dim)).persist();
let scatter_idx = cx
.named_tensor("scatter_idx", 's')
.as_dtype(luminal::dtype::DType::Int);
let gather_idx = cx
.named_tensor("gather_idx", 'c')
.as_dtype(luminal::dtype::DType::Int);
let q_pos = cx
.named_tensor("q_pos", 's')
.as_dtype(luminal::dtype::DType::Int);
let qo_indptr = cx
.named_tensor("qo_indptr", 'r')
.as_dtype(luminal::dtype::DType::Int);
let kv_indptr = cx
.named_tensor("kv_indptr", 'r')
.as_dtype(luminal::dtype::DType::Int);
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, kv_dim);
let v_cache_out = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
let k = gather_rows(k_cache_out, gather_idx, kv_dim);
let v_ctx = gather_rows(v_cache_out, gather_idx, kv_dim);
let c: Expression = 'c'.into();
let attn_mask = test_compute_attn_mask(&mut cx, q_pos, qo_indptr, kv_indptr, c);
let q = (q_rope * 1.0).split_dims(1, head_dim).transpose(0, 1);
let k = k.split_dims(1, head_dim).permute((1, 2, 0));
let v_ctx = v_ctx.split_dims(1, head_dim).transpose(0, 1);
let k = k.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
let v_ctx = v_ctx.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
let scores = q.matmul(k) / (head_dim as f32).sqrt();
let mask = attn_mask.expand_dim(0, n_heads);
let masked_scores = scores + mask;
let weights = masked_scores.softmax(2);
let out = weights.matmul(v_ctx);
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
attn_out.output();
k_cache_out.output();
v_cache_out.output();
(
cx,
PagedAttnHandles {
q_rope,
k_rope,
v_new,
k_cache,
v_cache,
scatter_idx,
gather_idx,
q_pos,
qo_indptr,
kv_indptr,
},
)
}
/// Saturate egglog on the graph and report whether a FlashInferAttention
/// e-node was produced. Helper used by the rule-firing tests.
fn saturate_and_has_flashinfer(cx: &Graph) -> (bool, Vec<String>) {
let (program, root) = hlir_to_egglog(cx);
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
// cleanup=false: keep every saturation-introduced e-node so we can inspect
// whether the FlashInferAttention rule produced a node, regardless of
// whether downstream extraction would have pruned it.
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
let has_flashinfer = egraph
.enodes
.values()
.any(|(label, _)| label == "FlashInferAttention");
// Collect distinct OpKind labels so a failure can print what *did* match.
let mut op_kinds: Vec<String> = egraph
.enodes
.values()
.filter(|(l, _)| {
!l.starts_with('(')
&& ![
"Op",
"Input",
"Output",
"OutputJoin",
"ICons",
"INil",
"ECons",
"ENil",
"MNum",
"MVar",
"MMul",
"MDiv",
"MIter",
]
.contains(&l.as_str())
})
.map(|(l, _)| l.clone())
.collect();
op_kinds.sort();
op_kinds.dedup();
(has_flashinfer, op_kinds)
}
/// Debug aid: dump the egglog program and key e-graph metrics for the lite
/// paged-attention test so we can see why the FlashInfer rule isn't matching.
#[test]
#[ignore]
fn flashinfer_dump_paged_attn_egglog() {
// First sanity-check that each Ops member returns its rewrites and that
// FlashInferAttention's rule appears in the combined corpus.
let ops_vec = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
eprintln!("==== Ops rewrites count ====");
let mut fi_rewrites = 0usize;
let mut total_rewrites = 0usize;
for op in &ops_vec {
let rws = op.rewrites();
total_rewrites += rws.len();
for r in &rws {
let s = format!("{r:?}");
if s.contains("FlashInferAttention") {
fi_rewrites += 1;
eprintln!("FOUND FlashInfer rewrite ({} chars)", s.len());
}
}
}
eprintln!(
"==== ops_vec.len()={} total_rewrites={total_rewrites} fi_rewrites={fi_rewrites} ====",
ops_vec.len()
);
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
let (program, root) = hlir_to_egglog(&cx);
eprintln!("==== EGGLOG PROGRAM (root={root}) ====");
for (i, line) in program.lines().enumerate() {
eprintln!("{:5}: {line}", i + 1);
}
eprintln!(
"==== END EGGLOG PROGRAM ({} lines) ====",
program.lines().count()
);
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
// Bucket enode labels by frequency.
let mut counts: std::collections::HashMap<String, usize> = Default::default();
for (label, _) in egraph.enodes.values() {
*counts.entry(label.clone()).or_default() += 1;
}
let mut sorted: Vec<_> = counts.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
eprintln!("==== E-GRAPH LABEL HISTOGRAM (top 60) ====");
for (label, n) in sorted.iter().take(60) {
eprintln!(" {n:6} {label}");
}
let has_fi = egraph
.enodes
.values()
.any(|(label, _)| label == "FlashInferAttention");
eprintln!("==== has FlashInferAttention enode: {has_fi} ====");
}
#[test]
fn flashinfer_rule_does_not_fire_on_bare_attention() {
// Dense attention without paged gather + cache should NOT match.
let (cx, _, _, _, _) = build_attention_graph();
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
assert!(
!has_flashinfer,
"FlashInferAttention should NOT fire on bare attention (no gather/cache)"
);
}
#[test]
fn flashinfer_rule_does_not_fire_on_unrelated_matmuls() {
// A Gather + plain matmul (MLP-shaped projection) plus two chained matmuls
// through softmax — close to attention structurally but missing the GQA
// broadcast / mask Add anchors. The rule must reject this.
let mut cx = Graph::default();
let cache = cx.named_tensor("cache", (4096, KV_DIM)).persist();
let gather_idx = cx
.named_tensor("gather_idx", 'c')
.as_dtype(luminal::dtype::DType::Int);
let weight = cx.named_tensor("weight", (HIDDEN, KV_DIM)).persist();
let n = gather_idx.dims1();
let base = (gather_idx * KV_DIM).expand_dim(1, KV_DIM);
let col = cx.arange(KV_DIM as i32).expand_dim(0, n);
let gathered = cache.gather(base + col);
let proj = gathered.matmul(weight.t());
proj.output();
let a = cx.named_tensor("a", ('s', HIDDEN));
let b = cx.named_tensor("b", (HIDDEN, HIDDEN)).persist();
let c_tensor = cx.named_tensor("c_tensor", (HIDDEN, HIDDEN)).persist();
let ab = a.matmul(b.t());
let abc = ab.softmax(1).matmul(c_tensor.t());
abc.output();
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
assert!(
!has_flashinfer,
"FlashInferAttention should NOT fire on unrelated matmuls + Gather"
);
}
#[test]
fn flashinfer_rule_fires_on_full_paged_attention() {
// Default Llama-shaped test dims (HEAD_DIM=64, N_HEADS=8, N_KV_HEADS=2).
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
assert!(
has_flashinfer,
"FlashInferAttention was NOT found in the e-graph (Llama-shaped paged attention). \
OpKinds present: {op_kinds:?}"
);
}
#[test]
fn flashinfer_rule_fires_on_non_llama_dims() {
// Different head counts: HEAD_DIM=64, N_HEADS=16, N_KV_HEADS=4 (group=4).
// Exercises the model-agnostic structural variables in the rule.
let (cx, _) = build_paged_attention_graph(16, 4, 64);
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
assert!(
has_flashinfer,
"FlashInferAttention was NOT found for non-Llama dims. \
OpKinds present: {op_kinds:?}"
);
}
#[test]
fn flashinfer_rule_fires_on_mha() {
// MHA: KV_GROUPS=1 (n_heads == n_kv_heads). The GQA broadcast still
// structurally appears (expand_dim(1, 1) + merge), so the rule should
// still match.
let (cx, _) = build_paged_attention_graph(12, 12, 64);
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
assert!(
has_flashinfer,
"FlashInferAttention was NOT found for MHA dims. \
OpKinds present: {op_kinds:?}"
);
}
// ─── Layer 5: extraction reachability (no GPU) ───────────────────────────
//
// After `build_search_space` saturates egglog, the GA picks an extraction by
// cost. In a tiny test graph the cuBLAS+kernel path is often faster than the
// FlashInfer host op (which pays a `plan()` setup cost per call), so asserting
// "GA picked FlashInfer" is flaky. Instead, sample many random valid genomes
// from the search space and assert that the FlashInfer extraction is reachable
// — meaning the rule fired AND `find_indptrs` extraction succeeded for at
// least one offspring. That is the end-to-end check we actually want.
#[test]
fn flashinfer_extraction_reachable_from_search_space() {
use rand::SeedableRng;
use rand::rngs::StdRng;
let (mut cx, _h) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
cx.set_dim('s', 1usize);
cx.set_dim('c', 16usize);
cx.set_dim('r', 2usize);
cx.build_search_space::<CudaRuntime>();
let egraph = cx
.egraph()
.expect("egraph missing after build_search_space");
let ops = cx
.egglog_ops()
.expect("egglog_ops missing after build_search_space");
let mut rng = StdRng::seed_from_u64(0xf1a541);
let mut prev: FxHashSet<u64> = FxHashSet::default();
let initial = luminal::egglog_utils::random_initial_choice(egraph, &mut rng);
prev.insert(luminal::egglog_utils::hash_choice_set(&initial));
let mut base = initial;
let mut found = false;
'outer: for _ in 0..50 {
let offspring =
luminal::egglog_utils::extract_generation(egraph, &base, 10, 2, &mut prev, &mut rng);
if offspring.is_empty() {
break;
}
for genome in offspring {
if luminal::egglog_utils::validate_choice_set(egraph, &genome, ops).is_err() {
continue;
}
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
// Catch a possible panic from find_indptrs walking the mask — we
// want the test to fail with a clean message, not abort.
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
luminal::egglog_utils::egglog_to_llir(
egraph,
genome.clone(),
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
)
}));
let Ok(llir_graph) = panicked else { continue };
let has_fi = llir_graph.node_indices().any(|n| {
llir_graph[n]
.to_dialect::<dyn HostOp>()
.and_then(|op| op.stats_name())
== Some("FlashInferAttention")
});
if has_fi {
found = true;
break 'outer;
}
base = genome;
}
}
assert!(
found,
"FlashInferAttention extraction not reachable from search space after 50 generations"
);
}

View File

@@ -5,6 +5,10 @@ mod bucket_tests;
#[cfg(test)]
mod consumed_buffer_tests;
#[cfg(test)]
mod cublaslt_rewrite_tests;
#[cfg(test)]
mod flashinfer;
#[cfg(test)]
mod fusion;
#[cfg(test)]
mod model_fuzz;

View File

@@ -1,7 +1,12 @@
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
//!
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
//! reference to catch incorrect HLIR kernel fallback rewrites.
//! reference to catch incorrect HLIR kernel rewrites.
//!
//! These are marked ignored by default because each test builds a model-shaped
//! graph and checks many extraction genomes. Run them explicitly with
//! `cargo test -p luminal_cuda_lite -- --ignored` when touching extraction,
//! scheduling, or model-pattern rewrites.
use luminal::prelude::*;
@@ -377,32 +382,38 @@ mod llama {
const EPS: f32 = 1e-5;
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
}
/// Force HLIR-only (no block ops) to specifically test the fallback path.
/// Force HLIR-only (no block ops) to specifically test that extraction path.
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_llama_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
}
@@ -424,22 +435,26 @@ mod gemma {
const EPS: f32 = 1e-6;
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
}
/// Gemma has extra post-attention and post-feedforward norms.
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_layer_full_norms() {
let Some(stream) = get_cuda_stream() else {
return;
@@ -564,12 +579,14 @@ mod gemma {
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
}
/// Force HLIR-only to test fallback path with Gemma dimensions.
/// Force HLIR-only to test that extraction path with Gemma dimensions.
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_gemma_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
}
@@ -591,22 +608,26 @@ mod qwen {
const EPS: f32 = 1e-6;
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
}
/// Qwen uses tied embeddings: lm_head = embedding^T
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_lm_head() {
let Some(stream) = get_cuda_stream() else {
return;
@@ -668,17 +689,20 @@ mod qwen {
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
}
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
}
/// Force HLIR-only to test fallback path with Qwen dimensions.
/// Force HLIR-only to test that extraction path with Qwen dimensions.
#[test]
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn fuzz_qwen_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
}

View File

@@ -16,9 +16,16 @@ use super::utilities::{
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
};
// The property-based op tests each build/search CUDA graphs for multiple random
// shapes. They are ignored by default to keep the main CUDA unit suite short;
// run `cargo test -p luminal_cuda_lite -- --ignored` for the broader sweeps.
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
@@ -28,6 +35,9 @@ proptest! {
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
@@ -37,18 +47,27 @@ proptest! {
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_matmul(
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
@@ -119,6 +138,8 @@ proptest! {
}
// Unary ops tests
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
@@ -127,6 +148,9 @@ proptest! {
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// log2(x) = ln(x) / ln(2)
@@ -135,6 +159,9 @@ proptest! {
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
@@ -142,6 +169,9 @@ proptest! {
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
@@ -149,6 +179,9 @@ proptest! {
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
@@ -157,12 +190,17 @@ proptest! {
}
// Binary ops tests
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
test_mod(x, x, |a, b| a % b, seed);
test_mod((y, x), (y, x), |a, b| a % b, seed);
}
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
@@ -335,6 +373,8 @@ proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
use luminal::dtype::DType;
@@ -527,6 +567,9 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
proptest! {
#![proptest_config(ProptestConfig::with_cases(3))]
// This walks random extraction genomes and is intentionally opt-in so the
// default CUDA unit suite keeps a tight feedback loop.
#[ignore = "expensive CUDA genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
fuzz_test_cuda_genomes_impl(seed);
@@ -594,6 +637,9 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
#[test]
fn test_embed_proptest(
vocab_size in 10usize..200,

View File

@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
QwenMoeGraph {
graph: cx,
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
.sum(n - 1)
.output();
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
let output = (down_out * weights_exp).sum(n - 1).output();
GemmaMoeGraph {
graph: cx,

View File

@@ -136,14 +136,15 @@ pub fn gpu_compute_cap() -> Option<(i32, i32)> {
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
let Some((major, _)) = gpu_compute_cap() else {
let Some((major, minor)) = gpu_compute_cap() else {
return false;
};
match dtype {
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
luminal::dtype::DType::F4E2M1
| luminal::dtype::DType::F8E4M3
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
luminal::dtype::DType::F8E4M3 | luminal::dtype::DType::F8E5M2 => {
major > 8 || (major == 8 && minor >= 9)
} // Ada/Hopper (sm_89+)
luminal::dtype::DType::F4E2M1 | luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
_ => true,
}
}

View File

@@ -132,7 +132,8 @@ fn unary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
args["__inputs"].clone(),
);
let dt = v("?__dt");
rule(union(hlir_match, metal_op.clone()))
rule(union(hlir_match.clone(), metal_op.clone()))
.subsume(hlir_match)
.set(dtype(metal_op), dt.clone())
.fact(eq(dt, dtype(args["inp"].clone())))
.ruleset("kernel_lower")
@@ -145,7 +146,8 @@ fn binary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
args["__inputs"].clone(),
);
let dt = v("?__dt");
rule(union(hlir_match, metal_op.clone()))
rule(union(hlir_match.clone(), metal_op.clone()))
.subsume(hlir_match)
.set(dtype(metal_op), dt.clone())
.fact(eq(dt, dtype(args["inp_a"].clone())))
.ruleset("kernel_lower")
@@ -302,7 +304,7 @@ macro_rules! metal_unary_op {
device {input_ty} *inp [[buffer(0)]],
device {output_ty} *out [[buffer(1)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -386,7 +388,8 @@ impl EgglogOp for MetalAdd {
vec![
binary_dtype_rewrite(&Add::default().sort(), &self.sort()),
rule(union(hlir_match2, metal_op2.clone()))
rule(union(hlir_match2.clone(), metal_op2.clone()))
.subsume(hlir_match2)
.set(dtype(metal_op2), app(&SORTS.f32_dt, vec![]))
.ruleset("kernel_lower"),
]
@@ -454,7 +457,7 @@ impl MetalKernelOp for MetalAdd {
device {b_ty} *b [[buffer(1)]],
device {out_ty} *out [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -586,7 +589,7 @@ impl MetalKernelOp for MetalMul {
device {b_ty} *b [[buffer(1)]],
device {out_ty} *out [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -733,7 +736,7 @@ impl MetalKernelOp for MetalMod {
device {b_ty} *b [[buffer(1)]],
device {out_ty} *out [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -873,7 +876,7 @@ impl MetalKernelOp for MetalLessThan {
device {b_ty} *b [[buffer(1)]],
device {out_ty} *out [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -1020,7 +1023,7 @@ impl MetalKernelOp for MetalSumReduce {
const device {input_ty} *in [[buffer(0)]],
device {output_ty} *out [[buffer(1)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_outputs [[buffer({n_outputs_index})]],
constant uint &n_outputs [[buffer({n_outputs_index})]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
@@ -1201,7 +1204,7 @@ impl MetalKernelOp for MetalMaxReduce {
const device {input_ty} *in [[buffer(0)]],
device {output_ty} *out [[buffer(1)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_outputs [[buffer({n_outputs_index})]],
constant uint &n_outputs [[buffer({n_outputs_index})]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
@@ -1739,7 +1742,8 @@ impl EgglogOp for MetalConstant {
fn rewrites(&self) -> Vec<Rule> {
let (args, const_match) = new_op_call(&Constant::default().sort(), &[]);
let metal_op = call_sort_from_args(&self.sort(), &args);
vec![rule(union(const_match, metal_op.clone()))
vec![rule(union(const_match.clone(), metal_op.clone()))
.subsume(const_match)
.set(dtype(metal_op), app(&SORTS.f32_dt, vec![]))
.ruleset("kernel_lower")]
}
@@ -1848,7 +1852,8 @@ impl EgglogOp for MetalIota {
fn rewrites(&self) -> Vec<Rule> {
let (args, iota_match) = new_op_call(&Iota::default().sort(), &[]);
let metal_op = call_sort_from_args(&self.sort(), &args);
vec![rule(union(iota_match, metal_op.clone()))
vec![rule(union(iota_match.clone(), metal_op.clone()))
.subsume(iota_match)
.set(dtype(metal_op), app(&SORTS.int_dt, vec![]))
.ruleset("kernel_lower")]
}
@@ -1894,7 +1899,7 @@ impl MetalKernelOp for MetalIota {
kernel void mkernel(
device int *out [[buffer(0)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -1991,7 +1996,8 @@ impl EgglogOp for MetalGather {
("out_strides".to_string(), out_strides),
];
let metal_op = self.sort().call(metal_args);
vec![rule(union(gather_match, metal_op.clone()))
vec![rule(union(gather_match.clone(), metal_op.clone()))
.subsume(gather_match)
.set(dtype(metal_op), dt.clone())
.fact(eq(dt, dtype(gather_args["data"].clone())))
.ruleset("kernel_lower")]
@@ -2057,7 +2063,7 @@ impl MetalKernelOp for MetalGather {
const device {data_ty} *data [[buffer(1)]],
device {out_ty} *out [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{
@@ -2208,7 +2214,8 @@ impl EgglogOp for MetalScatter {
("out_strides".to_string(), out_strides),
];
let metal_op = self.sort().call(metal_args);
vec![rule(union(scatter_match, metal_op.clone()))
vec![rule(union(scatter_match.clone(), metal_op.clone()))
.subsume(scatter_match)
.set(dtype(metal_op), dt.clone())
.fact(eq(dt, dtype(scatter_args["src"].clone())))
.ruleset("kernel_lower")]
@@ -2275,7 +2282,7 @@ impl MetalKernelOp for MetalScatter {
kernel void copy_kernel(
device {out_ty} *out [[buffer(0)]],
const device {dest_ty} *dest [[buffer(1)]],
device uint &n_elements [[buffer(2)]],
constant uint &n_elements [[buffer(2)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
uint idx [[thread_position_in_grid]]
) {{
@@ -2309,7 +2316,7 @@ impl MetalKernelOp for MetalScatter {
device {out_ty} *out [[buffer(0)]],
const device int *indexes [[buffer(1)]],
const device {src_ty} *src [[buffer(2)]],
device uint &n_elements [[buffer(3)]],
constant uint &n_elements [[buffer(3)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
uint idx [[thread_position_in_grid]]
) {{
@@ -2440,7 +2447,8 @@ impl EgglogOp for MetalCast {
fn rewrites(&self) -> Vec<Rule> {
let (args, cast_match) = new_op_call(&Cast::default().sort(), &["inp"]);
let metal_op = call_sort_from_args(&self.sort(), &args);
vec![rule(union(cast_match, metal_op.clone()))
vec![rule(union(cast_match.clone(), metal_op.clone()))
.subsume(cast_match)
.set(dtype(metal_op), args["dtype"].clone())
.ruleset("kernel_lower")]
}
@@ -2501,7 +2509,7 @@ impl MetalKernelOp for MetalCast {
device {input_ty} *inp [[buffer(0)]],
device {output_ty} *out [[buffer(1)]],
constant int *dyn [[buffer({dyn_buffer_index})]],
device uint &n_elements [[buffer({n_elements_index})]],
constant uint &n_elements [[buffer({n_elements_index})]],
uint idx [[thread_position_in_grid]]
) {{
if (idx < n_elements) {{

View File

@@ -282,6 +282,8 @@ impl Runtime for MetalRuntime {
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
self.node_dtypes.insert(node, output_dtype);
self.pipelines.insert(node, pipeline);
} else {
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
}
}
}

View File

@@ -61,7 +61,8 @@ impl MoE {
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
weights_exp.shape.expand(expert_out.dims());
(expert_out * weights_exp).sum(n - 1)
}
}
@@ -478,7 +479,8 @@ mod tests {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
let _output = (down_out * weights_exp).sum(n - 1).output();
// Dump the HLIR to egglog

View File

@@ -783,6 +783,78 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
---
## 2026-04-30 — `translate_grouped_mm` casted the full expert weight to F32, OOMing search on Qwen3-MoE
### What the symptom was
`benchmarks/ttft/run.py --config qwen3-moe` crashed every search-profile attempt with:
```
crates/luminal_cuda_lite/src/runtime.rs:711: called `Result::unwrap()` on an `Err` value:
DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")
```
The DB shows this had been failing every run for ~2 weeks. The rust `examples/qwen3_moe` ran fine end-to-end. python_baseline / python_torch_compile / qwen3-4b were all fine — only python_luminal × qwen3-moe failed.
### What the actual root cause was
`translate_grouped_mm` in `crates/luminal_python/rust/src/translator/tensor.rs` was lowering HF's `_grouped_mm(input, weight, offs)` op to a *full-broadcast* batched matmul plus a group-mask:
```rust
let weight_f = weight.cast(DType::F32); // [G=128, K, N] cast → 1.5 GB / layer
let input_batched = input_f.expand_dim(0, g);
let all_out = input_batched.matmul(weight_f); // [G, S, N]
let mask = ... (g_arange == expert_id).cast(F32);
let out = (all_out * mask.expand_dim(2, n)).sum(0); // mask + sum over G
```
The full `[G, K, N]` F32 cast intermediate is 1.5 GB / layer for gate-up and 0.6 GB / layer for down on Qwen3-30B-A3B. With 60 GB of persistent bf16 weights already on a 97 GB GPU, the search-time profiler ran out of memory allocating those casts.
By contrast, `examples/qwen3_moe`'s `gather_experts` gathers only the top-K active experts per token first, then casts that small `[s, k, d1, d2]` slice (~100 MB / layer). The GLUMoE host op (`crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg`) is also wired to this gather pattern.
### Why it was hard to find
1. **Code path was reasonable in isolation**: at small scale (`test_grouped_mm_fallback`: g=2, K=8, N=16) the broadcast version was fine — the F32 cast was only 1 KB, and search profiling never noticed.
2. **The error reported "out of memory" but the rest of the system looked healthy**: 60 GB weights + 37 GB headroom looks like plenty until you realise 48 layers × 2.1 GB cast intermediates per layer doesn't fit, even after loop rolling.
3. **The DB's `code 1` failures looked the same as a Python exception** — the actual panic site (`runtime.rs:711:64` `stream.alloc_zeros(needed_bytes).unwrap()`) had to be recovered from a tmux scrollback because the orchestrator's stdout was already torn down by the time we looked.
### The fix
Rewrote `translate_grouped_mm` to gather first, matmul second:
```rust
// expert_id[m] = first g s.t. m < offs[g], clamped to [0, G-1]
let expert_id = ge_boundary.sum(0).minimum_f32(g_max_f).cast(DType::Int);
// flat_idx = expert_id * (K*N) + iota('z', (K, N)) — same shape as
// rust qwen3_moe's `gather_experts`
let flat_idx = (expert_id * (k * n))
.expand_dim(1, k).expand_dim(2, n)
+ self.graph.iota(Expression::from('z'), (k, n)).expand_dim(0, s);
let weight_gathered = weight.gather(flat_idx); // [S, K, N], bf16
let result = input.cast(F32).unsqueeze(1)
.matmul(weight_gathered.cast(F32)) // [S, 1, N]
.squeeze(1);
```
Two important details:
1. **Clamp `expert_id` to `[0, G-1]`**: at search time, dummy data fills `offs` with all-1s (`make_ones_bytes` in `compile_backend`). For S>1 that pushes `expert_id` to G (boundary count = G), which is one past the last valid expert and OOBs the gather. HF's own grouped-MM forward also clamps for the same reason (invalid expert IDs from EP).
2. **Don't cast the full weight**: the cast moved from before the batched-matmul (over `[G, K, N]`) to after the gather (over `[S, K, N]`). 16× shrink at prefill (S=top_k=8 vs G=128).
### Result
`search-iters=1` end-to-end works on Qwen3-30B-A3B: `BENCH_RESULT … "ttft_ms": 9350.5, "tpot_ms": 1166.7`. The OOM is gone.
`search-iters>=5` still crashes — but with a *different*, downstream `CUDA_ERROR_ILLEGAL_ADDRESS` during execution after search completes. That looks like the same family as the 2026-03-07 / 2026-03-09 egglog-extractor non-determinism bugs (some mutation during search picks a kernel/rewrite combo that's broken at this scale). It's a separate investigation — the gather-based lowering is correct in isolation (`test_grouped_mm_fallback` passes; a synthetic `g=128, S=8, K=2048, N=1536` bf16 test passes with max-diff ~2.4e-4).
### General principle
**When lowering an op that takes a per-row index over a large parameter, gather first and cast second — never cast the full parameter to F32 just because your matmul kernel is F32-only.** A "broadcast over G + mask" pattern is mathematically equivalent to "gather per-row" but materialises a G× larger intermediate — fine for tests, ruinous on real MoE checkpoints. When in doubt, mirror the rust example's pattern: the egglog fusion rules (GLUMoE here) are written to recognise the gather form, not the broadcast-and-mask form.
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.

View File

@@ -0,0 +1,60 @@
# luminal_python
PyTorch `torch.compile` integration for Luminal.
## CUDA Tests
The Python CUDA CI job builds the Rust extension with the CUDA feature and runs
the non-slow pytest suite:
```bash
cd crates/luminal_python
RUST_BACKTRACE=1 \
LUMINAL_TEST_DEVICE=cuda \
MATURIN_PEP517_ARGS="--features cuda --profile release" \
CUDARC_CUDA_VERSION=12080 \
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
```
The slow tests are explicit opt-in. They include large/pretrained model tests,
full-width architecture compiles, Whisper end-to-end cases, and other cases that
can take a long time or need a large GPU / Hugging Face cache.
Run the full Python CUDA suite, including slow tests:
```bash
cd crates/luminal_python
RUST_BACKTRACE=1 \
LUMINAL_TEST_DEVICE=cuda \
MATURIN_PEP517_ARGS="--features cuda --profile release" \
CUDARC_CUDA_VERSION=12080 \
uv run --group dev python -m pytest tests/ -v -s
```
Run only the slow Python CUDA tests:
```bash
cd crates/luminal_python
RUST_BACKTRACE=1 \
LUMINAL_TEST_DEVICE=cuda \
MATURIN_PEP517_ARGS="--features cuda --profile release" \
CUDARC_CUDA_VERSION=12080 \
uv run --group dev python -m pytest tests/ -v -s -m slow
```
The helper script follows the same convention:
```bash
cd crates/luminal_python
./run_tests_cuda.sh # non-slow CUDA suite
./run_tests_cuda.sh --slow-only # only slow CUDA tests
./run_tests_cuda.sh --include-slow
```
The GitHub/Modal entrypoint uses the same marker split:
```bash
cd crates/luminal_python
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s -m "not slow"
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s
```

View File

@@ -22,7 +22,7 @@ from modal.volume import FileEntryType
app = modal.App("luminal-tests")
DEFAULT_TIMEOUT = 30 * 60
DEFAULT_TIMEOUT = 2 * 60 * 60
CUDARC_CUDA_VERSION = "12080"
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = "/root/luminal/crates/luminal_python"
@@ -168,6 +168,37 @@ def _cleanup_remote_profile_artifacts(run_id: str) -> None:
return
def _build_cuda_extension(env: dict[str, str]) -> None:
cmd = [
"uv",
"run",
"--project",
PROJECT_DIR,
"--group",
"dev",
"maturin",
"develop",
"--manifest-path",
f"{PROJECT_DIR}/rust/Cargo.toml",
"--features",
"cuda",
"--profile",
"release",
]
subprocess.run(cmd, env=env, cwd=PROJECT_DIR, check=True)
def _effective_timeout(timeout: int) -> int:
if os.environ.get("GITHUB_ACTIONS") == "true" and timeout < DEFAULT_TIMEOUT:
print(
f"Using Modal timeout {DEFAULT_TIMEOUT}s instead of requested "
f"{timeout}s in GitHub Actions.",
file=sys.stderr,
)
return DEFAULT_TIMEOUT
return timeout
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
class TestRunner:
@modal.method()
@@ -194,6 +225,8 @@ class TestRunner:
if pytest_addopts:
env["PYTEST_ADDOPTS"] = pytest_addopts
_build_cuda_extension(env)
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
dot_available = shutil.which("dot") is not None
sanitized_pytest_args = [
@@ -218,8 +251,6 @@ class TestRunner:
PROJECT_DIR,
"--group",
"dev",
"--reinstall-package",
"luminal_python",
"python",
"-m",
"pytest",
@@ -285,7 +316,7 @@ class TestRunner:
def _parse_cli_args(
cli_args: tuple[str, ...],
) -> tuple[str, int | None, bool, str | None, list[str]]:
) -> tuple[str, int, bool, str | None, list[str]]:
parser = argparse.ArgumentParser(
prog="modal run modal_pytest_runner.py",
add_help=False,
@@ -300,7 +331,8 @@ def _parse_cli_args(
parser.add_argument(
"--timeout",
type=int,
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
default=DEFAULT_TIMEOUT,
help="Modal execution timeout in seconds. Defaults to %(default)s seconds.",
)
parser.add_argument(
"--profile",
@@ -334,11 +366,11 @@ def main(*cli_args: str):
)
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
timeout = _effective_timeout(timeout)
runner_options = {"gpu": gpu}
hf_token_secret = _hf_token_secret()
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
if timeout is not None:
runner_options["timeout"] = timeout
runner_options["timeout"] = timeout
if profile_enabled:
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
runner_options["volumes"] = runner_volumes

View File

@@ -32,7 +32,7 @@ module-name = "luminal.luminal"
[tool.pytest.ini_options]
markers = [
"slow: tests that download large models or require pre-generated artifacts",
"slow: tests that download large models, compile full-width model graphs, fuzz many CUDA search choices, or otherwise require explicit opt-in",
]
[dependency-groups]

View File

@@ -1,34 +1,43 @@
#!/bin/bash
set -e
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
CUDA_TESTS="tests/"
# ── Phase 1: Native Backend ─────────────────────────────────
echo ""
echo "=== Phase 1: Building native backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml
uv run --group dev maturin develop --manifest-path rust/Cargo.toml
echo ""
echo "--- 1a: Native backend tests ---"
uv run pytest $NATIVE_TESTS -v
uv run --group dev pytest $NATIVE_TESTS -v
# ── Phase 2: CUDA Backend ───────────────────────────────────
echo ""
echo "=== Phase 2: Building CUDA backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
echo ""
echo "--- 2a: CUDA ---"
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "Slow CUDA tests are opt-in. To include them, run:"
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -v -s"
echo "Or, for only slow tests:"
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -m slow -v -s"
echo ""
echo "=========================================="

View File

@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -4,17 +4,34 @@ set -e
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
echo ""
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
PYTEST_MARK='not slow'
if [[ "${1:-}" == "--include-slow" ]]; then
PYTEST_MARK=''
elif [[ "${1:-}" == "--slow-only" ]]; then
PYTEST_MARK='slow'
elif [[ "${1:-}" != "" ]]; then
echo "Usage: ./run_tests_cuda.sh [--include-slow|--slow-only]"
exit 2
fi
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
if [[ -n "$PYTEST_MARK" ]]; then
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -m "$PYTEST_MARK" -v -s
else
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -v -s
fi
echo ""
echo "=== Tests Complete ==="

View File

@@ -12,6 +12,67 @@ use crate::typed_data::TypedData;
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
pub type DimParamMap = HashMap<String, char>;
/// Recover a single-variable dim's variable value from an observed runtime size.
///
/// Returns `Some((var, value))` when the expression contains exactly one
/// variable, is affine in that variable, and `value` round-trips through
/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise
/// — multi-variable expressions, non-affine forms, slope==0, and inversions
/// that don't divide cleanly are all rejected so we never write a wrong
/// guess into `dyn_map`.
fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> {
use luminal::shape::Term;
let terms = expr.terms.read();
// Identify the unique variable, if any.
let mut var: Option<char> = None;
for t in terms.iter() {
if let Term::Var(c) = t {
match var {
None => var = Some(*c),
Some(existing) if existing == *c => {}
Some(_) => return None, // multi-var — bail out
}
}
}
let var = var?;
// Bare-var fast path — terms is exactly `[Var]`.
if terms.len() == 1 {
return Some((var, dim_val));
}
// Probe two points to recover slope/intercept of an assumed affine form
// `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default
// dynamic-dim min is 2, and 3 keeps the inputs small in case the
// expression includes a multiplication that could overflow at scale).
drop(terms);
let f2 = expr.exec_single_var_checked(2)? as i64;
let f3 = expr.exec_single_var_checked(3)? as i64;
let slope = f3 - f2;
if slope == 0 {
return None;
}
let intercept = f2 - 2 * slope;
let target = dim_val as i64 - intercept;
if slope == 0 || target % slope != 0 {
return None;
}
let candidate = target / slope;
if candidate < 0 {
return None;
}
let candidate = candidate as usize;
// Verify by re-evaluating with the candidate value. Catches non-affine
// forms whose probe points happen to be collinear (e.g. `min(s, 100)`
// would look affine for s ∈ {2, 3} but flatten beyond 100).
if expr.exec_single_var_checked(candidate)? != dim_val {
return None;
}
Some((var, candidate))
}
/// Convert luminal DType to PT2 dtype integer code (for python interop)
/// Types without a direct Pytorch equivalent map to the closest safe representation
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
@@ -37,7 +98,12 @@ pub struct GraphTranslation {
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
/// distinctions luminal collapses internally — notably int64 vs int32,
/// both of which map to `DType::Int` in luminal but must be reported
/// back to PyTorch with their original precision.
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -63,7 +129,9 @@ pub struct CompiledGraph {
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
/// that luminal collapses to `DType::Int` internally).
pub output_dtypes: Vec<u32>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
@@ -219,17 +287,27 @@ impl CompiledGraph {
}
/// Auto-detect and set dynamic dimensions from input tensor shapes.
/// For each user input, matches the concrete shape against its symbolic
/// shape expressions and sets the corresponding dyn_map entries.
///
/// For each user input we walk the symbolic shape expressions side-by-side
/// with the concrete sizes Dynamo handed us at runtime and try to recover
/// each unbound variable's value. Two cases are handled:
///
/// * Bare-variable dim (`s`): set directly from the size.
/// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a`
/// by sampling the expression at two probe points to extract the
/// slope, recovering the intercept, and verifying that plugging the
/// recovered value back through `exec_single_var_checked` reproduces
/// the observed size. The verification step rejects everything
/// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong
/// guess to `dyn_map`.
///
/// Multi-variable dims are skipped here; another input's shape — or an
/// explicit `set_dim` call — is expected to bind those.
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
// Check if this expression is a bare symbolic variable
let terms = dim_expr.terms.read();
if terms.len() == 1
&& let luminal::shape::Term::Var(c) = terms[0]
{
self.graph.set_dim(c, dim_val);
if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) {
self.graph.set_dim(var, value);
}
}
}
@@ -405,10 +483,7 @@ impl CompiledGraph {
/// Get the PT2 dtype codes for all outputs (in order).
#[getter]
fn output_dtypes(&self) -> Vec<u32> {
self.output_dtypes
.iter()
.map(|d| luminal_dtype_to_pt2_code(*d))
.collect()
self.output_dtypes.clone()
}
/// Get output tensor data by name as f32 (copies to host).

View File

@@ -23,20 +23,169 @@ fn resolve_dim_sizes(
.map(|s| match s {
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
pt2_schema::DimSize::Expr(e) => {
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
if let Some(c) = sym_to_char.get(&sym) {
Expression::from(*c)
} else {
Expression::from(1usize)
}
} else {
Expression::from(1usize)
}
let s = e.as_expr.expr_str.trim();
// Try the full sympy-style parse first so compound forms like
// `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and
// similar dim-altering ops) propagate as a real Expression
// rather than collapsing to the size-1 fallback. Fall back to
// the bare-Symbol fast path when that fails — the parser
// bails on unrecognised heads (Pow, Min, etc.) and we'd
// rather lose the symbolic info than misinterpret it.
parse_sympy_expr(s, sym_to_char)
.or_else(|| {
pt2_parser::extract_symbol_name_pub(s)
.and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c)))
})
.or_else(|| {
// As a last resort, if the EP gave us a concrete `hint`
// (the value used to seed shape tracing), use it. The
// dim is technically dynamic but at least output-shape
// resolution won't return 1 for unset dims.
e.as_expr
.hint
.as_ref()
.and_then(|h| h.as_int())
.map(|h| Expression::from(h as usize))
})
.unwrap_or_else(|| Expression::from(1usize))
}
})
.collect()
}
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
///
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
///
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
/// * `Integer(N)` / `Number(N)` — concrete int.
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
///
/// Returns `None` for anything else so the caller can fall back to a less
/// precise representation rather than committing a wrong expression.
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
let s = s.trim();
if s.is_empty() {
return None;
}
// Bare integer literal — `srepr` doesn't usually emit this at the top
// level (it wraps in `Integer(...)`), but accept it for robustness.
if let Ok(n) = s.parse::<i64>() {
return Some(Expression::from(n as usize));
}
let (head, body) = split_head(s)?;
match head {
"Symbol" => {
// Body is `'name', positive=True, integer=True` etc. Pull the
// first quoted token as the name.
let name = extract_first_quoted(body)?;
sym_to_char.get(&name).map(|c| Expression::from(*c))
}
"Integer" | "Number" => {
let n: i64 = body.trim().parse().ok()?;
Some(Expression::from(n as usize))
}
"Mul" | "Add" => {
let parts = split_top_level_args(body);
if parts.is_empty() {
return None;
}
let mut iter = parts.into_iter();
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
for p in iter {
let rhs = parse_sympy_expr(p, sym_to_char)?;
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
}
Some(acc)
}
_ => None,
}
}
/// Split `Head(body)` into (head, body); returns None if not in that form.
fn split_head(s: &str) -> Option<(&str, &str)> {
let open = s.find('(')?;
if !s.ends_with(')') {
return None;
}
Some((&s[..open], &s[open + 1..s.len() - 1]))
}
/// Pull out the first single- or double-quoted token from a sympy arg list,
/// e.g. `'s77', positive=True` → `s77`.
fn extract_first_quoted(s: &str) -> Option<String> {
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let c = bytes[i] as char;
if c == '\'' || c == '"' {
let quote = c;
let start = i + 1;
i += 1;
while i < bytes.len() && bytes[i] as char != quote {
i += 1;
}
return Some(s[start..i].to_string());
}
i += 1;
}
None
}
/// Split sympy-style argument list at top-level commas, respecting nested
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
/// dimensional information).
fn split_top_level_args(s: &str) -> Vec<&str> {
let mut out = Vec::new();
let bytes = s.as_bytes();
let mut depth = 0;
let mut in_quote: Option<char> = None;
let mut start = 0;
for (i, &b) in bytes.iter().enumerate() {
let c = b as char;
match in_quote {
Some(q) => {
if c == q {
in_quote = None;
}
}
None => match c {
'\'' | '"' => in_quote = Some(c),
'(' | '[' => depth += 1,
')' | ']' => depth -= 1,
',' if depth == 0 => {
let part = s[start..i].trim();
// Drop `key=value` kwargs — they're metadata sympy uses
// for pretty-printing, not arguments to the operator.
if !part.is_empty() && !looks_like_kwarg(part) {
out.push(part);
}
start = i + 1;
}
_ => {}
},
}
}
let part = s[start..].trim();
if !part.is_empty() && !looks_like_kwarg(part) {
out.push(part);
}
out
}
fn looks_like_kwarg(part: &str) -> bool {
if let Some(eq) = part.find('=') {
let key = part[..eq].trim();
// sympy kwargs are bare identifiers like `positive`, `integer`.
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
} else {
false
}
}
#[pyfunction]
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
pub fn process_pt2(
@@ -113,10 +262,13 @@ pub fn translate_pt2(
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges
// Set initial dynamic dim values from symbol ranges. PT2 emits
// `min_val: null` when the constraint is unbounded; fall back to 1 in
// that case (the smallest valid dim — used only as an initial value).
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
graph.set_dim(*c, rc.min_val as usize);
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
graph.set_dim(*c, initial);
}
}
@@ -132,14 +284,14 @@ pub fn translate_pt2(
})
.collect();
let output_dtypes: Vec<DType> = translated
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
// Python wrapper can return tensors with the right torch.dtype, even when
// luminal collapses the type internally (e.g. int64 → DType::Int).
let output_dtypes: Vec<u32> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
.unwrap_or(DType::F32)
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
})
.collect();

View File

@@ -15,7 +15,16 @@ pub struct ExportedProgram {
#[derive(Debug, Clone, Deserialize)]
pub struct RangeConstraint {
pub min_val: i64,
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
/// constraint is unbounded (no min set), so this must accept None.
#[serde(default)]
pub min_val: Option<i64>,
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
/// unused on the luminal side, but accepted to avoid deserialization
/// errors when PT2 emits it.
#[serde(default)]
#[allow(dead_code)]
pub max_val: Option<i64>,
}
#[derive(Debug, Deserialize)]

View File

@@ -0,0 +1,195 @@
use anyhow::{Context, Result};
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
/// Which SDPA variant we're translating. Governs argument positions and
/// which output slots are consumed downstream.
#[derive(Clone, Copy, Debug)]
pub enum SdpaVariant {
/// `aten._scaled_dot_product_efficient_attention.default(q, k, v, attn_bias,
/// compute_log_sumexp, dropout_p=0., is_causal=False, *, scale=None)
/// -> (output, log_sumexp, philox_seed, philox_offset)`
Efficient,
/// `aten._scaled_dot_product_flash_attention.default(q, k, v, dropout_p=0.,
/// is_causal=False, return_debug_mask=False, *, scale=None)
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
/// rng_state, unused, debug_attn_mask)`
Flash,
/// `aten._scaled_dot_product_flash_attention_for_cpu.default(q, k, v,
/// dropout_p=0., is_causal=False, *, attn_mask=None, scale=None)
/// -> (output, logsumexp)`
FlashForCpu,
/// `aten._scaled_dot_product_cudnn_attention.default(q, k, v, attn_bias,
/// compute_log_sumexp, dropout_p=0., is_causal=False,
/// return_debug_mask=False, *, scale=None)
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
/// philox_seed, philox_offset, debug_attn_mask)`
Cudnn,
/// `aten.scaled_dot_product_attention.default(q, k, v, attn_mask=None,
/// dropout_p=0., is_causal=False, *, scale=None, enable_gqa=False)
/// -> Tensor` (single output, no tuple).
Unified,
}
impl<'a> Translator<'a> {
/// Translate any SDPA op variant into `softmax((Q@K^T)*scale + causal_mask +
/// attn_bias) @ V`. Stores the primary `output` by the node's first output
/// name. Other tuple outputs (logsumexp, philox_seed, etc.) are unused in
/// inference — left unbound; the downstream `getitem(node, 0)` resolves
/// to `output` via the tuple-output name list.
pub(crate) fn translate_sdpa(&mut self, node: &Node, variant: SdpaVariant) -> Result<()> {
let query = self.get_input_tensor(node, 0)?;
let key = self.get_input_tensor(node, 1)?;
let value = self.get_input_tensor(node, 2)?;
// Resolve args by NAME rather than positional index. PT2 serializes
// kwargs inline in `node.inputs` with `kind=2`, so any arg that wasn't
// passed positionally by the caller shifts the indices of subsequent
// positional args. Name-based lookup is unambiguous across variants
// and across caller argument-passing styles.
let arg_by_name =
|name: &str| -> Option<&NodeInput> { node.inputs.iter().find(|i| i.name == name) };
let tensor_arg = |name: &str| -> Option<GraphTensor> {
arg_by_name(name)
.and_then(|i| i.arg.as_tensor_name())
.and_then(|n| self.get_tensor(n).ok())
};
let float_arg =
|name: &str| -> Option<f64> { arg_by_name(name).and_then(|i| i.arg.as_float()) };
let bool_arg =
|name: &str| -> Option<bool> { arg_by_name(name).and_then(|i| i.arg.as_bool()) };
// attn_bias (Efficient/Cudnn/Unified) or attn_mask (FlashForCpu/Unified).
let additive = tensor_arg("attn_bias").or_else(|| tensor_arg("attn_mask"));
let dropout_p = float_arg("dropout_p").unwrap_or(0.0) as f32;
anyhow::ensure!(
dropout_p == 0.0,
"SDPA: dropout_p={dropout_p} unsupported (inference only)"
);
let is_causal = bool_arg("is_causal").unwrap_or(false);
// Silence compiler warnings — variant arg remains for branch-specific
// logic (output tuple-name resolution below) and for future divergence.
let _ = variant;
// `scale` kwarg, default 1/sqrt(head_dim).
let head_dim = query
.shape
.dims
.last()
.and_then(|d| d.to_usize())
.context("SDPA: query head_dim must be concrete")?;
let default_scale = 1.0_f32 / (head_dim as f32).sqrt();
let scale = float_arg("scale")
.map(|v| v as f32)
.unwrap_or(default_scale);
// Math form: scores = (Q @ K^T) * scale; + causal_mask; + attn_bias;
// attn = softmax(scores, dim=-1); out = attn @ V.
let q_ndim = query.shape.len();
anyhow::ensure!(
q_ndim >= 2,
"SDPA: query must have at least 2 dims (got {q_ndim})"
);
// Transpose last two dims of key.
let mut perm: Vec<usize> = (0..q_ndim).collect();
perm.swap(q_ndim - 2, q_ndim - 1);
let key_t = key.permute(perm);
let (q_for_mm, k_for_mm) = ensure_same_dtype(query, key_t);
let scores = q_for_mm.matmul(k_for_mm);
let scale_t = self
.graph
.constant_float(scale)
.cast(scores.dtype)
.expand_rhs(scores.shape);
let mut scores = scores * scale_t;
if is_causal {
let s_q = scores
.shape
.dims
.get(q_ndim - 2)
.and_then(|d| d.to_usize())
.context("SDPA is_causal: S_q must be concrete")?;
let s_k = scores
.shape
.dims
.get(q_ndim - 1)
.and_then(|d| d.to_usize())
.context("SDPA is_causal: S_k must be concrete")?;
let size = s_q.max(s_k);
// triu with diagonal=1 = 1 strictly above diagonal, 0 elsewhere.
let mut mask = self.graph.triu(size, 1).cast(DType::F32);
if s_q != size || s_k != size {
mask = mask.slice_along(0..s_q, 0).slice_along(0..s_k, 1);
}
// -1e9 * mask ≈ -inf where masked, 0 otherwise. Broadcast across
// batch/head prefix dims of `scores`.
let neg_large = mask * (-1e9_f32);
let mut neg_large = neg_large.cast(scores.dtype);
for _ in 0..(q_ndim - 2) {
neg_large = neg_large.expand_dim(0, Expression::from(1usize));
}
let (scores_b, mask_b) = broadcast_binary(scores, neg_large);
scores = scores_b + mask_b;
}
if let Some(bias) = additive {
let (scores_b, bias_b) = ensure_same_dtype(scores, bias);
let (scores_b, bias_b) = broadcast_binary(scores_b, bias_b);
scores = scores_b + bias_b;
}
let attn = scores.softmax(q_ndim - 1);
let (attn, value) = ensure_same_dtype(attn, value);
let out = attn.matmul(value);
// Store the primary output by name. The other tuple outputs are
// inference-time dead ends — downstream getitem(node, 0) resolves to
// the same tensor name we bind here, because pt2 serializes the
// multi-output name list with output[0] as the primary slot.
let out_name = if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
ts.first().map(|t| t.name.clone())
} else if variant == SdpaVariant::Unified {
node.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
} else {
node.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
.or_else(|| {
node.outputs
.first()
.and_then(|o| o.as_tensors.as_ref())
.and_then(|ts| ts.first().map(|t| t.name.clone()))
})
};
if let Some(name) = out_name
&& !name.is_empty()
{
self.tensors.insert(name, out);
} else {
anyhow::bail!("SDPA: no output tensor name found on node {}", node.target);
}
Ok(())
}
}
impl PartialEq for SdpaVariant {
fn eq(&self, other: &Self) -> bool {
matches!(
(self, other),
(SdpaVariant::Efficient, SdpaVariant::Efficient)
| (SdpaVariant::Flash, SdpaVariant::Flash)
| (SdpaVariant::FlashForCpu, SdpaVariant::FlashForCpu)
| (SdpaVariant::Cudnn, SdpaVariant::Cudnn)
| (SdpaVariant::Unified, SdpaVariant::Unified)
)
}
}

View File

@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, 1);
let mut b_expanded = b.expand_dim(0, out_dims[0]);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
@@ -389,8 +389,11 @@ fn depthwise_conv(
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
let patches = patches.expand_dim(2, group_out);
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
// Explicitly expand weight across the batch axis so the elementwise Mul
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
let w_expanded = w_flat
.expand_dim(0, patches.dims()[0])
.expand_dim(3, patches.dims()[3]);
// Element-wise multiply and sum over kernel dim
let product = patches * w_expanded;

View File

@@ -5,6 +5,8 @@ use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
use super::attention::SdpaVariant;
use super::reduction::ArgExtremum;
impl<'a> Translator<'a> {
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
@@ -146,6 +148,7 @@ impl<'a> Translator<'a> {
// Slice/index ops
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
"torch.ops.aten.select.int" => self.translate_select(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
@@ -185,6 +188,17 @@ impl<'a> Translator<'a> {
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
"torch.ops.aten.full.default" => self.translate_full(node)?,
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
// `empty` and `empty_permuted` allocate uninitialised tensors of
// a given shape; the caller fills them. We lower to zeros with
// the same shape+dtype — downstream reads are officially UB on
// PyTorch's side, and downstream writes overwrite our zeros.
// Qwen3MoE's MoE block uses `empty_permuted` to allocate the
// expert-output staging tensor before scatter-adding into it.
"torch.ops.aten.empty.memory_format" | "torch.ops.aten.empty_permuted.default" => {
self.translate_empty(node)?
}
// Qwen3-MoE's expert-balance counts tokens-per-expert via histc.
"torch.ops.aten.histc.default" => self.translate_histc(node)?,
// Grouped matmul (MoE expert dispatch).
// aten._grouped_mm is the native op; transformers::grouped_mm_fallback
@@ -207,6 +221,16 @@ impl<'a> Translator<'a> {
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
// Tensor comparisons
"torch.ops.aten.eq.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a.eq(scalar)
}
"torch.ops.aten.ne.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
@@ -224,6 +248,13 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a.eq(b)
}
"torch.ops.aten.ne.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.ne(b)
}
"torch.ops.aten.le.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
@@ -262,18 +293,27 @@ impl<'a> Translator<'a> {
// Clamp
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
// Cumsum
"torch.ops.aten.cumsum.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let a = if a.dtype == DType::Bool {
a.cast(DType::Int)
} else {
a
};
a.cumsum(dim)
// Rank-0 (scalar) input: cumsum of a single element is the element
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
// and the underlying `cumop` indexes `shape.dims[axis]` which would
// panic with empty dims.
if a.shape.is_empty() {
a
} else {
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
a.cumsum(dim)
}
}
// Floor / Ceil / Erf (approximations)
@@ -286,12 +326,14 @@ impl<'a> Translator<'a> {
}
"torch.ops.aten.ceil.default" => {
let a = self.get_input_tensor(node, 0)?;
// ceil(x) = -floor(-x)
let neg_a = a * (-1.0);
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
let adjust = neg_a.lt(trunc).cast(DType::F32);
let floor_neg = trunc - adjust;
floor_neg * (-1.0)
// ceil(x) = trunc(x) + (x > trunc(x)).
// Cast-to-Int rounds toward zero, so for any positive fractional
// `x` the trunc sits below `x` and we add 1; for negatives we
// have `trunc >= x` and adjust=0. Avoids the two extra
// mul-by-(-1) nodes that the `-floor(-x)` lowering emits.
let trunc = a.cast(DType::Int).cast(DType::F32);
let adjust = a.gt(trunc).cast(DType::F32);
trunc + adjust
}
"torch.ops.aten.erf.default" => {
let a = self.get_input_tensor(node, 0)?;
@@ -367,6 +409,17 @@ impl<'a> Translator<'a> {
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
// PyTorch's argmax/argmin returns int64; the dtype is preserved
// through the LUM-486 boundary widening.
"torch.ops.aten.argmax.default" => {
self.translate_argextremum(node, ArgExtremum::Max)?
}
"torch.ops.aten.argmin.default" => {
self.translate_argextremum(node, ArgExtremum::Min)?
}
// Gather (axis-aware)
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
@@ -397,6 +450,29 @@ impl<'a> Translator<'a> {
return Ok(());
}
// Scaled dot-product attention — each variant binds args slightly
// differently but all lower to matmul+softmax via translate_sdpa.
"torch.ops.aten._scaled_dot_product_efficient_attention.default" => {
self.translate_sdpa(node, SdpaVariant::Efficient)?;
return Ok(());
}
"torch.ops.aten._scaled_dot_product_flash_attention.default" => {
self.translate_sdpa(node, SdpaVariant::Flash)?;
return Ok(());
}
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default" => {
self.translate_sdpa(node, SdpaVariant::FlashForCpu)?;
return Ok(());
}
"torch.ops.aten._scaled_dot_product_cudnn_attention.default" => {
self.translate_sdpa(node, SdpaVariant::Cudnn)?;
return Ok(());
}
"torch.ops.aten.scaled_dot_product_attention.default" => {
self.translate_sdpa(node, SdpaVariant::Unified)?;
return Ok(());
}
// Split
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
@@ -407,6 +483,28 @@ impl<'a> Translator<'a> {
let (a, b) = broadcast_binary(a, b);
a % b
}
// Remainder (Python-style modulo). For float tensors aten.remainder
// returns the same value as `%` would in luminal (Mod follows the
// language's % semantics on f32). The Tensor variant accepts a
// tensor RHS that may be rank-0; broadcast both operands so a
// scalar RHS is expanded to match the LHS shape before mod.
"torch.ops.aten.remainder.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a % b
}
"torch.ops.aten.remainder.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a % scalar
}
// Prod reduction
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,

View File

@@ -2,6 +2,7 @@
//!
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
mod attention;
mod binary;
mod conv;
mod dispatch;
@@ -187,8 +188,21 @@ impl<'a> Translator<'a> {
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
arg.as_int()
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
if let Some(v) = arg.as_int() {
return Ok(v);
}
// Fall through to symbolic-aware resolution. Op-arg slots like `dim`
// and `axis` are always concrete in practice, but with dynamic shapes
// PT2 occasionally hands us a SymInt that is fully bound at export
// time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`);
// accept those when they reduce to a concrete int rather than failing
// with the misleading "not an int" diagnostic.
if let Some(expr) = self.resolve_arg_as_expression(arg)
&& let Some(v) = expr.to_usize()
{
return Ok(v as i64);
}
anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg)
}
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
@@ -207,11 +221,37 @@ impl<'a> Translator<'a> {
}
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
use crate::pt2_schema::SymIntEntry;
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
// Symbolic int lists: tolerate them as long as every entry is a
// bound concrete value. Prevents false "not an int list" failures on
// graphs where torch.export emits sym_ints for what is dimensionally
// a static parameter (kernel sizes, etc. with dynamic batch).
if let Some(entries) = arg.as_sym_ints() {
let mut out = Vec::with_capacity(entries.len());
for entry in entries {
let v = match entry {
SymIntEntry::Int(i) => Some(i.as_int),
SymIntEntry::Name(s) => self
.resolve_sym_int(&s.as_name)
.and_then(|e| e.to_usize().map(|u| u as i64)),
};
match v {
Some(n) => out.push(n),
None => {
anyhow::bail!(
"Input {idx} of {} contains an unresolved sym_int entry",
node.target
)
}
}
}
return Ok(out);
}
arg.as_ints()
.map(|v| v.to_vec())
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))

View File

@@ -120,6 +120,47 @@ impl<'a> Translator<'a> {
Ok(a.slice_along(start..end, dim))
}
/// `aten.select.int(self, dim, index)` — select element `index` along
/// `dim`, dropping that dim. Output rank = input rank 1, so a 1-D input
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
/// are normalized against the input shape.
///
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
/// slice + squeeze decomposition (rather than `gather`) because the
/// composition is a pure shape manipulation with a single iota, which the
/// luminal compiler can fold into surrounding ops.
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let index_raw = self.get_int_arg(node, 2)?;
// Normalize a possibly-negative index. PyTorch accepts indices in
// [-size, size); negative wraps from the end.
let index = if index_raw < 0 {
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"select.int: dim {} must be concrete to normalize a negative index",
dim
)
})?;
let normalized = axis_size as i64 + index_raw;
if normalized < 0 {
bail!(
"select.int: index {} out of range for dim {} of size {}",
index_raw,
dim,
axis_size
);
}
normalized as usize
} else {
index_raw as usize
};
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
}
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
names
@@ -259,21 +300,15 @@ impl<'a> Translator<'a> {
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
// Normalize negative indices for this dimension
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"index.Tensor: dim {} must be concrete for negative index normalization",
dim_idx
)
})?;
let idx_f32 = idx_tensor.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
let adjustment = self
.graph
.constant_float(axis_size as f32)
.expand_rhs(idx_f32.shape);
let is_negative = idx_f32.lt(zero).cast(DType::F32);
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
// Normalize negative indices for this dimension. Stay in Int —
// multiplying an Int tensor by an Expression broadcasts the axis
// size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int
// for the result, Bool→F32 for the negative mask) per indexed dim.
let axis_size = src_shape[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = &strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
@@ -339,20 +374,34 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
})?;
let indices_f32 = indices.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
let adjustment = self
.graph
.constant_float(axis_dim as f32)
.expand_rhs(indices_f32.shape);
let is_negative = indices_f32.lt(zero).cast(DType::F32);
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
// gather_elements requires the index rank to match the source rank,
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
let indices = if promoted_rank0 {
indices.unsqueeze(0)
} else {
indices
};
Ok(a.gather_elements(normalized, dim))
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
// Stay in Int the whole way — multiplying an Int tensor by an
// Expression broadcasts the axis size and avoids three Cast nodes
// (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for
// the negative mask) that the previous F32-routed path emitted.
let axis_dim = a.shape.dims[dim];
let indices_int = indices.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(indices_int.shape);
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
let result = a.gather_elements(normalized, dim);
Ok(if promoted_rank0 {
result.squeeze(0)
} else {
result
})
}
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -410,19 +459,14 @@ impl<'a> Translator<'a> {
// ensures the bool-mask path lowers to a where-blend instead.
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
// Broadcast the (often scalar) value tensor to match data shape,
// then blend by mask. Cast mask to data's dtype for the arithmetic
// so this works for both integer and float data.
// then blend by mask. Cast mask to data's dtype for the
// arithmetic so this works for both integer and float data.
let mask_f = idx_tensor.cast(a.dtype);
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
// Implements where(mask, value, a) as
// a*(1 - mask) + value*mask
// works without a dedicated cond op for any numeric dtype.
let one = self
.graph
.constant_float(1.0)
.cast(a.dtype)
.expand_rhs(a.shape);
return Ok(a * (one - mask_f) + values_b * mask_f);
// where(mask, value, a) as `a + mask*(value - a)`. Saves a mul
// and the `1.0` constant compared to the `a*(1 - m) + v*m`
// form; works for any numeric dtype without a dedicated cond.
return Ok(a + mask_f * (values_b - a));
}
// Integer-index scatter: index_put with indices=[idx_tensor] writes

View File

@@ -6,6 +6,20 @@ use crate::pt2_util::*;
use super::Translator;
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
/// smallest (ascending sort) element when scanning the input.
#[derive(Clone, Copy)]
pub(crate) enum ArgExtremum {
Max,
Min,
}
impl ArgExtremum {
fn descending(self) -> bool {
matches!(self, ArgExtremum::Max)
}
}
/// Compute total element count, returning an error if any dimension is symbolic.
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
a.dims().iter().try_fold(1usize, |acc, d| {
@@ -37,16 +51,26 @@ impl<'a> Translator<'a> {
(axes, keepdim)
}
_ => {
// Full reduce: flatten to [1, N] and reduce axis 1
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
let ndim = a.shape.len();
if ndim == 0 {
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
// and mean of a scalar is just the scalar.
return Ok(a);
}
let total = concrete_numel(&a)?;
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let axes: Vec<usize> = (0..ndim).collect();
let result = match op {
ReductionOp::Sum => flat.sum(vec![1]),
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
ReductionOp::Max => flat.max(vec![1]),
ReductionOp::Min => flat.min(vec![1]),
ReductionOp::Prod => flat.prod(vec![1]),
ReductionOp::Sum => a.sum(axes),
// Note: the luminal `mean` helper divides by the product of the
// axis dims, but we already require concrete dims here so we
// divide by the cached `total` to avoid recomputing.
ReductionOp::Mean => a.sum(axes) / total as f32,
ReductionOp::Max => a.max(axes),
ReductionOp::Min => a.min(axes),
ReductionOp::Prod => a.prod(axes),
};
return Ok(result);
}
@@ -70,4 +94,100 @@ impl<'a> Translator<'a> {
Ok(result)
}
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
/// existing `stable_argsort` op and selecting the first index along the
/// sort axis.
///
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
/// for argmin). FX export emits the inputs positionally:
/// - input 0: tensor
/// - input 1: dim (Int) or None (Other) — when `dim=None`
/// - input 2: keepdim (Bool, optional)
///
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
/// The result of argsort along the sort axis is sliced at index 0,
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
/// `dim`.
///
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
/// view; we materialize it with `* 1` so the resulting node has
/// contiguous strides matching its visible shape (mirroring the
/// `topk` lowering in `translate_topk`). Without this, the output
/// buffer would be sized for the un-sliced argsort tensor while the
/// shape tracker reports a smaller rank.
///
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
/// metadata records int64 and the Python wrapper widens at the
/// boundary, so the PyTorch contract is preserved end-to-end
/// (LUM-486).
pub(crate) fn translate_argextremum(
&mut self,
node: &Node,
which: ArgExtremum,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
// argument (typically `Argument::Other(Null)`), so a missing or
// non-int slot means "reduce over the flattened tensor".
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
self.get_int_arg(node, 1).ok()
} else {
None
};
let keepdim = if node.inputs.len() > 2 {
self.get_bool_arg(node, 2).unwrap_or(false)
} else {
false
};
if a.shape.is_empty() {
match dim_opt {
None | Some(0) | Some(-1) => {
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
// `keepdim=True` does not add a dimension when the input is 0-d.
return Ok(self.graph.constant(0i64).cast(DType::Int));
}
Some(dim) => {
return Err(anyhow::anyhow!(
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
));
}
}
}
let descending = which.descending();
let (sort_axis, base) = match dim_opt {
None => {
// Full-reduce: flatten to 1-D, argsort along axis 0.
let total = concrete_numel(&a)?;
let flat = reshape_tensor(a, vec![Expression::from(total)]);
(0usize, flat)
}
Some(dim_raw) => {
let dim = normalize_dim(dim_raw, a.shape.len());
(dim, a)
}
};
// Pick index 0 along the sort axis. The slice-then-squeeze chain
// produces a non-contiguous view whose physical buffer is still
// sized for the un-sliced argsort tensor; the optional `keepdim`
// unsqueeze adds a stride-0 axis which is also non-contiguous.
// Materialize at the end with `* 1` so the resulting node has
// contiguous strides matching its visible shape (matches the
// pattern used by `translate_topk` for sliced index outputs).
let sorted = base.stable_argsort(sort_axis, descending);
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
let result = if keepdim {
picked.unsqueeze(sort_axis)
} else {
picked
};
Ok(result * 1)
}
}

View File

@@ -72,6 +72,97 @@ impl<'a> Translator<'a> {
})
}
/// Lower `aten.histc.default` for the integer-bincount case.
///
/// Qwen3-MoE's expert-balance layer calls
/// `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)` to count how
/// many tokens were routed to each expert. With those args every
/// integer value `i ∈ [0, K-1]` maps to exactly bin `i`, and the result
/// is equivalent to `torch.bincount`. We implement that case as a
/// broadcast equality + sum:
///
/// counts[b] = sum_i (input[i] == b + min) for b in [0, bins)
///
/// More general histc bin widths (`bins != max - min + 1`, or
/// non-integer values that span fractional bins) are not supported
/// today — the equality path would silently drop them. We bail rather
/// than produce wrong counts.
pub(crate) fn translate_histc(&mut self, node: &Node) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let bins_i64: i64 = self
.get_int_arg(node, 1)
.context("histc: missing `bins` arg (#1)")?;
// `min`/`max` are float kwargs (default 0.0 each, which means
// "auto-pick from input"); for the qwen3-moe call they're always
// integers passed as floats.
let min = self.get_float_arg(node, 2).unwrap_or(0.0);
let max = self.get_float_arg(node, 3).unwrap_or(0.0);
anyhow::ensure!(
input.shape.len() == 1,
"histc: only 1D input is supported, got {}D",
input.shape.len()
);
anyhow::ensure!(
bins_i64 > 0,
"histc: bins must be positive, got {}",
bins_i64
);
// Bincount-equivalent case: one integer value per bin.
anyhow::ensure!(
(max - min - (bins_i64 - 1) as f64).abs() < 1e-6,
"histc: only the bincount-equivalent case (bins == max - min + 1) is \
supported; got bins={}, min={}, max={}. Other cases would need a \
general bin-width / right-edge-inclusion implementation.",
bins_i64,
min,
max,
);
let bins_u = bins_i64 as usize;
let n = input.shape.dims[0];
// arange(bins) [bins] → cast to input dtype, optionally shift by min,
// broadcast to [bins, N], compare for equality with input broadcast.
let mut bins_arange = self.graph.arange(Expression::from(bins_u));
if min != 0.0 {
// `min` is non-zero (uncommon in the qwen3-moe path but legal)
// — shift the comparison values to start at min.
let min_i = min as i64;
let shift = self
.graph
.constant_float(min_i as f32)
.cast(bins_arange.dtype)
.expand_rhs(bins_arange.shape);
bins_arange += shift;
}
let bins_expanded = bins_arange.cast(input.dtype).expand_dim(1, n);
let input_expanded = input.expand_dim(0, Expression::from(bins_u));
let matches = input_expanded.eq(bins_expanded); // Bool [bins, N]
let out_dtype = self.output_meta_dtype(node)?;
Ok(matches.cast(out_dtype).sum(1))
}
/// Lower `aten.empty.memory_format` and `aten.empty_permuted.default`.
///
/// Both allocate an uninitialised tensor; the caller is responsible for
/// writing into it. We materialise zeros instead — luminal has no
/// "uninitialised" notion, and PyTorch's contract on `empty` outputs is
/// undefined for any read prior to a write, so a zero-fill is sound.
/// `aten.empty_permuted` additionally takes a `physical_layout` arg
/// (the storage permutation); for a zero-filled tensor that's a no-op.
pub(crate) fn translate_empty(&mut self, node: &Node) -> Result<GraphTensor> {
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
let dtype = self.output_meta_dtype(node)?;
let zero = self.graph.constant_float(0.0).cast(dtype);
Ok(if shape.is_empty() {
zero
} else {
zero.expand_rhs(shape)
})
}
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
@@ -109,13 +200,18 @@ impl<'a> Translator<'a> {
/// Output `[S, N]` where token m (in group g s.t. `offs[g-1] <= m < offs[g]`)
/// is multiplied by `weight[g]`.
///
/// Implementation:
/// 1. Batched matmul across every expert: `[G, S, K] @ [G, K, N] → [G, S, N]`
/// (input broadcast along the G batch dim — matches luminal's 3D@3D pattern
/// so the CUDA optimizer can fuse it into a batched GEMM).
/// 2. Build a `[G, S]` group-membership mask from `offs`:
/// `expert_id[m] = Σ_g (offs[g] <= m)`, then `mask[g, m] = (g == expert_id[m])`.
/// 3. Multiply `[G, S, N]` result by the broadcast mask and sum over `G`.
/// Implementation: for each token m we (a) compute its expert id from offs,
/// (b) gather only that expert's `[K, N]` slice from weight, and (c) do a
/// single per-token matmul. The gather pattern mirrors the rust qwen3_moe
/// example's `gather_experts`, which the GLUMoE host-op fusion in
/// `luminal_cuda_lite` is designed to recognise.
///
/// Why not the straightforward `[G, S, K] @ [G, K, N] → [G, S, N]` + mask:
/// it forces a full F32 cast of the entire `[G, K, N]` weight tensor as
/// search-time intermediate, which OOMs on real MoE checkpoints
/// (Qwen3-30B-A3B: 1.5 GB / layer × 48 layers for gate-up alone). Gathering
/// first keeps the F32 cast on `[S, K, N]` instead — for prefill (S = top_k)
/// that is a 16× shrink (G=128, top_k=8).
///
/// `offs` flows through as a runtime tensor — the routing decision is computed
/// at execution time by the gate network and the same compiled graph handles
@@ -143,62 +239,100 @@ impl<'a> Translator<'a> {
let s = input.shape.dims[0];
let g = weight.shape.dims[0];
let k = weight.shape.dims[1];
let n = weight.shape.dims[2];
let input_f = input.cast(DType::F32);
let weight_f = weight.cast(DType::F32);
let offs_f = offs.cast(DType::F32);
// Batched matmul over every expert: [G, S, K] @ [G, K, N] → [G, S, N].
let input_batched = input_f.expand_dim(0, g);
let all_out = input_batched.matmul(weight_f);
// Group mask [G, S].
let s_arange = self.graph.arange(s).cast(DType::F32);
let g_arange = self.graph.arange(g).cast(DType::F32);
let ge_boundary = s_arange
// expert_id[m] = number of g s.t. m >= offs[g], clamped to [0, G-1].
// Same value as HF MoE's `expert_ids.clamp(0, num_experts-1)` for
// invalid expert IDs from EP, AND protects search-time profiling:
// dummy-1 input bytes give offs=[1,…,1], which pushes the raw count
// to G for any token with index ≥ 1 and would OOB the weight gather.
//
// Stay in Int throughout — arange / offs are already Int, ge → Bool
// → cast(Int), sum stays Int, and the binary `minimum` handles the
// clamp without an F32 round-trip.
let _ = g
.to_usize()
.context("_grouped_mm: G (num_experts) must be concrete")?;
let s_arange = self.graph.arange(s); // Int [S]
let ge_int = s_arange
.expand_dim(0, g)
.ge(offs_f.expand_dim(1, s))
.cast(DType::F32);
let expert_id = ge_boundary.sum(0);
let mask = g_arange
.expand_dim(1, s)
.eq(expert_id.expand_dim(0, g))
.cast(DType::F32);
.ge(offs.expand_dim(1, s)) // Bool [G, S]
.cast(DType::Int); // Int [G, S]
let raw = ge_int.sum(0); // Int [S], values in [0, G]
let cap = self.graph.constant(g - 1).expand_dim(0, s); // Int [S], all G-1
let expert_id = raw.minimum(cap); // Int [S]
// Apply mask and sum over experts.
let out = (all_out * mask.expand_dim(2, n)).sum(0);
// Flat gather index into weight (treated as a length-G*K*N 1D buffer):
// flat[m, k_, n_] = expert_id[m] * (K*N) + k_ * N + n_
// Encoded as `Mul(expert_id, Iota(io_const)) + Iota(MIter, K*N)` so the
// resulting Gather matches the GLUMoE / gather-experts egglog patterns.
let io = k * n;
let base = expert_id * io;
let within = self.graph.iota(Expression::from('z'), (k, n));
let exp_base = base.expand_dim(1, k).expand_dim(2, n);
let exp_within = within.expand_dim(0, s);
let flat_idx = exp_base + exp_within;
Ok(out.cast(input.dtype))
// Gather → [S, K, N], preserves weight's native dtype (bf16 stays bf16).
let weight_gathered = weight.gather(flat_idx);
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
// Operands stay in their native dtype — no F32 cast on the gathered
// weight or the input. The earlier cast(F32) was a holdover from the
// broadcast-and-mask version (which had to use F32 because of the
// cast(F32) on the mask). Gather-then-matmul has no such requirement,
// and casting `[S, K, N]` to F32 doubled the gather scratch (~100 MB
// to ~200 MB per layer for Qwen3-30B-A3B prefill). Matmul rewrites
// (cuBLASLt etc.) handle bf16 input with F32 accumulator internally.
let result = input.unsqueeze(1).matmul(weight_gathered).squeeze(1);
Ok(result.cast(input.dtype))
}
/// Build the where-formula graph: `cond * x + (1 - cond) * y`, computed
/// in F32, cast back to `out_dtype`. Shared between `translate_where`,
/// `translate_where_scalar_other`, and `translate_masked_fill_scalar` so
/// they all go through one well-tested code path.
pub(crate) fn where_formula(
&mut self,
cond: GraphTensor,
x: GraphTensor,
y: GraphTensor,
out_dtype: DType,
) -> GraphTensor {
let (cond_b, x_b) = broadcast_binary(cond, x);
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
// Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops
// plus the explicit `1.0` constant. Mathematically identical for
// c ∈ {0, 1} and produces the same F32 output type.
let c = cond_bc.cast(DType::F32);
let x_f = x_bc.cast(DType::F32);
let y_f = y_bc.cast(DType::F32);
// Cast back: an F32 result downstream-interpreted as bf16 walks the
// buffer at half-stride, returning every-other-element zeros.
(y_f + c * (x_f - y_f)).cast(out_dtype)
}
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
let cond = self.get_input_tensor(node, 0)?;
let x = self.get_input_tensor(node, 1)?;
let y = self.get_input_tensor(node, 2)?;
// Ensure x and y have the same dtype
let (x, y) = ensure_same_dtype(x, y);
// Broadcast all three tensors to a common shape first
let (cond_b, x_b) = broadcast_binary(cond, x);
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
let c = cond_bc.cast(DType::F32);
let x_f = x_bc.cast(DType::F32);
let y_f = y_bc.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
Ok(c * x_f + (one - c) * y_f)
let out_dtype = x.dtype;
Ok(self.where_formula(cond, x, y, out_dtype))
}
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
// Broadcast cond and x to a common shape
let (cond_b, x_b) = broadcast_binary(cond, x);
let c = cond_b.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
Ok(c * x_b + (one - c) * other)
let out_dtype = x.dtype;
// Build a tensor for the scalar `other` matching `x`'s shape so we
// can route through the shared where_formula helper.
let other = self.graph.constant_float(other_val).expand_rhs(x.shape);
Ok(self.where_formula(cond, x, other, out_dtype))
}
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -253,33 +387,37 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
// Determine output names
let values_name = node
.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
let indices_name =
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
ts.get(1).map(|t| t.name.clone())
} else if node.outputs.len() > 1 {
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
} else {
None
};
let tuple_outputs = node.outputs.first().and_then(|o| o.as_tensors.as_ref());
let values_name = if let Some(ts) = tuple_outputs {
ts.first().map(|t| t.name.clone())
} else {
node.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
};
let indices_name = if let Some(ts) = tuple_outputs {
ts.get(1).map(|t| t.name.clone())
} else if node.outputs.len() > 1 {
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
} else {
None
};
// Build top-k outputs from a full stable argsort, then slice to k.
// Build top-k outputs from a full stable argsort. Slice the indices
// before gathering values so the gather shape matches the requested
// top-k output rather than the full sort width.
let full_argsort = a.stable_argsort(dim, true);
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
// Only build the outputs that are consumed.
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
let values = a.gather_elements(topk_indices, dim);
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
// Materialize the sliced indices through a copy before storing them.
let indices = full_argsort.slice_along(..k, dim) * 1.0;
self.tensors.insert(idx_name, indices);
self.tensors.insert(idx_name, topk_indices);
}
Ok(())

View File

@@ -51,13 +51,19 @@ impl<'a> Translator<'a> {
let a = self.get_input_tensor(node, 0)?;
for input in &node.inputs {
if input.name == "dtype" {
if let Some(dtype_int) = input.arg.as_int() {
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
return Ok(a.cast(dtype));
}
if let Some(dtype_int) = input.arg.as_scalar_type() {
let dtype = torch_dtype_int_to_luminal(dtype_int);
return Ok(a.cast(dtype));
let dtype_int = input
.arg
.as_int()
.map(|i| i as u32)
.or_else(|| input.arg.as_scalar_type());
if let Some(d) = dtype_int {
let dtype = torch_dtype_int_to_luminal(d);
// Skip emitting a Cast op when the dtype already matches —
// PT2 graphs frequently emit `_to_copy` purely as a clone hint
// (e.g. dtype=float32 on a tensor that is already F32), and
// every redundant Cast inflates the graph and survives until
// optimization passes can prove it as a no-op.
return Ok(if a.dtype == dtype { a } else { a.cast(dtype) });
}
}
}
@@ -131,37 +137,34 @@ impl<'a> Translator<'a> {
}
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
// `masked_fill(input, mask, fill)` = `where(mask, fill, input)`.
// Routes through the shared `where_formula` helper so we exercise
// the exact same code path as `aten.where.self`, which is verified
// to handle the bf16 cast-back correctly. Hand-rolling the same
// formula directly here used to drift (egglog made different
// rewrite choices on the rebuilt-locally graph), so we deliberately
// re-use the helper.
// `aten.masked_fill.Scalar(input, mask, fill)` ≡
// `aten.where.self(mask, full_like(input, fill), input)`. The
// `full_like + where` sequence is the verified-working path
// (test: `where(mask, torch.zeros_like(x), x)` round-trips with
// max_diff = 0); we reproduce its exact graph-build order here.
// Hand-rolling the formula in any other shape (single-mul, F32
// throughout, alternative constant-cast orderings) routes egglog
// through a rewrite that returns an F32 buffer downstream-read as
// bf16 — the every-other-element-zero pattern.
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
let (input, mask) = broadcast_binary(input, mask);
let work_dtype = if input.dtype == DType::Bool {
DType::Int
} else {
input.dtype
};
let input_work = if input.dtype == DType::Bool {
input.cast(DType::Int)
} else {
input
};
let mask_work = mask.cast(work_dtype);
let fill_work = self
let out_dtype = input.dtype;
// Build fill_t exactly like translate_full_like does:
// constant_float(val).cast(dtype).expand_rhs(reference.shape)
let fill_t = self
.graph
.constant_float(fill)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let one = self
.graph
.constant_float(1.0)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let result = mask_work * fill_work + (one - mask_work) * input_work;
Ok(if input.dtype == DType::Bool {
result.cast(DType::Bool)
} else {
result
})
.cast(out_dtype)
.expand_rhs(input.shape);
Ok(self.where_formula(mask, fill_t, input, out_dtype))
}
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -210,12 +213,18 @@ impl<'a> Translator<'a> {
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
// Check rounding_mode kwarg
// Check rounding_mode kwarg. PT2 serializes string args as
// {"as_string": "<value>"}, so we have to drill into the JSON.
let rounding_mode = node.inputs.iter().find_map(|input| {
if input.name == "rounding_mode"
&& let Argument::Other(val) = &input.arg
{
return val.as_str().map(|s| s.to_string());
if let Some(s) = val.as_str() {
return Some(s.to_string());
}
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
return Some(s.to_string());
}
}
None
});
@@ -266,4 +275,52 @@ impl<'a> Translator<'a> {
}
Ok(result)
}
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
///
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
/// overload takes tensor bounds that appear as separate input nodes in the
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
///
/// - rank-0 (scalar wrapped in a tensor) — most common
/// - same shape as self (per-element clamp, e.g. learned bounds)
/// - any shape that broadcasts to self via right-align + size-1 expand
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
///
/// We use `broadcast_binary` to right-align and expand both operands to a
/// common shape before the elementwise max/min, matching PyTorch semantics
/// across all three modes.
///
/// Either bound may be absent (FX represents this as a non-tensor argument
/// at the corresponding input slot), in which case we clamp to one side
/// only.
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_tensor = node
.inputs
.get(1)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let max_tensor = node
.inputs
.get(2)
.and_then(|i| i.arg.as_tensor_name())
.map(|n| self.get_tensor(n))
.transpose()?;
let mut result = a;
if let Some(lo) = min_tensor {
let lo = lo.cast(result.dtype);
let (r, lo) = broadcast_binary(result, lo);
result = r.maximum(lo);
}
if let Some(hi) = max_tensor {
let hi = hi.cast(result.dtype);
let (r, hi) = broadcast_binary(result, hi);
result = r.minimum(hi);
}
Ok(result)
}
}

View File

@@ -1,5 +1,6 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
import warnings
from typing import List
import torch
@@ -8,6 +9,10 @@ from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
class DTypeBoundaryWarning(UserWarning):
"""Warns when the PyTorch boundary must cast input data before execution."""
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
@@ -77,7 +82,10 @@ class CompiledModel:
)
user_inputs = inputs
input_device = inputs[0].device if inputs else torch.device("cpu")
# Use the first *user* input for device detection — when torch.compile
# has lifted SymInts or weights into the call args, `inputs[0]` may not
# be a tensor. user_inputs has been filtered to actual tensors.
input_device = user_inputs[0].device if user_inputs else torch.device("cpu")
# Auto-detect dynamic dims from input shapes
if self._has_dynamic_dims:
@@ -92,6 +100,15 @@ class CompiledModel:
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if tensor.dtype != expected_dtype:
warnings.warn(
"Luminal compiled input "
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
f"expects {expected_dtype}; converting at every call will "
"allocate/copy input data.",
DTypeBoundaryWarning,
stacklevel=2,
)
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
@@ -132,6 +149,11 @@ class CompiledModel:
# Run the graph
self._graph.run()
# Integer dtypes for which we read the buffer as i32 and then cast.
# Includes int64 because luminal collapses all integer types to its
# 32-bit `Int` internally — we restore the original precision here.
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
# Collect outputs
if _use_zero_copy:
outputs = []
@@ -147,11 +169,12 @@ class CompiledModel:
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype == torch.int32:
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
elif out_dtype == torch.bool:
@@ -179,9 +202,13 @@ class CompiledModel:
if i < len(output_dtype_codes)
else torch.float32
)
if out_dtype == torch.int32:
if out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
)
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))

View File

@@ -11,14 +11,21 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
def _detect_factory_capsule(example_inputs):
"""Pick the best built-in factory capsule based on input device."""
device = example_inputs[0].device if example_inputs else torch.device("cpu")
# Dynamo can prefix `example_inputs` with SymInt entries when shapes are
# dynamic — those have no `.device`. Pick the first real tensor instead.
first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None)
device = first_tensor.device if first_tensor is not None else torch.device("cpu")
if device.type == "cuda":
try:
from .luminal import _cuda_lite_factory_capsule
return _cuda_lite_factory_capsule()
except ImportError:
pass
except (ImportError, AttributeError) as exc:
raise RuntimeError(
"CUDA input was provided, but luminal_python was not built with "
"the cuda feature. Rebuild with `maturin develop --features cuda` "
"or run through `run_tests_cuda.sh`/the Modal CUDA test runner."
) from exc
from .luminal import _native_factory_capsule
return _native_factory_capsule()

View File

@@ -110,7 +110,35 @@ def _export_kwargs():
return kwargs
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
def _decomp_table():
"""Decomposition table for `ep.run_decompositions()` that preserves SDPA.
The default table decomposes `aten.scaled_dot_product_attention.default`
into ~20 ops (matmul/softmax + an `eq.Scalar`/`logical_not`/`any.dim`/
`where`/`full_like` "all-masked" sentinel chain). We translate SDPA as a
single fused op via `translate_sdpa`, so we strip the SDPA decompositions
here to let them survive into the FX graph the translator walks.
"""
try:
from torch.export import default_decompositions
except ImportError:
return None
table = default_decompositions()
sdpa_ops = [
torch.ops.aten.scaled_dot_product_attention.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
]
for op in sdpa_ops:
table.pop(op, None)
return table
def _save_and_compile(
ep_or_path, factory, search_iterations, original_weights=None, user_indices=None
):
"""Compile a PT2 model via Rust, return CompiledModel.
Args:
@@ -148,12 +176,171 @@ def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=N
# Load CPU weights after compilation
_load_cpu_weights(compiled, cpu_weights)
return CompiledModel(compiled, weight_refs=keep_alive)
return CompiledModel(
compiled, weight_refs=keep_alive, user_indices=user_indices
)
finally:
if owns_tmpdir and tmpdir:
shutil.rmtree(tmpdir, ignore_errors=True)
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
Range bounds returned by ShapeEnv can be sympy `Infinity` / `-Infinity`
(as well as the internal `int_oo` sentinel), which both raise on `int(...)`.
Treat anything non-finite — and anything that simply doesn't coerce — as
"no bound."
"""
if value is None:
return None
# Stringify is robust against the various sentinel types: sympy.Infinity,
# torch.utils._sympy.numbers.IntInfinity, etc. all stringify to "oo"/"-oo".
s = str(value)
if "oo" in s or "inf" in s.lower():
return None
try:
return int(value)
except (TypeError, ValueError, OverflowError, AttributeError):
return None
def _strip_symint_placeholders(gm, example_inputs):
"""Rewrite SymInt graph inputs into tensor.size(d) calls, then drop them.
When Dynamo decides a dim is dynamic it emits the symbol as a separate
placeholder (e.g. `s77`) alongside the user's tensor (whose FakeTensor shape
references the same symbol). torch.export.export rejects mixed
SymInt/Tensor positional args, and the Rust pipeline doesn't model SymInt
inputs anyway — so we replace each SymInt placeholder's uses with
`aten.sym_size.int(tensor, dim)` for the first tensor placeholder whose
example_value's shape[dim] matches the symbol, then erase the placeholder.
Returns `(post_strip_inputs, kept_indices, ok)` where:
- `post_strip_inputs` is `example_inputs` filtered to tensor-only entries
- `kept_indices` is the indices into `example_inputs` we kept (used by
the caller to compose with any prior input filter, e.g. lifted-weight
re-internalization, when handing `user_indices` to CompiledModel)
- `ok` is False when at least one SymInt placeholder couldn't be
rewritten (compound expression with users, or no matching tensor dim);
the caller should fall back to no-dynamic export in that case.
"""
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
# Collect (placeholder_node, example_input_idx) for every SymInt placeholder.
symint_entries = []
tensor_entries = []
for idx, node in enumerate(placeholders):
ev = node.meta.get("example_value")
if isinstance(ev, torch.SymInt) or (
ev is None
and idx < len(example_inputs)
and isinstance(example_inputs[idx], torch.SymInt)
):
symint_entries.append((node, idx))
else:
tensor_entries.append((node, idx))
if not symint_entries:
return example_inputs, list(range(len(example_inputs))), True
# Build a symbol -> (tensor_node, dim) lookup from the tensor placeholders'
# example FakeTensor shapes. Any tensor whose shape[d] is the SymInt
# is a valid source — pick the first.
sym_to_source = {}
for t_node, _ in tensor_entries:
ev = t_node.meta.get("example_value")
if not torch.is_tensor(ev):
continue
for d, s in enumerate(ev.shape):
if isinstance(s, torch.SymInt):
key = str(s.node.expr)
sym_to_source.setdefault(key, (t_node, d))
# Rewrite each SymInt placeholder's uses to sym_size calls, then erase it.
all_clean = True
for s_node, _ in symint_entries:
ev = s_node.meta.get("example_value")
if ev is None:
all_clean = False
continue
# The placeholder's example_value is the SymInt itself; its expr is the
# symbol name (or a compound expression we can't lift this way).
expr_str = str(ev.node.expr)
source = sym_to_source.get(expr_str)
if source is None:
# Compound expression or no tensor carries this symbol — bail.
if len(s_node.users) > 0:
all_clean = False
continue
gm.graph.erase_node(s_node)
continue
if len(s_node.users) > 0:
t_node, dim = source
with gm.graph.inserting_after(t_node):
size_node = gm.graph.call_function(
torch.ops.aten.sym_size.int, (t_node, dim)
)
size_node.meta["val"] = ev
size_node.meta["example_value"] = ev
s_node.replace_all_uses_with(size_node)
gm.graph.erase_node(s_node)
if not all_clean:
# Recompile defensively even on partial success — some erases may have
# happened. Caller will decide whether to proceed.
gm.graph.lint()
gm.recompile()
return example_inputs, list(range(len(example_inputs))), False
gm.graph.lint()
gm.recompile()
# Filter the runtime example_inputs to drop the stripped SymInt entries.
kept_indices = [idx for _, idx in tensor_entries]
keep_set = set(kept_indices)
new_inputs = [v for i, v in enumerate(example_inputs) if i in keep_set]
return new_inputs, kept_indices, True
def _build_dynamic_shapes_from_gm(gm):
"""Construct a torch.export.export `dynamic_shapes` spec from FX metadata.
Walks each tensor placeholder's `meta['example_value']` FakeTensor and
marks every SymInt dim as `Dim.AUTO`. Sharing/equality relationships
between symbolic dims are already encoded in the FakeTensor shapes —
torch.export's symbolic-shape engine recovers them during the trace, so
we don't need to allocate named `Dim` objects ourselves.
The returned spec is wrapped under `{"args": (...)}` because Dynamo's
`GraphModule.forward(*args, **kwargs)` signature treats positional inputs
as the `args` tuple.
Returns None if there are no symbolic dims to mark.
"""
from torch.export import Dim
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
per_input_spec = []
saw_dynamic = False
for node in placeholders:
ev = node.meta.get("example_value")
if not torch.is_tensor(ev):
per_input_spec.append(None)
continue
spec = {}
for d, s in enumerate(ev.shape):
if isinstance(s, torch.SymInt):
spec[d] = Dim.AUTO
saw_dynamic = True
per_input_spec.append(spec if spec else None)
if not saw_dynamic:
return None
return {"args": tuple(per_input_spec)}
def _reinternalize_lifted_params(gm, example_inputs):
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
@@ -203,7 +390,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
if user_indices
else list(example_inputs)
)
return gm, user_inputs, original_weights
return gm, user_inputs, original_weights, user_indices
# ---------------------------------------------------------------------------
@@ -218,96 +405,238 @@ def compile(
factory=None,
export_kwargs=None,
dynamic_dim=None,
dynamic_shapes=None,
):
"""Compile a PyTorch model to run on Luminal via PT2 pipeline.
Args:
model: A PyTorch nn.Module.
example_input: Example input tensor(s) for tracing.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Which input dimension to make dynamic.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
symbolic dim is needed.
* `None` (default): leave shapes static.
* `int`: mark that dim of the (first) input as `Dim.AUTO`.
* `Iterable[int]`: mark each listed dim of the first input.
* `"auto"`: mark every non-trivial dim (size > 1) of the
first input as `Dim.AUTO` — works for floating-point and
integer inputs alike.
dynamic_shapes: Direct passthrough to `torch.export.export`'s
`dynamic_shapes` argument. When provided, takes precedence over
`dynamic_dim`. Use this for full control: per-input specs,
`Dim("name", min=, max=)` ranges, shared dims across inputs, etc.
Returns:
A CompiledModel callable.
"""
if dynamic_dim is None:
dynamic_dim = "auto"
if factory is None:
factory = _detect_factory_capsule([example_input])
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
kwargs = export_kwargs or {}
extra = _export_kwargs()
# Build dynamic_shapes from the convenience knob if the caller didn't
# hand us a full spec. `dynamic_dim=None` falls back to the legacy
# `"auto"` behavior (mark the last axis of an integer input as dynamic)
# so callers that relied on the previous default keep working.
if dynamic_shapes is None:
if dynamic_dim is None:
dynamic_dim = _legacy_auto_dim(example_args)
if dynamic_dim is not None:
dynamic_shapes = _build_dynamic_shapes_from_dim_arg(
dynamic_dim, example_args
)
# `torch.export.export` is finicky: when `dynamic_shapes` is set it
# validates the spec against the example shapes and raises on any
# disagreement (e.g. the user marked a dim as dynamic but their model
# specialises it to a constant). Fall back to a static export so the
# caller still gets a usable CompiledModel rather than a hard error.
ep = None
# Try dynamic dimension export
candidate_dims = []
if isinstance(dynamic_dim, int):
candidate_dims = [dynamic_dim]
elif dynamic_dim == "auto" and example_input.dim() >= 2:
if not example_input.is_floating_point():
candidate_dims = [example_input.dim() - 1]
if candidate_dims:
from torch.export import Dim
for dim_idx in candidate_dims:
try:
seq = Dim("seq", min=2)
arg_shapes = {dim_idx: seq}
kwarg_shapes = {k: None for k in kwargs}
dynamic_shapes = (
(arg_shapes,) + tuple(kwarg_shapes.values())
if kwarg_shapes
else (arg_shapes,)
)
ep = torch.export.export(
model,
(example_input,),
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
**extra,
)
ep = ep.run_decompositions()
break
except Exception:
continue
if dynamic_shapes is not None:
try:
ep = torch.export.export(
model,
example_args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
**extra,
)
ep = ep.run_decompositions(_decomp_table())
except Exception:
ep = None
if ep is None:
ep = torch.export.export(
model,
(example_input,),
example_args,
kwargs=kwargs,
dynamic_shapes=None,
**extra,
)
ep = ep.run_decompositions()
ep = ep.run_decompositions(_decomp_table())
return _save_and_compile(ep, factory, search_iterations)
def pt2_backend(gm, example_inputs, factory=None):
"""torch.compile backend using PT2 pipeline.
def _drop_input_guards(ep):
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
``i = torch.tensor(2)``), torch.export records two equivalent input
guards: ``L['i'].item() == 2`` (referencing the original local source)
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
Two failures stack on top of each other:
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
a literal ``L`` reference in the generated ``_guards_fn`` and raising
``NameError: name 'L' is not defined`` during retracing.
2. Even after dropping the unresolvable guard, the surviving
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
a graph break.
These guards exist solely to validate inputs at runtime in eager-mode
consumers of the ExportedProgram; the luminal compiler does its own
input shape/dtype checks against the compiled graph signature, so we
are not losing any safety by clearing them.
"""
if hasattr(ep, "_guards_code"):
ep._guards_code = []
def _drop_dead_data_dependent_ops(gm):
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
When dynamo specializes a 0-d int input by tracing through ``.item()``,
the resulting graph may contain a dead ``aten.item.default`` node whose
output is never consumed. luminal's translator does not lower
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
node in the graph causes a graph break at compile time. Eliminating it
keeps the (correctly specialized) downstream graph in a single subgraph.
"""
graph = gm.graph
changed = False
for node in list(graph.nodes):
if (
node.op == "call_function"
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
and len(node.users) == 0
):
graph.erase_node(node)
changed = True
if changed:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
def _legacy_auto_dim(example_args):
"""Match the historical `dynamic_dim="auto"` heuristic.
Returns the last axis of the first input when that input is a 2-D-or-
larger integer tensor (the typical token-id sequence pattern), and
`None` otherwise. Float inputs and 1-D tensors fall through to the
static export path the legacy code did.
"""
if not example_args:
return None
first = example_args[0]
if not torch.is_tensor(first):
return None
if first.is_floating_point():
return None
if first.dim() < 2:
return None
return first.dim() - 1
def _build_dynamic_shapes_from_dim_arg(dynamic_dim, example_args):
"""Translate the `dynamic_dim` shorthand into a full `dynamic_shapes` spec.
Always targets the first positional input — multi-input dynamic specs
require the caller to use `dynamic_shapes=` directly so they can name
which input each dim belongs to.
"""
from torch.export import Dim
if not example_args:
return None
first = example_args[0]
if not torch.is_tensor(first):
return None
if isinstance(dynamic_dim, int):
dims = [dynamic_dim]
elif isinstance(dynamic_dim, str) and dynamic_dim == "auto":
# Mark every dim with size > 1 as dynamic. Dim.AUTO leaves
# torch.export to pick a Dim per axis and infer relationships from
# the example FakeTensor.
dims = [d for d, s in enumerate(first.shape) if int(s) > 1]
elif hasattr(dynamic_dim, "__iter__"):
dims = [int(d) for d in dynamic_dim]
else:
return None
if not dims:
return None
spec = {d: Dim.AUTO for d in dims}
rest = (None,) * (len(example_args) - 1)
return (spec,) + rest
def _eager_pt2_compile(
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
):
"""Run torch.export → save → Rust compile end-to-end. Returns CompiledModel.
Factored out so both the eager (static-shapes) and lazy (dynamic-shapes)
backend paths share a single implementation.
"""
import gc
if factory is None:
factory = _detect_factory_capsule(example_inputs)
try:
ep = torch.export.export(
gm,
tuple(user_inputs),
dynamic_shapes=dynamic_shapes,
**_export_kwargs(),
)
except Exception:
# If torch.export rejects the dynamic spec (e.g. user code introduced
# a constraint we didn't model), retry without it. Better to lose the
# dynamic-dim optimization than to hand the user a hard failure.
if dynamic_shapes is None:
raise
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
# LUM-499: drop dynamo-emitted input guards before run_decompositions
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
# data-dependent .item() calls and unresolved `L[...]` references.
_drop_input_guards(ep)
_drop_dead_data_dependent_ops(ep.graph_module)
ep = ep.run_decompositions(_decomp_table())
gm = gm.eval()
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
ep = ep.run_decompositions()
# When using shared memory (original_weights), strip large weight buffers from
# the EP before saving. The Rust side uses device pointers for these weights,
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
# When using shared memory (original_weights), strip large weight buffers
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
@@ -315,9 +644,9 @@ def pt2_backend(gm, example_inputs, factory=None):
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
# Save the exported program to disk, then free it and the traced graph module
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
# holding ep alive during compilation would double the weight memory on GPU.
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
# alive during compile would double weight memory on GPU.
tmpdir = tempfile.mkdtemp(prefix="luminal_")
pt2_path = os.path.join(tmpdir, "model.pt2")
torch.export.save(ep, pt2_path)
@@ -328,9 +657,129 @@ def pt2_backend(gm, example_inputs, factory=None):
torch.cuda.empty_cache()
try:
result = _save_and_compile(
pt2_path, factory, 10, original_weights=original_weights
return _save_and_compile(
pt2_path,
factory,
10,
original_weights=original_weights,
user_indices=user_indices,
)
return result
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
class _LazyDynamicCompiledModel:
"""Defers torch.export + Rust compile to the first invocation.
Calling `torch.export.export(..., dynamic_shapes=...)` from inside a
Dynamo backend frame triggers an internal "Guard failed on the same
frame it was created" assertion in PyTorch — `torch.export`'s symbolic
tracer mutates the ShapeEnv that Dynamo is also relying on for the
surrounding compile, leaving the just-installed guards in an
inconsistent state. Punting all of that work to the first runtime call
sidesteps the issue: by then Dynamo's guard installation is finished,
so the shape-env mutations no longer matter.
This wrapper is API-compatible with `CompiledModel` for the bits the
caller cares about (`__call__`, `has_dynamic_dims`, `dim_params`,
`set_dim`). Subsequent calls forward straight to the inner CompiledModel.
"""
def __init__(
self,
gm,
user_inputs,
original_weights,
user_indices,
dynamic_shapes,
factory,
):
self._gm = gm
self._user_inputs = user_inputs
self._original_weights = original_weights
self._user_indices = user_indices
self._dynamic_shapes = dynamic_shapes
self._factory = factory
self._compiled = None
def _ensure_compiled(self):
if self._compiled is None:
self._compiled = _eager_pt2_compile(
self._gm,
self._user_inputs,
self._original_weights,
self._user_indices,
self._dynamic_shapes,
self._factory,
)
# Drop references to inputs we no longer need — the Rust side
# holds onto weights via device pointers / CPU buffers.
self._gm = None
self._user_inputs = None
self._original_weights = None
return self._compiled
def __call__(self, *inputs, **kwargs):
return self._ensure_compiled()(*inputs, **kwargs)
@property
def has_dynamic_dims(self):
return self._ensure_compiled().has_dynamic_dims
@property
def dim_params(self):
return self._ensure_compiled().dim_params
def set_dim(self, name, value):
return self._ensure_compiled().set_dim(name, value)
def pt2_backend(gm, example_inputs, factory=None):
"""torch.compile backend using PT2 pipeline.
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
"""
import copy as _copy
if factory is None:
factory = _detect_factory_capsule(example_inputs)
# Work on a private copy of the GraphModule. Dynamo holds onto the
# original to install guards and to retrace on shape changes; mutating it
# here (erasing SymInt placeholders, re-internalizing lifted weights)
# corrupts that bookkeeping and surfaces as cryptic "guard failed on the
# same frame" assertions on the next call. The deepcopy is cheap relative
# to the rest of the export pipeline.
gm = _copy.deepcopy(gm).eval()
gm, user_inputs, original_weights, post_lift_indices = _reinternalize_lifted_params(
gm, example_inputs
)
# Lift any SymInt placeholders Dynamo emitted alongside the tensor inputs
# into `aten.sym_size.int` calls so the re-export sees a tensor-only
# signature, then derive the `dynamic_shapes` spec from the surviving
# tensor placeholders' FakeTensor shapes. If the strip can't fully clean
# the graph (e.g. a compound-expr SymInt with users), we drop dynamic
# info and fall back to per-shape recompilation — same as today.
user_inputs, post_strip_subindices, strip_ok = _strip_symint_placeholders(
gm, user_inputs
)
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
# Compose both filter steps into a single user_indices list relative to
# the *original* example_inputs Dynamo will pass at runtime — so
# CompiledModel.__call__ can drop both lifted weights and SymInt args.
user_indices = [post_lift_indices[i] for i in post_strip_subindices]
if dynamic_shapes is not None:
# See `_LazyDynamicCompiledModel` for why dynamic-shape compiles must
# be deferred — torch.export with dynamic_shapes mutates ShapeEnv state
# Dynamo is still relying on, and running it inside the backend frame
# corrupts the freshly-installed guards.
return _LazyDynamicCompiledModel(
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
)
return _eager_pt2_compile(
gm, user_inputs, original_weights, user_indices, None, factory
)

View File

@@ -0,0 +1,215 @@
from dataclasses import dataclass
import warnings
from typing import Callable
import pytest
import torch
from luminal import luminal_backend
from luminal.compiled_model import DTypeBoundaryWarning
class BoundaryNoopModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.bool:
return x | torch.zeros((), dtype=torch.bool, device=x.device)
return x + torch.zeros((), dtype=x.dtype, device=x.device)
@dataclass(frozen=True)
class DTypeCase:
name: str
dtype: torch.dtype
values: Callable[[], torch.Tensor]
xfail_reason: str | None = None
DTYPE_CASES = [
DTypeCase(
"bool",
torch.bool,
lambda: torch.tensor([True, False, True], dtype=torch.bool),
),
DTypeCase(
"uint8",
torch.uint8,
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
),
DTypeCase(
"int8",
torch.int8,
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
),
DTypeCase(
"int16",
torch.int16,
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
),
DTypeCase(
"int32",
torch.int32,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int32,
),
),
DTypeCase(
"int64_i32_range",
torch.int64,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int64,
),
),
DTypeCase(
"float16",
torch.float16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
),
DTypeCase(
"bfloat16",
torch.bfloat16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
),
DTypeCase(
"float32",
torch.float32,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
),
DTypeCase(
"float64_f32_exact",
torch.float64,
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
),
DTypeCase(
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
xfail_reason=(
"Luminal currently collapses integer inputs through i32 at the "
"compiled boundary, so out-of-range int64 values lose information."
),
),
DTypeCase(
"float64_precision_sensitive",
torch.float64,
lambda: torch.tensor(
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
xfail_reason=(
"Luminal currently routes float64 no-op computation through f32 "
"storage/outputs before restoring the PyTorch-visible dtype."
),
),
]
def _cuda_skip_reason() -> str | None:
if not torch.cuda.is_available():
return "CUDA is not available"
try:
from luminal.luminal import _cuda_lite_factory_capsule
_cuda_lite_factory_capsule()
except (ImportError, AttributeError, RuntimeError) as exc:
return f"luminal_python was not built with CUDA support: {exc}"
return None
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
def boundary_device(request) -> torch.device:
device_name = request.param
if device_name == "cuda":
skip_reason = _cuda_skip_reason()
if skip_reason is not None:
pytest.skip(skip_reason)
return torch.device(device_name)
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
if case.xfail_reason is not None
else (),
id=case.name,
)
for case in DTYPE_CASES
],
)
def test_boundary_noop_preserves_dtype_and_values(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
expected = model(x)
actual = compiled(x)
assert isinstance(actual, torch.Tensor)
assert actual.dtype == expected.dtype
assert torch.equal(actual.cpu(), expected.cpu())
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"uint8",
"int8",
"int16",
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_boundary_warns_when_input_dtype_requires_conversion(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
compiled(x)
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
],
)
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always")
compiled(x)
dtype_boundary_warnings = [
record
for record in records
if issubclass(record.category, DTypeBoundaryWarning)
]
assert dtype_boundary_warnings == []

View File

@@ -0,0 +1,312 @@
"""End-to-end tests for dynamic-shape support through ``torch.compile``.
These exercise the path that the standard PyTorch user hits — i.e. wrapping a
model with ``torch.compile(model, backend=luminal_backend)`` and calling it
with varying input shapes. The luminal backend is expected to recognise
Dynamo-emitted SymInt placeholders, propagate the symbolic dims through the
PT2 export, and reuse a single compiled graph across shape changes.
"""
from __future__ import annotations
import pytest
import torch
import torch._dynamo
from luminal.main import luminal_backend
def _compile(model, count_holder):
def wrapper(gm, example_inputs):
out = luminal_backend(gm, example_inputs)
count_holder.append(1)
return out
return torch.compile(model, backend=wrapper)
def _compile_with_dynamic_true(model, count_holder):
def wrapper(gm, example_inputs):
out = luminal_backend(gm, example_inputs)
count_holder.append(1)
return out
return torch.compile(model, backend=wrapper, dynamic=True)
@pytest.fixture(autouse=True)
def _enable_automatic_dynamic():
"""Make sure the tests run with Dynamo's automatic-dynamic detection on.
Other tests in the suite flip this off; reset state between tests so the
cache that backs the previous suppression doesn't carry over. We also
raise the recompile limit because Dynamo defaults to 1 (which trips
before automatic-dynamic kicks in) and have to do an extra reset to
drop any cached frames from prior tests in the suite.
"""
torch._dynamo.reset()
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
prev_limit = torch._dynamo.config.recompile_limit
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.recompile_limit = 16
try:
yield
finally:
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
torch._dynamo.config.recompile_limit = prev_limit
torch._dynamo.reset()
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — the dynamic-shape backend wiring is exercised end to end against the cuda_lite runtime",
)
def test_dynamic_seq_via_torch_compile_reuses_compile(device: torch.device):
"""A varying seq dim should produce two backend invocations total.
First call: Dynamo emits a static-shape graph (no SymInt placeholders).
Second call: Dynamo detects the size mismatch and re-traces with the dim
marked dynamic. From that point on, every subsequent shape variation
must be served by the same compiled graph — no further backend calls.
"""
class Mdl(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
return x.reshape(s, -1).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5), (
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
)
assert len(counts) == 2, (
f"expected exactly 2 backend invocations (one static, one dynamic), got {len(counts)}"
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_dynamic_via_torch_compile_with_lifted_weights(device: torch.device):
"""Combines lifted-weight re-internalization with the SymInt strip.
Most real models hit both paths simultaneously (Dynamo lifts every
`nn.Parameter` AND emits SymInt placeholders for any dim that varies
between calls), so the two filters need to compose without losing
track of input positions.
"""
class Mdl(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(8, 4)
def forward(self, x):
return self.lin(x).sum(-1)
model = Mdl().eval().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [3, 4, 5, 6, 4]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5), (
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
)
assert len(counts) == 2
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_compound_shape_expression_auto_resolves(device: torch.device):
"""Affine shape expressions (`2*s` etc.) should still let auto-detect work.
The `auto_set_dims_from_input_shapes` Rust path used to only handle bare
`Term::Var(c)` shape expressions and silently skip anything else, leaving
affine dims unresolved on the CompiledGraph and the corresponding output
sizes stale. We now invert single-variable affine forms `a*x + b` by
sampling two probe points; this test exercises that path by constructing
a model whose first axis evolves into `2*s` after a `cat` along it.
"""
class Mdl(torch.nn.Module):
def forward(self, x):
# `cat([x, x], dim=0)` doubles the leading dim — torch.export
# encodes the resulting shape as `2*s` rather than `s`.
return torch.cat([x, x], dim=0).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_torch_compile_dynamic_true_single_compile(device: torch.device):
"""`torch.compile(model, backend=luminal_backend, dynamic=True)` works.
`dynamic=True` skips Dynamo's specialise-then-promote dance and emits a
fully-symbolic graph from the first call. The luminal backend must
handle the SymInt placeholders Dynamo passes alongside the tensor
inputs and reuse a single compiled graph across all shape variations —
one backend invocation total, in contrast to the 2 we'd see under
automatic-dynamic mode (which burns a static compile on call 1 before
promoting to dynamic on call 2).
"""
class Mdl(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
return x.reshape(s, -1).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile_with_dynamic_true(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape
assert torch.allclose(out, ref, atol=1e-5)
assert len(counts) == 1, (
f"dynamic=True should produce a single backend invocation, got {len(counts)}"
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_explicit_compile_float_input_dynamic(device: torch.device):
"""`luminal.pt2.compile(model, example, dynamic_dim=...)` with a float input.
The previous version of `compile()` silently fell back to a static export
for floating-point inputs (the `"auto"` heuristic was integer-only). The
new spec accepts an explicit `int` or `Iterable[int]` regardless of dtype,
and `"auto"` now picks every non-trivial axis.
"""
from luminal.pt2 import compile as luminal_compile
class Mdl(torch.nn.Module):
def forward(self, x):
return (x * 2.0).sum(-1)
model = Mdl().eval().to(device)
example = torch.randn(4, 8, device=device)
compiled = luminal_compile(model, example, search_iterations=3, dynamic_dim=0)
assert compiled.has_dynamic_dims, "compile() should have produced a dynamic graph"
for shp in [4, 5, 6, 7]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
# `compile()` returns a tuple of outputs; extract the first.
out_t = out[0] if isinstance(out, tuple) else out
assert out_t.shape == ref.shape, (
f"shape={shp}: got {out_t.shape}, expected {ref.shape}"
)
assert torch.allclose(out_t, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_explicit_compile_dynamic_shapes_passthrough(device: torch.device):
"""`luminal.pt2.compile(... , dynamic_shapes=...)` accepts a full spec.
Lets the caller specify named `Dim` objects with ranges — the previous
API hardcoded `Dim("seq", min=2)` for any single dynamic dim.
"""
from torch.export import Dim
from luminal.pt2 import compile as luminal_compile
class Mdl(torch.nn.Module):
def forward(self, x):
return x.mean(-1)
model = Mdl().eval().to(device)
example = torch.randn(4, 8, device=device)
seq = Dim("seq_len", min=2, max=64)
compiled = luminal_compile(
model, example, search_iterations=3, dynamic_shapes=({0: seq},)
)
assert compiled.has_dynamic_dims
# torch.export rewrites user-supplied Dim names to its internal s77/s33
# convention before saving — what we actually need to verify is that a
# symbolic dim was registered, not what label it ended up with.
assert len(compiled.dim_params) == 1, (
f"expected exactly one dynamic dim, got {compiled.dim_params}"
)
for shp in [3, 5, 16]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
out_t = out[0] if isinstance(out, tuple) else out
assert out_t.shape == ref.shape
assert torch.allclose(out_t, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_dynamic_two_dim_via_torch_compile(device: torch.device):
"""Both batch and seq dynamic — should still reuse a single compile."""
class Mdl(torch.nn.Module):
def forward(self, x):
return x.sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
# Vary batch and seq together so Dynamo marks both as dynamic.
for batch, seq in [(2, 8), (3, 9), (4, 10), (5, 11), (3, 12)]:
x = torch.randn(batch, seq, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape
assert torch.allclose(out, ref, atol=1e-5)
# Allow at most a small number of compiles — two shape transitions can
# legitimately take Dynamo two retraces (one per newly-dynamic dim).
assert len(counts) <= 3, (
f"expected ≤3 compiles for two-dim dynamic, got {len(counts)}"
)

View File

@@ -1,5 +1,6 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
@@ -220,6 +221,7 @@ from test_models import (
Conv1dNoPadModel,
Conv1dSamePadModel,
Conv1dBiasModel,
Conv1dFloorDivPositionalModel,
Conv2dNoPadModel,
Conv2dSamePadModel,
Conv2dBiasModel,
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
eager = model(x)
out = model_compiled(x)
assert out.dtype == eager.dtype == torch.int64
assert torch.equal(out, eager)
def test_reduce_sum_3d_axis1(device: torch.device):
"""Test sum reduction along axis 1 for a 3D tensor."""
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
@@ -1847,6 +1860,60 @@ def test_scaled_dot_product_attention(device: torch.device):
assert torch.allclose(output, original, atol=1e-5)
# ========== F.scaled_dot_product_attention (SDPA aten variants) ==========
# Tests for `torch.nn.functional.scaled_dot_product_attention`, which lowers
# to one of `aten._scaled_dot_product_*_attention.default` (variant chosen by
# PyTorch's dispatcher: efficient/flash/flash_for_cpu/cudnn). Coverage here
# exercises `translate_sdpa` end-to-end.
def _sdpa_qkv(device: torch.device, b: int = 1, h: int = 2, s: int = 4, d: int = 8):
"""Build a `(B, H, S, D)` Q/K/V triple of float32 tensors on `device`."""
torch.manual_seed(0)
q = torch.rand((b, h, s, d), device=device)
k = torch.rand((b, h, s, d), device=device)
v = torch.rand((b, h, s, d), device=device)
return q, k, v
def test_sdpa_basic(device: torch.device):
"""`F.scaled_dot_product_attention(q, k, v)` — default scale, no mask."""
from test_models import SdpaBasicModel
model: torch.nn.Module = SdpaBasicModel().to(device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
q, k, v = _sdpa_qkv(device)
expected: torch.Tensor = model(q, k, v)
actual: torch.Tensor = compiled(q, k, v)
assert torch.allclose(actual, expected, atol=1e-5)
def test_sdpa_causal(device: torch.device):
"""`F.scaled_dot_product_attention(q, k, v, is_causal=True)`."""
from test_models import SdpaCausalModel
model: torch.nn.Module = SdpaCausalModel().to(device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
q, k, v = _sdpa_qkv(device)
expected: torch.Tensor = model(q, k, v)
actual: torch.Tensor = compiled(q, k, v)
assert torch.allclose(actual, expected, atol=1e-5)
def test_sdpa_with_attn_bias(device: torch.device):
"""SDPA with an additive `attn_mask` (float bias) broadcast over heads."""
from test_models import SdpaWithBiasModel
model: torch.nn.Module = SdpaWithBiasModel().to(device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
q, k, v = _sdpa_qkv(device)
bias = torch.zeros((1, 1, q.shape[-2], k.shape[-2]), device=device)
bias[..., 0, 1] = -1.0 # any non-trivial bias to verify it's actually applied
expected: torch.Tensor = model(q, k, v, bias)
actual: torch.Tensor = compiled(q, k, v, bias)
assert torch.allclose(actual, expected, atol=1e-5)
def test_mlp_block(device: torch.device):
"""Test two-layer MLP: Linear(8,16) -> ReLU -> Linear(16,4) on input (2,8)."""
model: torch.nn.Module = MLPBlockModel().to(device)
@@ -1968,9 +2035,16 @@ def test_split(device: torch.device):
# ========== Argsort / MoE Routing Tests ==========
def test_argsort_stable_duplicates(device: torch.device):
"""Duplicate values should follow stable lower-index-first tie-breaking."""
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
"""Duplicate values should follow stable lower-index-first tie-breaking.
Parametrized over int32/int64 to verify luminal preserves whichever
integer dtype the eager model declares (LUM-486).
"""
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
device
)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.tensor(
[[2.0, 1.0, 1.0, 3.0]],
@@ -1979,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.dtype == torch.int32
assert torch.equal(output, original.to(torch.int32))
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
assert output.dtype == original.dtype, (
f"luminal returned {output.dtype}, eager produced {original.dtype}"
)
assert torch.equal(output, original)
def test_tiny_moe_routing(device: torch.device):
"""Focused proof for build MoE routing support."""
model: torch.nn.Module = TinyMoERoutingModel().to(device)
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
"""Focused proof for built MoE routing support.
Parametrized over int32/int64 for the integer-valued outputs to verify
luminal preserves the dtype declared by the eager model (LUM-486).
"""
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
scores = torch.tensor(
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
@@ -1996,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
expected = model(scores)
output = model_compiled(scores)
expected_dtypes = (
torch.int32,
torch.float32,
torch.int32,
torch.bool,
torch.int32,
torch.float32,
)
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
assert actual.dtype == expected_dtype
eager = eager.to(actual.dtype)
for actual, eager in zip(output, expected):
assert actual.dtype == eager.dtype, (
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
)
if actual.dtype.is_floating_point:
assert torch.allclose(actual, eager)
else:
@@ -2024,6 +2099,23 @@ def test_topk_values(device: torch.device):
assert torch.allclose(model_compiled(x), model(x))
def test_topk_values_width_128_with_indices(device: torch.device):
"""Regression for router-sized TopK values when both tuple outputs are used."""
class TopKValuesAndIndices(torch.nn.Module):
def forward(self, x: torch.Tensor):
values, indices = torch.topk(torch.softmax(x, dim=-1), 8, dim=1)
return values, indices
model = TopKValuesAndIndices().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(4, 128, device=device)
actual_values, actual_indices = model_compiled(x)
expected_values, expected_indices = model(x)
assert torch.allclose(actual_values, expected_values, atol=1e-5)
assert torch.equal(actual_indices.to(expected_indices.dtype), expected_indices)
def test_topk_indices(device: torch.device):
"""Tests TopK indices output for 2D tensor along axis=1."""
model: torch.nn.Module = TopKIndicesTestModel().to(device)
@@ -2406,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_floor_div_positional_pt2(device: torch.device):
"""Conv1d stride output uses floor division before positional add."""
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.shape == original.shape == (15, 16)
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)

View File

@@ -0,0 +1,142 @@
import torch
import pytest
from luminal import luminal_backend
class StrideSensitiveInputModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"coeff",
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.coeff
class TwoInputReadModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * 2.0 + y * 3.0
class ReturnInputModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class ReturnInputAndComputedModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x, x + 1.0
class CloneThenMutateModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.clone()
y.add_(1.0)
return y, x * 2.0
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
return base, base.t()
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
assert not view.is_contiguous()
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
def _assert_same(actual, expected) -> None:
if isinstance(expected, tuple):
assert isinstance(actual, tuple)
assert len(actual) == len(expected)
for actual_item, expected_item in zip(actual, expected):
_assert_same(actual_item, expected_item)
return
assert torch.allclose(actual, expected)
def _single_non_contiguous_view(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return StrideSensitiveInputModel().to(device), (view,), base
def _same_view_twice(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return TwoInputReadModel().to(device), (view, view), base
def _overlapping_views(device: torch.device):
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
x = base[:3, :4]
y = base[1:, 1:]
assert not x.is_contiguous()
assert not y.is_contiguous()
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
return TwoInputReadModel().to(device), (x, y), base
def _return_input(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputModel().to(device), (view,), base
def _return_input_and_computed(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputAndComputedModel().to(device), (view,), base
def _internal_clone_inplace(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return CloneThenMutateModel().to(device), (view,), base
@pytest.mark.parametrize(
"make_case",
[
pytest.param(
_single_non_contiguous_view,
id="single_non_contiguous_view_stride_sensitive_read",
),
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
pytest.param(_return_input, id="return_input_boundary_value"),
pytest.param(
_return_input_and_computed,
id="return_input_boundary_value_and_computed_value",
),
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
],
)
def test_input_boundary_contiguous_materialization_cases(
device: torch.device, make_case
) -> None:
model, inputs, base = make_case(device)
compiled = torch.compile(model, backend=luminal_backend)
base_before = base.clone()
expected = model(*inputs)
actual = compiled(*inputs)
_assert_same(actual, expected)
assert torch.allclose(base, base_before)
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
device: torch.device,
) -> None:
model, (view,), base = _single_non_contiguous_view(device)
wrong_if_storage_order_used = model(base.reshape(view.shape))
expected = model(view)
assert not torch.allclose(wrong_if_storage_order_used, expected)

View File

@@ -101,6 +101,7 @@ def test_kv_cache_growing():
not torch.cuda.is_available(),
reason="R1 full-width 1-layer is too memory-heavy for CPU native backend",
)
@pytest.mark.slow
def test_kv_cache_growing_r1_mla(device: torch.device):
"""Growing-cache decode loop on DeepSeek-R1 (MLA + decoupled RoPE), 1 layer.

View File

@@ -158,6 +158,7 @@ def test_hf_llama_medium(device: torch.device):
_run_hf_llama_test(config, device, atol=1e-5)
@pytest.mark.slow
def test_hf_llama_large(device: torch.device):
"""HuggingFace LlamaForCausalLM — large (1024 hidden, 1 layer, ~18M params)."""
config = _make_llama_config(
@@ -171,6 +172,7 @@ def test_hf_llama_large(device: torch.device):
_run_hf_llama_test(config, device, atol=1e-5)
@pytest.mark.slow
def test_hf_llama3_real_config_1layer(device: torch.device):
"""HuggingFace LlamaForCausalLM — real Llama3.2-1B architecture, 1 layer.
@@ -227,6 +229,7 @@ def test_hf_llama_decode_loop_static(device: torch.device):
tokens.append(next_token)
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
"""Decode loop on real Llama3.2-1B with pretrained weights.
@@ -282,6 +285,7 @@ def _gpu_mem(label):
)
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
@@ -333,6 +337,7 @@ def test_hf_llama3_full(device: torch.device):
)
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_large_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
@@ -414,6 +419,7 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
)
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama38b_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.

View File

@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
class ArgsortStableDuplicatesModel(torch.nn.Module):
"""Tests deterministic duplicate ordering for exported argsort."""
"""Tests deterministic duplicate ordering for exported argsort.
``idx_dtype`` parameterizes the integer dtype of the returned indices so
the test can verify dtype preservation across luminal's int dtype paths
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
lets us drive the same model toward int32 or int64 outputs.
"""
SORT_DIM = 1
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.argsort(x, dim=self.SORT_DIM)
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
class TinyMoERoutingModel(torch.nn.Module):
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
group_ids) to the requested dtype so the test can sweep int32 and int64
output paths (LUM-486). Internal indices stay int64 because torch.gather
/ torch.scatter require int64 index tensors.
"""
TOP_K = 2
ROUTING_DIM = -1
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
DISPATCH_ON = 1
GROUP_SIZE = 2
def __init__(self) -> None:
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
super().__init__()
self.idx_dtype = idx_dtype
self.register_buffer(
"expert_scale",
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
routing_sign = torch.sign(masked_values)
return (
routed_indices,
routed_indices.to(self.idx_dtype),
masked_values,
dispatch,
dispatch.to(self.idx_dtype),
inactive_mask,
group_ids,
group_ids.to(self.idx_dtype),
routing_sign,
)
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
return self.conv(x)
class Conv1dFloorDivPositionalModel(torch.nn.Module):
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
self.conv2 = torch.nn.Conv1d(
16, 16, kernel_size=3, stride=2, padding=1, bias=True
)
self.position = torch.nn.Parameter(torch.randn(15, 16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.gelu(self.conv1(x))
x = torch.nn.functional.gelu(self.conv2(x))
x = x.squeeze(0).transpose(0, 1)
return x + self.position
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
@@ -2280,3 +2315,48 @@ class IntIndexAssignScalarModel(torch.nn.Module):
out = x.clone()
out[indices] = 42.0
return out
class SdpaBasicModel(torch.nn.Module):
"""`F.scaled_dot_product_attention(q, k, v)` with no mask, no causal flag.
Lowers to `aten._scaled_dot_product_*_attention` (variant chosen by
PyTorch based on device/dtype). Tests the default-scale matmul+softmax
path. Inputs are 4-D `(B, H, S, D)`.
"""
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(q, k, v)
class SdpaCausalModel(torch.nn.Module):
"""`F.scaled_dot_product_attention(q, k, v, is_causal=True)`.
Tests the `is_causal` branch of `translate_sdpa`, which materializes a
triangular mask and adds `-1e9 * mask` to the pre-softmax scores.
"""
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
class SdpaWithBiasModel(torch.nn.Module):
"""SDPA with an additive `attn_mask` bias (float, broadcast over heads).
Tests the additive-bias branch of `translate_sdpa`. The bias has shape
`(1, 1, S_q, S_k)` so it broadcasts across batch/head prefix dims of
the scores tensor.
"""
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias)

View File

@@ -0,0 +1,138 @@
"""Regression coverage for torch.compile mutation and alias contracts.
PyTorch backends are expected to preserve the semantics of the traced graph.
After torch.export functionalization, input mutations are represented as
leading mutation outputs before user outputs. Luminal currently treats every
compiled graph output as a user output and also materializes inputs at the
boundary, so caller-visible mutation and aliasing semantics are not preserved.
"""
import pytest
import torch
from luminal import luminal_backend
class MutateInputThenCompute(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x * 2.0
class MutateInputReturnAlias(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x
class MutateOverlappingInputAlias(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x.add_(10.0)
return y * 2.0
class ReturnInputView(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.t()
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
model = MutateInputThenCompute()
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
actual_input = expected_input.clone()
expected = model(expected_input)
compiled = torch.compile(model, backend=backend)
actual = compiled(actual_input)
assert torch.equal(actual, expected)
assert torch.equal(actual_input, expected_input)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
model = MutateInputReturnAlias()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
_assert_same_storage(out, x)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
model = ReturnInputView()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal currently treats functionalized input-mutation outputs as user "
"outputs and does not copy mutation outputs back to caller inputs."
),
)
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
model = MutateInputThenCompute().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
expected_out = expected_x * 2.0
assert torch.equal(out, expected_out)
assert torch.equal(x, expected_x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal does not preserve caller-visible overlapping input aliasing "
"when one aliased input is mutated."
),
)
def test_luminal_overlapping_input_alias_mutation_contract(
device: torch.device,
) -> None:
model = MutateOverlappingInputAlias().to(device)
eager_base = torch.arange(6, dtype=torch.float32, device=device)
expected = model(eager_base[:4], eager_base[1:5])
base = torch.arange(6, dtype=torch.float32, device=device)
compiled = torch.compile(model, backend=luminal_backend)
actual = compiled(base[:4], base[1:5])
assert torch.equal(actual, expected)
assert torch.equal(base, eager_base)
@pytest.mark.xfail(
strict=True,
reason="Luminal materializes returned input views instead of preserving aliasing.",
)
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
model = ReturnInputView().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)

View File

@@ -0,0 +1,275 @@
"""Qwen3-MoE HuggingFace model integration tests.
Tests progressively larger HuggingFace `Qwen3MoeForCausalLM` configs through
the PyTorch -> PT2 -> luminal pipeline via `torch.compile(..., backend=
luminal_backend)`. Qwen3-MoE shares the dense Qwen3 backbone but replaces
the FFN with a top-k router over `num_experts` independent expert MLPs —
which exercises code paths the dense tests don't:
- `aten._grouped_mm.default` (gather-then-matmul lowering, PR #298)
- bf16 `KernelScatter` (KV cache scatter on a non-F32 dtype)
- `aten.empty_permuted` / `aten.histc` (MoE expert dispatch and
tokens-per-expert counts)
- clamp-on-Int dtype handling (router top-k indices flowing into
`aten.clamp`)
The smaller configs run on GPU in seconds; the "real config" case loads
the actual `Qwen/Qwen3-30B-A3B` arch (128 experts, top-8) with
`num_hidden_layers` overridden to 1 so a full-width compile is
exercised on random weights.
Together these guard the regression-and-fix story that landed alongside:
the bf16 KernelScatter dtype-aware vec count, the `aten.empty(_permuted)`
/ `aten.histc` translator entries, and the
`maximum_f32`-on-Int casting fix.
"""
import pytest
import torch
import torch._dynamo
from luminal import luminal_backend
# ────────────────────────────────────────────────────────────────────────
# Helpers
# ────────────────────────────────────────────────────────────────────────
def _make_qwen3_moe_config(
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
num_hidden_layers: int,
intermediate_size: int,
moe_intermediate_size: int,
num_experts: int,
num_experts_per_tok: int,
vocab_size: int,
):
"""Create a Qwen3MoeConfig with use_cache=False and eager attention.
Shared helper so each test only specifies the scaling knobs that matter
for that case.
"""
from transformers import Qwen3MoeConfig
return Qwen3MoeConfig(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=num_hidden_layers,
intermediate_size=intermediate_size,
moe_intermediate_size=moe_intermediate_size,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
vocab_size=vocab_size,
max_position_embeddings=128,
use_cache=False,
attn_implementation="eager",
)
def _run_hf_qwen3_moe_test(config, device: torch.device, atol: float):
"""Run a HuggingFace Qwen3MoeForCausalLM test with the given config.
Compiles the model with `luminal_backend`, runs both eager and compiled
on the same input, asserts the logits match within `atol`.
"""
from transformers import Qwen3MoeForCausalLM
model = Qwen3MoeForCausalLM(config).eval().to(device)
compiled = torch.compile(model, backend=luminal_backend)
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=atol), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
# ────────────────────────────────────────────────────────────────────────
# Tests — progressively larger configs
# ────────────────────────────────────────────────────────────────────────
def test_hf_qwen3_moe_tiny(device: torch.device):
"""HuggingFace Qwen3MoeForCausalLM — tiny: 2 experts, top-1 routing.
Smallest config that still exercises the MoE expert dispatch
(`aten._grouped_mm`). Top-1 routing keeps the test simple while still
validating the gather-then-matmul lowering path.
"""
config = _make_qwen3_moe_config(
hidden_size=32,
num_attention_heads=2,
num_key_value_heads=1,
num_hidden_layers=1,
intermediate_size=64,
moe_intermediate_size=64,
num_experts=2,
num_experts_per_tok=1,
vocab_size=128,
)
_run_hf_qwen3_moe_test(config, device, atol=1e-5)
def test_hf_qwen3_moe_small(device: torch.device):
"""HuggingFace Qwen3MoeForCausalLM — small: 4 experts, top-2 routing."""
config = _make_qwen3_moe_config(
hidden_size=128,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=1,
intermediate_size=256,
moe_intermediate_size=128,
num_experts=4,
num_experts_per_tok=2,
vocab_size=512,
)
_run_hf_qwen3_moe_test(config, device, atol=1e-4)
def test_hf_qwen3_moe_medium(device: torch.device):
"""HuggingFace Qwen3MoeForCausalLM — medium: 8 experts, top-2, 2 layers.
Two layers means the e-graph crosses a layer boundary, which is where
the late-memory-analysis cleanup pass operates differently than
single-layer cases.
"""
config = _make_qwen3_moe_config(
hidden_size=128,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=256,
moe_intermediate_size=128,
num_experts=8,
num_experts_per_tok=2,
vocab_size=512,
)
_run_hf_qwen3_moe_test(config, device, atol=1e-4)
@pytest.mark.slow
def test_hf_qwen3_moe_real_config_1layer(device: torch.device):
"""HuggingFace Qwen3MoeForCausalLM — real Qwen3-30B-A3B architecture, 1 layer.
Loads `Qwen/Qwen3-30B-A3B`'s AutoConfig (128 experts, top-8 routing,
2048 hidden) and overrides `num_hidden_layers=1`. Random weights —
cheap smoke that the production-shape MoE *layer* compiles end-to-end
through luminal_backend without paying the full 48-layer cost.
"""
from transformers import AutoConfig, Qwen3MoeForCausalLM
config = AutoConfig.from_pretrained("Qwen/Qwen3-30B-A3B")
config.num_hidden_layers = 1
config.use_cache = False
config._attn_implementation = "eager"
model = Qwen3MoeForCausalLM(config).eval().to(device)
compiled = torch.compile(model, backend=luminal_backend)
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
@pytest.mark.slow
def test_hf_qwen3_moe_real_config_full(device: torch.device):
"""HuggingFace Qwen3MoeForCausalLM — full Qwen3-30B-A3B, pretrained.
Loads the real `Qwen/Qwen3-30B-A3B` checkpoint at its native bf16
dtype: 48 hidden layers, 128 experts, top-8 routing, 2048 hidden —
i.e. the production architecture, no `num_hidden_layers` override.
This is the end-to-end "the full MoE compiles" regression guard;
the 1-layer variant above is the cheap smoke.
Asserts the **compile + run** path completes and the compiled
forward produces *finite* output (no NaN / no Inf). It does NOT
assert tight numerical equivalence with eager: at this depth the
egglog search is non-deterministic enough that the two paths can
diverge structurally (same general magnitudes, different per-element
values). Tight numerical equivalence at full scale is tracked as
follow-up work — the smaller-config tests above use atol≤1e-3 and
cover the per-op correctness that this test cannot.
Compared to the 1-layer test this primarily catches:
- egglog cleanup behaviour over a 48-layer-wide e-graph (the
`egglog_utils.rs:1286: No valid graphs` panic surfaces here
if the cleanup cascade re-regresses on MoE root-eclasses);
- per-layer plumbing of residual stream + KV state that
single-layer tests don't exercise;
- any bf16-specific code path (e.g. KernelScatter OOB) that's
masked at fp32.
Memory profile on H200/H100:
- bf16 pretrained weights: ~60 GB
- single-token input keeps activations & router state trivial
- peak observed during compiled forward: ~75 GB total
"""
import gc
from transformers import AutoConfig, Qwen3MoeForCausalLM
# Aggressively release any allocator state from prior tests in the
# same process — at this scale we don't have headroom to absorb it.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
config = AutoConfig.from_pretrained("Qwen/Qwen3-30B-A3B")
config.use_cache = False
config._attn_implementation = "eager"
model = (
Qwen3MoeForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B",
config=config,
torch_dtype=torch.bfloat16,
)
.eval()
.to(device)
)
compiled = torch.compile(model, backend=luminal_backend)
# Single-token input — the full-depth compile is the regression target,
# not multi-token throughput (which the bench covers separately).
input_ids = torch.tensor([[1]], device=device)
with torch.no_grad():
# Eager forward — confirms the test setup is sane (HF is happy).
ref = model(input_ids)
ref_max = ref.logits.float().abs().max().item()
assert torch.isfinite(ref.logits).all(), (
"eager forward produced non-finite logits — test setup is broken, "
"not a luminal regression"
)
del ref
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Compiled forward — the actual regression target.
out = compiled(input_ids)
out_logits = out.logits.float()
n_nan = int(out_logits.isnan().sum().item())
n_inf = int(out_logits.isinf().sum().item())
out_max = out_logits.abs().max().item()
assert n_nan == 0 and n_inf == 0, (
f"compiled forward produced non-finite logits: {n_nan} NaNs, "
f"{n_inf} Infs (eager max abs={ref_max:.2f}, compiled max abs={out_max:.2f})"
)
# Sanity-check magnitude: compiled output should be in the same ballpark
# as eager — within an order of magnitude of the eager logits' scale.
# Catches the failure mode where some kernel silently produces
# near-zero or near-Inf values that pass the finite check.
assert 0.1 * ref_max <= out_max <= 10.0 * ref_max, (
f"compiled max abs={out_max:.2f} is out of band vs eager max abs={ref_max:.2f} "
f"(>10× off in either direction); likely a numerical/scale bug"
)

File diff suppressed because it is too large Load Diff

View File

@@ -83,6 +83,7 @@ def test_whisper_decoder_layer(device: torch.device):
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
@pytest.mark.slow
def test_whisper_encoder_random_init(device: torch.device):
"""Full encoder over a random mel: 2 conv stems + 4 transformer blocks."""
model = _make_small_whisper().to(device)
@@ -96,6 +97,7 @@ def test_whisper_encoder_random_init(device: torch.device):
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
@pytest.mark.slow
def test_whisper_full_random_init_one_step(device: torch.device):
"""End-to-end Whisper forward (encoder + decoder for one step) with random weights.

View File

@@ -0,0 +1,117 @@
"""YOLO v11n end-to-end tests using the luminal_cuda_lite backend.
This module exercises the YOLO v11n building blocks (Conv + BN, C3k2, the
SPPF/C2PSA backbone, the Detect head) and finally the full model through
``torch.compile(..., backend=luminal_backend)``.
The smaller per-block tests are useful when triaging which part of the
architecture starts diverging: incrementally building a model up is much
easier than debugging a 100-layer mismatch in one go.
Marked ``slow`` because the first run downloads ~6 MB of weights and the
luminal e-graph compile of the full model is non-trivial. Run with::
uv run pytest tests/test_yolo_v11.py -v -s
"""
from typing import Callable
import pytest
import torch
import torch._dynamo
from luminal import luminal_backend
def _require_cuda(device: torch.device):
if device.type != "cuda":
pytest.skip("YOLO v11 examples require the CUDA backend.")
def _require_ultralytics():
try:
from ultralytics import YOLO # noqa: F401
except ImportError as exc: # pragma: no cover
pytest.skip(f"ultralytics not installed: {exc}")
def _yolo_model(device: torch.device, decode_only: bool = True):
"""Load yolo11n with BN folded into Conv. Returns the eager torch model."""
from ultralytics import YOLO
yolo = YOLO("yolo11n.pt")
pt_model = yolo.model.eval()
pt_model.fuse()
if decode_only:
pt_model.model[-1].export = True
pt_model.to(device)
return pt_model
@pytest.mark.slow
def test_yolo_v11n_first_three_layers(device: torch.device):
"""Compile only the first three layers (Conv, Conv, C3k2) — exercises the
chunk + bottleneck residual + concat pattern that's the trickiest piece
of the model graph."""
_require_cuda(device)
_require_ultralytics()
pt_model = _yolo_model(device, decode_only=True)
class FirstThree(torch.nn.Module):
def __init__(self, backbone):
super().__init__()
self.layers = torch.nn.ModuleList([backbone[i] for i in range(3)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
sub = FirstThree(pt_model.model).to(device).eval()
torch.manual_seed(0)
x = torch.rand(1, 3, 640, 640, dtype=torch.float32, device=device)
with torch.no_grad():
ref = sub(x)
torch._dynamo.reset()
compiled: Callable = torch.compile(sub, backend=luminal_backend)
with torch.no_grad():
out = compiled(x)
max_diff = torch.max(torch.abs(out - ref)).item()
print(f"yolo11n[:3] max_diff vs PyTorch eager: {max_diff:.4e}")
assert torch.allclose(out, ref, atol=1e-3), (
f"yolo11n[:3] outputs differ — max_diff={max_diff:.4e}"
)
@pytest.mark.slow
def test_yolo_v11n_end_to_end(device: torch.device):
"""Full yolo11n forward via torch.compile. The compile may be slow on
machines without strong egglog parallelism — see the example README for
the standalone Rust binary alternative."""
_require_cuda(device)
_require_ultralytics()
pt_model = _yolo_model(device)
torch.manual_seed(0)
x = torch.rand(1, 3, 640, 640, dtype=torch.float32, device=device)
with torch.no_grad():
ref = pt_model(x)
if isinstance(ref, (list, tuple)):
ref = ref[0]
torch._dynamo.reset()
compiled: Callable = torch.compile(pt_model, backend=luminal_backend)
with torch.no_grad():
out = compiled(x)
if isinstance(out, (list, tuple)):
out = out[0]
max_diff = torch.max(torch.abs(out - ref)).item()
print(f"YOLO v11n max_diff vs PyTorch eager: {max_diff:.4e}")
assert torch.allclose(out, ref, atol=1e-3), (
f"YOLO v11n outputs differ from PyTorch eager — max_diff={max_diff:.4e}"
)

View File

@@ -315,8 +315,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -462,8 +462,13 @@ fn hlir_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
let hidden_exp = hidden.unsqueeze(2);
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}

View File

@@ -52,6 +52,9 @@ fn main() {
v_out.output();
}
cx.set_dim('s', 1);
cx.set_dim('p', 1);
println!("Building E-Graph...");
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(500),

View File

@@ -246,8 +246,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -287,8 +287,13 @@ fn hlir_attention(
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -287,7 +287,8 @@ impl QwenMoE {
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
weights_exp.shape.expand(down_out.dims());
(down_out * weights_exp).sum(n - 1)
}
}
@@ -385,8 +386,13 @@ fn attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
k_full.shape.dims[1] = total_seq;
v_full.shape.dims[1] = total_seq;
// GQA expand
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -174,8 +174,13 @@ fn decoder_self_attention(
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total, ..));
let v_full = v_cache_out.slice((.., ..total, ..));
let mut k_full = k_cache_out.slice((.., ..total, ..));
let mut v_full = v_cache_out.slice((.., ..total, ..));
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
// cannot yet propagate expression-bound assertions, so `slice` reports
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
k_full.shape.dims[1] = total;
v_full.shape.dims[1] = total;
let q = split_heads(q);

3
examples/yolo_v11/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
artifacts/
yolo11n.pt
__pycache__/

View File

@@ -0,0 +1,28 @@
[package]
name = "yolo_v11"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "yolo_v11"
path = "src/main.rs"
[features]
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = { path = "../../crates/luminal_tracing" }
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
safetensors = "0.7.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
half = { version = "2.7.1", features = ["bytemuck"] }
bytemuck = "1.24.0"
memmap2 = "0.9.9"
rustc-hash = "2.1"
image = { version = "0.25", default-features = false, features = ["jpeg", "png"] }
ureq = "2"

108
examples/yolo_v11/README.md Normal file
View File

@@ -0,0 +1,108 @@
# YOLO v11n on Luminal
End-to-end Object detection demo running the Ultralytics yolo11n model on the
`luminal_cuda_lite` backend.
## Layout
```
examples/yolo_v11/
├── Cargo.toml # Rust crate (binary: yolo_v11)
├── src/
│ ├── main.rs # Full forward, NMS, and annotated image output
│ └── model.rs # YOLO v11n architecture in luminal IR
├── python/
│ ├── reference.py # PyTorch eager reference + weight prep
│ └── luminal_example.py # torch.compile(..., backend=luminal_backend) demo
└── artifacts/ # Downloaded/generated artifacts (gitignored)
├── bus.jpg
├── reference_input.bin
├── reference_output.bin
├── reference_boxes.json
└── weights.safetensors
```
## Quick start
1. **Run the Rust example** (CUDA, e.g. on a GH200 / H100):
```bash
# Full model on the default bus.jpg sample
cargo run --release -p yolo_v11 --bin yolo_v11
# Full model on any JPEG or PNG
cargo run --release -p yolo_v11 --bin yolo_v11 -- --input /path/to/image.jpg --output /tmp/yolo_annotated.png
# Positional shorthand: <input> <output>
cargo run --release -p yolo_v11 --bin yolo_v11 -- /path/to/image.jpg /tmp/yolo_annotated.png
```
On first run, the binary downloads `weights.safetensors` and the default
`bus.jpg` sample into `examples/yolo_v11/artifacts/` if they are missing.
`yolo_v11` builds the entire YOLO v11n graph and the Detect head, preprocesses
a JPEG/PNG with a Rust implementation of the 640x640 Ultralytics-style
letterbox transform, runs the forward, applies class-aware NMS, and prints
detections in the original image coordinates. For image inputs, it also writes
an annotated PNG to `examples/yolo_v11/artifacts/annotated.png` by default.
The input and annotated output paths can be supplied as CLI arguments:
`--input /path/to/image.png --output /path/to/out.png`.
The direct image path may differ slightly from Python/OpenCV preprocessing
because it uses Rust image decoding and resizing.
2. **(Optional) Regenerate reference data + fused weights** (PyTorch + Ultralytics):
```bash
pip install ultralytics torch opencv-python-headless
python examples/yolo_v11/python/reference.py
```
This downloads `yolo11n.pt`, fuses Conv + BN, runs the eager forward on a
bundled bus image, and writes `examples/yolo_v11/artifacts/`.
3. **(Optional) Run the Python compiled-model example**:
Requires `luminal_python` built with the cuda feature (see
`crates/luminal_python/run_tests_cuda.sh`).
```bash
python examples/yolo_v11/python/luminal_example.py
```
The pytest version is `crates/luminal_python/tests/test_yolo_v11.py`.
## Implementation notes
* All Conv blocks are loaded with `bn` folded into a bias-augmented Conv2d
(`forward_fuse`), so the saved tensors are just `<layer>.conv.weight` and
`<layer>.conv.bias`.
* The `C3k2`, `C3k`, `C2PSA`, and `Attention` modules in PyTorch use
`tensor.chunk(2, dim=1)` (or `qkv.split([...], dim=...)`) to produce two/three
channel-slices that then take separate paths. Slicing followed by a residual
add inside a bottleneck triggers a cascade in luminal_cuda_lite's e-graph
cleanup that prunes the only kernel alternatives. To work around this, the
Python script pre-splits those conv weights along the output-channel dim and
the Rust model exposes them as separate convs (`cv1a`/`cv1b` for C3k2/C2PSA,
`q_split`/`k_split`/`v_split` for Attention).
* Anchors, per-anchor strides, and the DFL projection weight are fed from Rust
via `runtime.set_data`. The DFL projection is the constant `arange(reg_max)`.
* `make_contiguous` (a free function in `src/model.rs`) materializes a
non-contiguous view via `gather + iota` (the same trick `GraphTensor::output`
uses internally). It's applied wherever an op chain produces a strided view
that the next op needs to read sequentially.
* 1x1 convolutions skip the unfold path and use a direct 2D matmul, so
luminal_cuda_lite's `TileMatmulFullSplit` kernel can match.
## Known limitation: full-model compile time
The `yolo_v11` binary builds a graph of ~2,200 HLIR nodes (~100 convolutions
plus the Detect head). luminal_cuda_lite's e-graph rewrite phase runs many
rules to fixpoint over the whole graph, which on a conv-heavy vision model
becomes the dominant cost. On a Grace-Hopper class machine this phase can
take >10 minutes (using ~30+ GB of host RAM in the egglog tables) before the
search and execution finally proceed.
The Python torch.compile path (`crates/luminal_python/tests/test_yolo_v11.py`)
is a useful alternative because the pt2 export decomposes the graph slightly
differently.

View File

@@ -0,0 +1,893 @@
mod model;
use std::{
env, fs, io,
path::{Path, PathBuf},
process,
time::Instant,
};
use image::{ImageBuffer, ImageReader, Rgb, RgbImage};
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const ARTIFACT_DIR: &str = "examples/yolo_v11/artifacts";
const WEIGHTS_URL: &str =
"https://github.com/luminal-ai/luminal/releases/download/yolo-v11n/weights.safetensors";
const SAMPLE_IMAGE_URL: &str =
"https://github.com/ultralytics/assets/releases/download/v0.0.0/bus.jpg";
const CONF_THRES: f32 = 0.25;
const IOU_THRES: f32 = 0.45;
const MAX_DET: usize = 300;
#[derive(Debug, Clone, Copy)]
struct LetterboxMeta {
orig_width: u32,
orig_height: u32,
ratio: f32,
pad_x: f32,
pad_y: f32,
}
#[derive(Debug, Clone)]
struct Detection {
score: f32,
class_id: usize,
x1: f32,
y1: f32,
x2: f32,
y2: f32,
}
#[derive(Debug, Clone)]
struct CliArgs {
image_path: Option<PathBuf>,
annotated_path: PathBuf,
}
fn print_usage() {
println!(
"Usage: cargo run --release -p yolo_v11 --bin yolo_v11 -- [--input <image.jpg|image.png>] [--output <annotated.png>]\n\
\n\
Positional form is also supported:\n\
cargo run --release -p yolo_v11 --bin yolo_v11 -- <image.jpg|image.png> <annotated.png>\n\
\n\
If no image is supplied, the example uses examples/yolo_v11/artifacts/bus.jpg and downloads it if needed."
);
}
fn cli_args(artifact_dir: &Path) -> CliArgs {
let mut image_path = None;
let mut annotated_path = None;
let mut positionals = Vec::new();
let mut args = env::args_os().skip(1);
while let Some(arg) = args.next() {
let arg_str = arg.to_string_lossy();
match arg_str.as_ref() {
"-h" | "--help" => {
print_usage();
process::exit(0);
}
"--input" => {
image_path = Some(next_cli_path(&mut args, arg_str.as_ref()));
}
"--output" | "-o" => {
annotated_path = Some(next_cli_path(&mut args, arg_str.as_ref()));
}
"--" => {
positionals.extend(args.map(PathBuf::from));
break;
}
_ if arg_str.starts_with('-') => panic!("Unknown argument: {arg_str}"),
_ => positionals.push(PathBuf::from(arg)),
}
}
if let Some(positional) = positionals.first() {
if image_path.is_some() {
panic!("Input image was provided both positionally and with --input");
}
image_path = Some(positional.clone());
}
if let Some(positional) = positionals.get(1) {
if annotated_path.is_some() {
panic!("Output image was provided both positionally and with --output");
}
annotated_path = Some(positional.clone());
}
if positionals.len() > 2 {
panic!("Too many positional arguments; expected at most <input> <output>");
}
let image_path = image_path.or_else(|| Some(artifact_dir.join("bus.jpg")));
let annotated_path = annotated_path.unwrap_or_else(|| artifact_dir.join("annotated.png"));
CliArgs {
image_path,
annotated_path,
}
}
fn next_cli_path(args: &mut impl Iterator<Item = std::ffi::OsString>, flag: &str) -> PathBuf {
args.next()
.map(PathBuf::from)
.unwrap_or_else(|| panic!("{flag} requires a path"))
}
fn ensure_downloaded(path: &Path, url: &str, label: &str) {
if path.exists() {
return;
}
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.unwrap_or_else(|e| panic!("Failed to create {}: {e}", parent.display()));
}
let tmp_path = download_temp_path(path);
let _ = fs::remove_file(&tmp_path);
println!("Downloading {label}: {url}");
println!(" -> {}", path.display());
let response = ureq::get(url)
.set("User-Agent", "luminal-yolo-v11-example")
.call()
.unwrap_or_else(|e| panic!("Failed to download {label} from {url}: {e}"));
let mut reader = response.into_reader();
let mut file = fs::File::create(&tmp_path)
.unwrap_or_else(|e| panic!("Failed to create {}: {e}", tmp_path.display()));
io::copy(&mut reader, &mut file)
.unwrap_or_else(|e| panic!("Failed to write {}: {e}", tmp_path.display()));
file.sync_all()
.unwrap_or_else(|e| panic!("Failed to sync {}: {e}", tmp_path.display()));
fs::rename(&tmp_path, path).unwrap_or_else(|e| {
panic!(
"Failed to move {} to {}: {e}",
tmp_path.display(),
path.display()
)
});
}
fn download_temp_path(path: &Path) -> PathBuf {
let file_name = path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("download");
path.with_file_name(format!("{file_name}.download"))
}
fn preprocess_image(path: &Path) -> (Vec<f32>, LetterboxMeta) {
let rgb = ImageReader::open(path)
.unwrap_or_else(|e| panic!("Failed to open image {}: {e}", path.display()))
.decode()
.unwrap_or_else(|e| panic!("Failed to decode image {}: {e}", path.display()))
.to_rgb8();
let (orig_width, orig_height) = rgb.dimensions();
assert!(
orig_width > 0 && orig_height > 0,
"image must have non-zero dimensions"
);
let img_size = IMG_SIZE as u32;
let ratio = (img_size as f32 / orig_height as f32).min(img_size as f32 / orig_width as f32);
let resized_width = ((orig_width as f32 * ratio).round() as u32).max(1);
let resized_height = ((orig_height as f32 * ratio).round() as u32).max(1);
let resized = resize_rgb_inter_linear(&rgb, resized_width, resized_height);
let dw = img_size.saturating_sub(resized_width) as f32;
let dh = img_size.saturating_sub(resized_height) as f32;
let left = ((dw / 2.0) - 0.1).round().max(0.0) as u32;
let top = ((dh / 2.0) - 0.1).round().max(0.0) as u32;
let mut letterboxed: RgbImage =
ImageBuffer::from_pixel(img_size, img_size, Rgb([114, 114, 114]));
image::imageops::replace(&mut letterboxed, &resized, left.into(), top.into());
let plane = IMG_SIZE * IMG_SIZE;
let mut data = vec![0.0_f32; 3 * plane];
for y in 0..IMG_SIZE {
for x in 0..IMG_SIZE {
let p = letterboxed.get_pixel(x as u32, y as u32);
let idx = y * IMG_SIZE + x;
data[idx] = p[0] as f32 / 255.0;
data[plane + idx] = p[1] as f32 / 255.0;
data[2 * plane + idx] = p[2] as f32 / 255.0;
}
}
(
data,
LetterboxMeta {
orig_width,
orig_height,
ratio,
pad_x: left as f32,
pad_y: top as f32,
},
)
}
fn resize_rgb_inter_linear(src: &RgbImage, dst_width: u32, dst_height: u32) -> RgbImage {
let (src_width, src_height) = src.dimensions();
assert!(src_width > 0 && src_height > 0);
if src_width == dst_width && src_height == dst_height {
return src.clone();
}
let scale_x = src_width as f32 / dst_width as f32;
let scale_y = src_height as f32 / dst_height as f32;
let mut dst = RgbImage::new(dst_width, dst_height);
for y in 0..dst_height {
let (y0, y1, wy) = resize_axis(y, scale_y, src_height);
for x in 0..dst_width {
let (x0, x1, wx) = resize_axis(x, scale_x, src_width);
let p00 = src.get_pixel(x0, y0);
let p01 = src.get_pixel(x1, y0);
let p10 = src.get_pixel(x0, y1);
let p11 = src.get_pixel(x1, y1);
let mut out = [0u8; 3];
for c in 0..3 {
let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
let bottom = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
out[c] = (top * (1.0 - wy) + bottom * wy).round().clamp(0.0, 255.0) as u8;
}
dst.put_pixel(x, y, Rgb(out));
}
}
dst
}
fn resize_axis(dst_index: u32, scale: f32, src_len: u32) -> (u32, u32, f32) {
if src_len == 1 {
return (0, 0, 0.0);
}
let src = (dst_index as f32 + 0.5) * scale - 0.5;
if src < 0.0 {
return (0, 0, 0.0);
}
let mut i0 = src.floor() as u32;
let mut weight = src - i0 as f32;
if i0 >= src_len - 1 {
i0 = src_len - 2;
weight = 1.0;
}
(i0, i0 + 1, weight)
}
fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
let cwd = std::env::current_dir().unwrap();
let artifact_dir = cwd.join(ARTIFACT_DIR);
let weights_path = artifact_dir.join("weights.safetensors");
let cli = cli_args(&artifact_dir);
let image_path = cli.image_path.clone();
let search_graphs = 50usize;
println!("Using artifact directory: {}", artifact_dir.display());
ensure_downloaded(&weights_path, WEIGHTS_URL, "YOLO v11n Luminal weights");
let image_path = image_path.unwrap_or_else(|| {
panic!(
"No input image supplied and default image is missing; pass --input <image.jpg|image.png>"
)
});
if image_path == artifact_dir.join("bus.jpg") {
ensure_downloaded(&image_path, SAMPLE_IMAGE_URL, "sample image");
}
assert!(
image_path.exists(),
"Image path does not exist: {}",
image_path.display()
);
println!("Input image: {}", image_path.display());
let (img_data, letterbox_meta) = preprocess_image(&image_path);
println!(
" original={}x{} letterbox_ratio={:.6} pad=({:.0}, {:.0})",
letterbox_meta.orig_width,
letterbox_meta.orig_height,
letterbox_meta.ratio,
letterbox_meta.pad_x,
letterbox_meta.pad_y
);
let expected_input = 3 * IMG_SIZE * IMG_SIZE;
assert_eq!(img_data.len(), expected_input, "input size mismatch");
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
// Build graph
let mut cx = Graph::default();
let img = cx.named_tensor("input.image", (1usize, 3usize, IMG_SIZE, IMG_SIZE));
let yolo = YoloV11::init(&mut cx);
let logits = yolo.forward(img).output();
println!("Building E-Graph...");
let t0 = Instant::now();
cx.build_search_space::<CudaRuntime>();
println!(" built E-Graph in {:?}", t0.elapsed());
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
// Initialize anchors, strides, and DFL constant.
let (anchors_flat, strides_flat) = make_anchors_and_strides(&[80, 40, 20], &STRIDES);
runtime.set_data(yolo.detect.anchors, anchors_flat.clone());
runtime.set_data(yolo.detect.strides, strides_flat.clone());
runtime.set_data(yolo.detect.dfl_weight, dfl_weight());
runtime.set_data(img, img_data.clone());
println!("Compiling (search_graphs={search_graphs})...");
let t0 = Instant::now();
runtime = cx.search(runtime, search_graphs);
println!(" search took {:?}", t0.elapsed());
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)
runtime.set_data(yolo.detect.anchors, anchors_flat);
runtime.set_data(yolo.detect.strides, strides_flat);
runtime.set_data(yolo.detect.dfl_weight, dfl_weight());
runtime.set_data(img, img_data);
println!("Executing...");
let t0 = Instant::now();
runtime.execute(&cx.dyn_map);
let elapsed = t0.elapsed();
println!(" forward took {:?}", elapsed);
// Get output (1, 4 + NC, 8400) — Detect with export=True returns the
// DECODED predictions (4 box coords + NC class scores), not the raw
// (NC + REG_MAX*4) channels.
let out = runtime.get_f32(logits);
let total_anchors: usize = 80 * 80 + 40 * 40 + 20 * 20;
let expected_out_len = (4 + NC) * total_anchors;
println!(
" output buffer length: {} (expected {} for shape (1, {}, {}))",
out.len(),
expected_out_len,
4 + NC,
total_anchors
);
let out = &out[..expected_out_len];
let detections = nms_detections(out, total_anchors, CONF_THRES, IOU_THRES, MAX_DET);
print_detections(&detections, Some(letterbox_meta));
save_annotated_image(
&image_path,
&cli.annotated_path,
&detections,
letterbox_meta,
);
println!("Wrote annotated image: {}", cli.annotated_path.display());
}
fn print_detections(detections: &[Detection], meta: Option<LetterboxMeta>) {
println!(
"Detections after NMS (conf >= {:.2}, iou <= {:.2}):",
CONF_THRES, IOU_THRES
);
if detections.is_empty() {
println!(" none");
return;
}
let coco_names = coco_names();
for det in detections.iter().take(20) {
let name = coco_names.get(det.class_id).copied().unwrap_or("?");
let (x1, y1, x2, y2) = if let Some(meta) = meta {
map_to_original(det.x1, det.y1, det.x2, det.y2, meta)
} else {
(det.x1, det.y1, det.x2, det.y2)
};
println!(
" conf={:.3} class={:>14} xyxy=[{:.1}, {:.1}, {:.1}, {:.1}]",
det.score, name, x1, y1, x2, y2
);
}
}
fn save_annotated_image(
input_path: &Path,
output_path: &Path,
detections: &[Detection],
meta: LetterboxMeta,
) {
let mut image = ImageReader::open(input_path)
.unwrap_or_else(|e| panic!("Failed to open image {}: {e}", input_path.display()))
.decode()
.unwrap_or_else(|e| panic!("Failed to decode image {}: {e}", input_path.display()))
.to_rgb8();
let names = coco_names();
let thickness = ((image.width().min(image.height()) as f32 / 320.0).round() as u32).max(2);
for det in detections.iter().take(MAX_DET) {
let (x1, y1, x2, y2) = map_to_original(det.x1, det.y1, det.x2, det.y2, meta);
let color = class_color(det.class_id);
draw_rect(&mut image, x1, y1, x2, y2, color, thickness);
let name = names.get(det.class_id).copied().unwrap_or("?");
draw_label(
&mut image,
x1,
y1,
&format!("{name} {:.2}", det.score),
color,
);
}
if let Some(parent) = output_path
.parent()
.filter(|path| !path.as_os_str().is_empty())
{
fs::create_dir_all(parent)
.unwrap_or_else(|e| panic!("Failed to create {}: {e}", parent.display()));
}
image.save(output_path).unwrap_or_else(|e| {
panic!(
"Failed to write annotated image {}: {e}",
output_path.display()
)
});
}
fn class_color(class_id: usize) -> Rgb<u8> {
const COLORS: [[u8; 3]; 20] = [
[220, 38, 38],
[37, 99, 235],
[22, 163, 74],
[217, 119, 6],
[147, 51, 234],
[8, 145, 178],
[219, 39, 119],
[101, 163, 13],
[234, 88, 12],
[79, 70, 229],
[15, 118, 110],
[190, 18, 60],
[124, 58, 237],
[202, 138, 4],
[2, 132, 199],
[132, 204, 22],
[249, 115, 22],
[168, 85, 247],
[20, 184, 166],
[244, 63, 94],
];
Rgb(COLORS[class_id % COLORS.len()])
}
fn draw_rect(
image: &mut RgbImage,
x1: f32,
y1: f32,
x2: f32,
y2: f32,
color: Rgb<u8>,
thickness: u32,
) {
let width = image.width();
let height = image.height();
if width == 0 || height == 0 {
return;
}
let left = x1.min(x2).floor().clamp(0.0, (width - 1) as f32) as u32;
let right = x1.max(x2).ceil().clamp(0.0, (width - 1) as f32) as u32;
let top = y1.min(y2).floor().clamp(0.0, (height - 1) as f32) as u32;
let bottom = y1.max(y2).ceil().clamp(0.0, (height - 1) as f32) as u32;
if left > right || top > bottom {
return;
}
for t in 0..thickness {
if top + t <= bottom {
draw_hline(image, left, right, top + t, color);
}
if bottom >= t && bottom - t >= top {
draw_hline(image, left, right, bottom - t, color);
}
if left + t <= right {
draw_vline(image, left + t, top, bottom, color);
}
if right >= t && right - t >= left {
draw_vline(image, right - t, top, bottom, color);
}
}
}
fn draw_hline(image: &mut RgbImage, x1: u32, x2: u32, y: u32, color: Rgb<u8>) {
if y >= image.height() {
return;
}
let start = x1.min(x2).min(image.width().saturating_sub(1));
let end = x1.max(x2).min(image.width().saturating_sub(1));
for x in start..=end {
image.put_pixel(x, y, color);
}
}
fn draw_vline(image: &mut RgbImage, x: u32, y1: u32, y2: u32, color: Rgb<u8>) {
if x >= image.width() {
return;
}
let start = y1.min(y2).min(image.height().saturating_sub(1));
let end = y1.max(y2).min(image.height().saturating_sub(1));
for y in start..=end {
image.put_pixel(x, y, color);
}
}
fn draw_label(image: &mut RgbImage, box_x: f32, box_y: f32, text: &str, color: Rgb<u8>) {
let scale = ((image.width().min(image.height()) as f32 / 500.0).round() as u32).max(2);
let text = text.to_ascii_uppercase();
let text_width = text_pixel_width(&text, scale);
let text_height = 7 * scale;
let pad = 3 * scale;
let label_width = text_width + pad * 2;
let label_height = text_height + pad * 2;
let mut x = box_x.floor().max(0.0) as u32;
if x + label_width >= image.width() {
x = image.width().saturating_sub(label_width + 1);
}
let box_top = box_y.floor().max(0.0) as u32;
let y = if box_top > label_height {
box_top - label_height
} else {
box_top.min(image.height().saturating_sub(label_height + 1))
};
fill_rect(image, x, y, label_width, label_height, color);
draw_text_5x7(image, x + pad, y + pad, &text, Rgb([255, 255, 255]), scale);
}
fn fill_rect(image: &mut RgbImage, x: u32, y: u32, width: u32, height: u32, color: Rgb<u8>) {
let max_x = (x + width).min(image.width());
let max_y = (y + height).min(image.height());
for py in y..max_y {
for px in x..max_x {
image.put_pixel(px, py, color);
}
}
}
fn text_pixel_width(text: &str, scale: u32) -> u32 {
let mut width = 0;
for ch in text.chars() {
width += if ch == ' ' { 3 * scale } else { 5 * scale };
width += scale;
}
width.saturating_sub(scale)
}
fn draw_text_5x7(image: &mut RgbImage, x: u32, y: u32, text: &str, color: Rgb<u8>, scale: u32) {
let mut cursor = x;
for ch in text.chars() {
if ch == ' ' {
cursor += 4 * scale;
continue;
}
draw_glyph_5x7(image, cursor, y, ch, color, scale);
cursor += 6 * scale;
}
}
fn draw_glyph_5x7(image: &mut RgbImage, x: u32, y: u32, ch: char, color: Rgb<u8>, scale: u32) {
let Some(rows) = glyph_5x7(ch) else {
return;
};
for (row_idx, row) in rows.iter().enumerate() {
for (col_idx, pixel) in row.as_bytes().iter().enumerate() {
if *pixel != b'1' {
continue;
}
let px = x + col_idx as u32 * scale;
let py = y + row_idx as u32 * scale;
fill_rect(image, px, py, scale, scale, color);
}
}
}
fn glyph_5x7(ch: char) -> Option<[&'static str; 7]> {
Some(match ch {
'A' => [
"01110", "10001", "10001", "11111", "10001", "10001", "10001",
],
'B' => [
"11110", "10001", "10001", "11110", "10001", "10001", "11110",
],
'C' => [
"01111", "10000", "10000", "10000", "10000", "10000", "01111",
],
'D' => [
"11110", "10001", "10001", "10001", "10001", "10001", "11110",
],
'E' => [
"11111", "10000", "10000", "11110", "10000", "10000", "11111",
],
'F' => [
"11111", "10000", "10000", "11110", "10000", "10000", "10000",
],
'G' => [
"01111", "10000", "10000", "10011", "10001", "10001", "01111",
],
'H' => [
"10001", "10001", "10001", "11111", "10001", "10001", "10001",
],
'I' => [
"11111", "00100", "00100", "00100", "00100", "00100", "11111",
],
'J' => [
"00111", "00010", "00010", "00010", "00010", "10010", "01100",
],
'K' => [
"10001", "10010", "10100", "11000", "10100", "10010", "10001",
],
'L' => [
"10000", "10000", "10000", "10000", "10000", "10000", "11111",
],
'M' => [
"10001", "11011", "10101", "10101", "10001", "10001", "10001",
],
'N' => [
"10001", "11001", "10101", "10011", "10001", "10001", "10001",
],
'O' => [
"01110", "10001", "10001", "10001", "10001", "10001", "01110",
],
'P' => [
"11110", "10001", "10001", "11110", "10000", "10000", "10000",
],
'Q' => [
"01110", "10001", "10001", "10001", "10101", "10010", "01101",
],
'R' => [
"11110", "10001", "10001", "11110", "10100", "10010", "10001",
],
'S' => [
"01111", "10000", "10000", "01110", "00001", "00001", "11110",
],
'T' => [
"11111", "00100", "00100", "00100", "00100", "00100", "00100",
],
'U' => [
"10001", "10001", "10001", "10001", "10001", "10001", "01110",
],
'V' => [
"10001", "10001", "10001", "10001", "10001", "01010", "00100",
],
'W' => [
"10001", "10001", "10001", "10101", "10101", "10101", "01010",
],
'X' => [
"10001", "10001", "01010", "00100", "01010", "10001", "10001",
],
'Y' => [
"10001", "10001", "01010", "00100", "00100", "00100", "00100",
],
'Z' => [
"11111", "00001", "00010", "00100", "01000", "10000", "11111",
],
'0' => [
"01110", "10001", "10011", "10101", "11001", "10001", "01110",
],
'1' => [
"00100", "01100", "00100", "00100", "00100", "00100", "01110",
],
'2' => [
"01110", "10001", "00001", "00010", "00100", "01000", "11111",
],
'3' => [
"11110", "00001", "00001", "01110", "00001", "00001", "11110",
],
'4' => [
"00010", "00110", "01010", "10010", "11111", "00010", "00010",
],
'5' => [
"11111", "10000", "10000", "11110", "00001", "00001", "11110",
],
'6' => [
"01110", "10000", "10000", "11110", "10001", "10001", "01110",
],
'7' => [
"11111", "00001", "00010", "00100", "01000", "01000", "01000",
],
'8' => [
"01110", "10001", "10001", "01110", "10001", "10001", "01110",
],
'9' => [
"01110", "10001", "10001", "01111", "00001", "00001", "01110",
],
'.' => [
"00000", "00000", "00000", "00000", "00000", "01100", "01100",
],
'-' => [
"00000", "00000", "00000", "11111", "00000", "00000", "00000",
],
'/' => [
"00001", "00010", "00010", "00100", "01000", "01000", "10000",
],
'?' => [
"01110", "10001", "00001", "00010", "00100", "00000", "00100",
],
_ => return None,
})
}
fn nms_detections(
out: &[f32],
total_anchors: usize,
conf_thres: f32,
iou_thres: f32,
max_det: usize,
) -> Vec<Detection> {
let nc = NC;
let mut candidates = Vec::new();
for a in 0..total_anchors {
let cx = out[a];
let cy = out[total_anchors + a];
let w = out[2 * total_anchors + a];
let h = out[3 * total_anchors + a];
let mut best_score = 0.0_f32;
let mut best_class = 0usize;
for c in 0..nc {
let s = out[(4 + c) * total_anchors + a];
if s > best_score {
best_score = s;
best_class = c;
}
}
if best_score >= conf_thres {
candidates.push(Detection {
score: best_score,
class_id: best_class,
x1: cx - w / 2.0,
y1: cy - h / 2.0,
x2: cx + w / 2.0,
y2: cy + h / 2.0,
});
}
}
candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
let mut keep: Vec<Detection> = Vec::new();
'candidate: for candidate in candidates {
for selected in &keep {
if candidate.class_id == selected.class_id && box_iou(&candidate, selected) > iou_thres
{
continue 'candidate;
}
}
keep.push(candidate);
if keep.len() >= max_det {
break;
}
}
keep
}
fn box_iou(a: &Detection, b: &Detection) -> f32 {
let ix1 = a.x1.max(b.x1);
let iy1 = a.y1.max(b.y1);
let ix2 = a.x2.min(b.x2);
let iy2 = a.y2.min(b.y2);
let intersection = (ix2 - ix1).max(0.0) * (iy2 - iy1).max(0.0);
let a_area = (a.x2 - a.x1).max(0.0) * (a.y2 - a.y1).max(0.0);
let b_area = (b.x2 - b.x1).max(0.0) * (b.y2 - b.y1).max(0.0);
intersection / (a_area + b_area - intersection + f32::EPSILON)
}
fn map_to_original(
x1: f32,
y1: f32,
x2: f32,
y2: f32,
meta: LetterboxMeta,
) -> (f32, f32, f32, f32) {
let ox1 = ((x1 - meta.pad_x) / meta.ratio).clamp(0.0, meta.orig_width as f32);
let oy1 = ((y1 - meta.pad_y) / meta.ratio).clamp(0.0, meta.orig_height as f32);
let ox2 = ((x2 - meta.pad_x) / meta.ratio).clamp(0.0, meta.orig_width as f32);
let oy2 = ((y2 - meta.pad_y) / meta.ratio).clamp(0.0, meta.orig_height as f32);
(ox1, oy1, ox2, oy2)
}
fn coco_names() -> [&'static str; NC] {
[
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -226,8 +226,9 @@ pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
// Safety: source bytes are from a valid typed buffer; we reinterpret.
unsafe fn from_bytes<T: Copy>(bytes: Vec<u8>) -> Vec<T> {
let n = bytes.len() / std::mem::size_of::<T>();
let cap = bytes.capacity() / std::mem::size_of::<T>();
let mut bytes = std::mem::ManuallyDrop::new(bytes);
unsafe { Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, n, n) }
unsafe { Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, n, cap) }
}
match dtype {
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),

View File

@@ -232,6 +232,10 @@ pub struct BaseSorts {
pub bf16_dt: SortDef,
pub int_dt: SortDef,
pub bool_dt: SortDef,
pub f4e2m1_dt: SortDef,
pub f8e4m3_dt: SortDef,
pub f8e5m2_dt: SortDef,
pub f8ue8m0_dt: SortDef,
pub i4_dt: SortDef,
pub tf32_dt: SortDef,
// Egglog builtin primitives (for term construction only)
@@ -312,6 +316,10 @@ impl BaseSorts {
bf16_dt: sort(DTYPE, "Bf16", &[]),
int_dt: sort(DTYPE, "Int", &[]),
bool_dt: sort(DTYPE, "Bool", &[]),
f4e2m1_dt: sort(DTYPE, "F4E2M1", &[]),
f8e4m3_dt: sort(DTYPE, "F8E4M3", &[]),
f8e5m2_dt: sort(DTYPE, "F8E5M2", &[]),
f8ue8m0_dt: sort(DTYPE, "F8UE8M0", &[]),
i4_dt: sort(DTYPE, "I4", &[]),
tf32_dt: sort(DTYPE, "TF32", &[]),
p_add: func("+", &["a", "b"]),
@@ -367,6 +375,10 @@ impl BaseSorts {
&self.bf16_dt,
&self.int_dt,
&self.bool_dt,
&self.f4e2m1_dt,
&self.f8e4m3_dt,
&self.f8e5m2_dt,
&self.f8ue8m0_dt,
&self.i4_dt,
&self.tf32_dt,
] {
@@ -444,6 +456,7 @@ pub fn base_expression_egglog() -> String {
p.add_ruleset("expr");
p.add_ruleset("dtype_prop");
p.add_ruleset("cleanup");
p.add_ruleset("post_cleanup");
// Register all sorts
s.register(&mut p);
@@ -506,14 +519,24 @@ pub fn base_expression_egglog() -> String {
);
// Cancel common factor in division: (a*b)/(a*c) → b/c
p.add_rule(
rewrite(
"div-cancel-factor",
div(mul(v("a"), v("b")), mul(v("a"), v("c"))),
div(v("b"), v("c")),
)
.ruleset("expr"),
);
//
// DISABLED: this rule rewrites to a `div` whose operands are themselves
// typically `mul`s of stride/shape factors, so the new tree matches the
// same `div-cancel-factor` pattern again. Combined with `mul-comm` (4
// orderings of a*b/c*d) it drives a combinatorial blow-up on the deep
// `flatten_strides` index expressions produced by stacked unfold-based
// convolutions. At 7 backbone YOLO v11 layers it accounts for ~66k
// matches in a single early-stage saturate. Productive simplifications
// (`div-self`, `mod-mul-self`, `div-const`, `merge-dims`) cover the
// cases we actually need without the explosion.
// p.add_rule(
// rewrite(
// "div-cancel-factor",
// div(mul(v("a"), v("b")), mul(v("a"), v("c"))),
// div(v("b"), v("c")),
// )
// .ruleset("expr"),
// );
// Division self-cancel: a/a → 1
p.add_rule(rewrite("div-self", div(v("a"), v("a")), num(i64(1))).ruleset("expr"));
@@ -620,14 +643,26 @@ pub fn base_expression_egglog() -> String {
.ruleset("expr"),
);
// `div-div`, restricted to nested constant divisors only. The original
// unconstrained form `(a/b)/c → a/(b*c)` produces a new `div` whose
// denominator matches the same rule again as soon as `a` is itself a
// `div`, and `flatten_strides` produces 4-deep div chains for every
// conv. Under `(saturate expr)` the unrestricted version is the single
// biggest match generator on YOLO v11 (~200k matches at 7 layers,
// growing super-linearly). Restricting both divisors to numeric
// literals keeps the productive constant-folding case
// (e.g. `((w+7)/2)/2 → (w+7)/4`) while completely avoiding the
// explosion on stride/index expressions whose denominators are
// composite expressions like `c_in*H*W`.
p.add_rule(
rewrite(
"div-div",
div(div(v("a"), v("b")), v("c")),
div(v("a"), mul(v("b"), v("c"))),
"div-div-num",
div(div(v("a"), num(v("?b"))), num(v("?c"))),
div(v("a"), num(pmul(v("?b"), v("?c")))),
)
.ruleset("expr"),
);
p.add_rule(
rewrite(
"add-div",

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ impl Add for GraphTensor {
type Output = GraphTensor;
fn add(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to add tensors. Got {:?} and {:?}",
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
type Output = GraphTensor;
fn mul(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(
self.dims(),
rhs.dims(),
"Dims must match to multiply tensors."
);
assert_eq!(
self.dtype, rhs.dtype,
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
@@ -355,7 +361,17 @@ impl GraphTensor {
/// Take the elementwise maximum of a tensor and a float
pub fn maximum_f32(self, rhs: f32) -> GraphTensor {
self.maximum(self.graph().constant_float(rhs).expand_rhs(self.shape))
// `constant_float` always emits F32; cast it to `self.dtype` so the
// downstream `lt`/`le` comparisons inside `maximum` don't panic when
// `self` is Int (e.g. `aten.clamp` on Int top-k indices coming out
// of an MoE router). For Int self the cast floors the bound, which
// matches PyTorch's `clamp(int_tensor, min=<float>)` semantics.
self.maximum(
self.graph()
.constant_float(rhs)
.cast(self.dtype)
.expand_rhs(self.shape),
)
}
/// Take the elementwise minimum of two tensors
@@ -464,6 +480,42 @@ pub(super) mod tests {
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
}
#[test]
#[should_panic(expected = "Dims must match to add tensors.")]
fn test_add_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a + b;
}
#[test]
#[should_panic(expected = "Dims must match to multiply tensors.")]
fn test_mul_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a * b;
}
#[test]
#[should_panic(expected = "Dims must match to mod tensors.")]
fn test_mod_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a % b;
}
#[test]
#[should_panic(expected = "Dims must match to lt tensors.")]
fn test_lt_rejects_implicit_broadcast() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3));
let b = cx.tensor((1, 3));
let _ = a.lt(b);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -547,6 +599,27 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_mod_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
test_binary_transforms(
size,
(),
|a, b| a % b.expand_rhs(a.shape),
|a, b| {
let lhs = a.to_vec1::<f32>().unwrap();
let rhs_scalar = b.to_scalar::<f32>().unwrap();
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
},
identity,
shift_from_zero,
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
@@ -560,6 +633,28 @@ pub(super) mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_lt_scalar_broadcast(size in 1usize..64) {
// rank-0 RHS expanded against rank-N LHS for `lt`.
test_binary(
size,
(),
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|a, b| {
let scalar = b.to_scalar::<f32>().unwrap();
let lhs = a.to_vec1::<f32>().unwrap();
let result: Vec<f32> = lhs
.iter()
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
.collect();
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
},
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]

Some files were not shown because too many files have changed in this diff Show More