Compare commits

...

8 Commits

Author SHA1 Message Date
Joe Fioti
62e86f9dc5 Reuse cuBLASLt prepares across matching graph ops 2026-06-01 00:25:30 +00:00
Joe Fioti
75e4e6be0a Simplify example mains and trim CUDA profiling output (#339)
* Simplify example mains and trim CUDA profiling output

* Simplify model examples and adjust CUDA profiling output

* Simplify example model setup and CUDA profiling output
2026-05-29 23:37:13 -04:00
tucker-luminal
4cd47ffa45 luminal_python: dynamic-shape gather/scatter in the PT2 translator (#334)
`gather_elements` / `scatter_elements` / `scatter_nd` in luminal-core
require concrete shape dims, so `torch.compile(model, backend=luminal_backend)`
crashed the moment Dynamo handed us a SymInt for batch or seq_len.

The translator now lowers all three through Expression-typed shape
arithmetic and only calls luminal-core primitives that already accept
Expressions, with a small `dim_arith` helper that keeps every shape
product in canonical commutative order so different code paths don't
build syntactically-different versions of the same logical dim.

Verified end-to-end on Qwen3-30B-A3B across varying prompt lengths.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 16:43:38 -05:00
Joe Fioti
db72cf505c Dyn dim intervals (#333)
* Consolidate compile APIs and bucket config

* Fix Metal compile options API for clippy and llama CI
2026-05-24 18:01:56 -04:00
tucker-luminal
766db93b08 Dtype i64 f64 first class (#323)
* tests for interface specification

* luminal_python: skip CUDA zero-copy for float64 outputs

Luminal collapses `DType::F64` to F32 internally, so a CUDA kernel for an
f64-typed output actually writes f32 bytes. The Python wrapper was
registering an `f64` pre-allocated tensor's `data_ptr` as the zero-copy
destination — handing the kernel a 12-byte payload for a 24-byte buffer,
leaving half of every f64 element as garbage.

Fix: only set the device pointer for the dtypes luminal *natively* writes
end-to-end on CUDA (f32, f16, bf16). For f64, pre-allocate the f64 output
tensor but skip the device-ptr handoff; the collection path then falls
through to `get_output()` (which reads the kernel's actual f32 output)
and casts to f64 via the existing read-and-cast branch.

Pre-existing latent bug — the test scaffolding from the prior commit
exposes it as `test_boundary_noop_preserves_dtype_and_values
[cuda-float64_f32_exact]`. Phase E adds first-class f64 IR support
which will eventually let the kernel write real f64 bytes and restore
zero-copy here; this commit unblocks the CUDA test sweep until then.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal: first-class I64 / F64 in IR + CPU + PT2 boundary

Today luminal collapses every PT2 integer dtype to `DType::Int` (i32) and
`float64` to `DType::F32` at the FFI boundary. The LUM-486 commit papered
over symptoms by storing the user-visible PT2 dtype code in a sidecar and
casting back at the Python wrapper — but the IR still computes in i32 /
f32, so values outside those ranges (`2**40`, `1.0000000000000002`) lose
information before the kernel ever runs.

This commit makes i64 and f64 first-class through the IR end-to-end:

  - `DType::I64` added; custom `Debug` impl maps it to `"Int64"` (not
    `"I64"`) because egglog has a built-in primitive sort named `I64` for
    integer literals in shape expressions, and the egglog-format sites
    in `hlir.rs` serialize `DType` via `{:?}` — emitting `"I64"` would
    shadow the primitive and panic the egraph loader with
    `UnboundFunction("I64", ...)`. Documented at the variant.
  - `f64_dt: sort(DTYPE, "F64", &[])` and `int64_dt: sort(DTYPE, "Int64",
    &[])` registered in `egglog_utils::base`; matching arms added to
    `extract_dtype`.
  - `NativeData::I64(Vec<i64>)` and `NativeData::F64(Vec<f64>)` added.
    `len`, `f32`/`f16`/`bf16`/`i32`/`bool` accessors widen for both; new
    `i64()` and `f64()` accessors mirror the existing access pattern.
    `From<Vec<i64>>` and `From<Vec<f64>>` impls round out the inference.
  - Cast op covers the full new Cartesian product. Cast to `Int` from
    `I64` saturates, matching `tensor.to(torch.int32)` overflow
    semantics. Cast to `F32` from `F64` narrows.
  - CPU kernels handle I64/F64 directly in Add, Mul, Mod, Gather, Scatter,
    SumReduce, MaxReduce. Unary transcendentals (`Log2`, `Exp2`, etc.)
    still bridge through f32 in v1 — the translator inserts cast-bridges
    around them; reaching the kernel with `I64`/`F64` panics with a
    pointer to the missing bridge.
  - `dyn_backend::bytes_to_native_data` preserves i64 / f64 bytes
    directly; `dummy_data_for_dtype` includes i64 fill. New trait methods
    `get_output_i64` / `get_output_f64` on `DynBackend` with the native
    runtime impl.
  - `cuda_dtype` extended (`"long long"` for I64). Full CUDA kernel
    support for i64/f64 elementwise emit is Phase F — the mapping is
    here so the egglog ext correctly types the kernel inputs, but
    several elementwise CUDA paths still need codegen work.
  - PT2 boundary: `torch_dtype_int_to_luminal` returns `I64`/`F64` for
    codes 5/8. `TypedData::from_pytorch_bytes` and
    `pt2_compiled_model::bytes_to_typed` preserve raw bytes for both.
    `luminal_dtype_to_pt2_code` round-trips `I64` to code 5.
  - `CompiledGraph` exposes `get_output_i64` / `get_output_f64`. The
    Python wrapper routes `torch.int64` / `torch.float64` outputs
    through them — no more i32-buffer-then-`.to(int64)` cast-back layer.
  - Test scaffolding updated: the `int64_*` and `float64_*` cases move
    from `test_boundary_warns_when_input_dtype_requires_conversion`
    (where they previously had to warn because a conversion was real)
    to `test_boundary_does_not_warn_when_input_dtype_matches_graph`.
    Reflecting the new contract: int64 / float64 inputs match the
    graph's input dtype directly.

xfails removed from `int64_outside_i32_range` and
`float64_precision_sensitive`. Both now pass on CPU end-to-end. CUDA
parity for i64/f64 elementwise kernels lands in Phase F (commit 17).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal: hard-reject dtype mismatch at the FFI boundary

Before: when a caller passed an input whose dtype didn't match the graph's
declared input dtype, the Python wrapper silently `.to(expected_dtype)`-ed
it and emitted a `DTypeBoundaryWarning`. Two real problems:

1. Precision bugs hid. A user passing `torch.float64` into a graph that
   wanted `torch.float32` lost precision-sensitive values
   (`1.0000000000000002` → `1.0`) without anything in the test suite or
   logs flagging it. The warning only showed up at first call and was
   trivially missed in a CI log.

2. Per-call allocation+copy burnt cycles the caller couldn't see in their
   profile. For a model invoked thousands of times a second, the cast was
   a real cost the user wasn't aware was happening.

The contract is now strict: `model(x)` requires
`x.dtype == model.input_dtypes[i]` for every positional input. Mismatched
dtype raises `DTypeBoundaryError` before any FFI work. Migration: call
`.to(model.input_dtypes[i])` at the call site.

  - Add `DTypeBoundaryError(TypeError)` to `compiled_model.py` with a
    docstring that names the prior precision-bug class and points the
    user to the call-site migration.
  - Delete `.to(expected_dtype)` from the input hot path; replace with a
    direct `raise`. `DTypeBoundaryWarning` removed entirely.
  - Metal backend factory rejects `DType::I64` and `DType::F64` inputs at
    translate-time with `UnsupportedDtype` — Metal codegen has no native
    64-bit kernels, and reaching the kernel emitter with these used to
    panic deep in MSL generation with an unhelpful error.
  - Test scaffolding: `test_boundary_warns_when_input_dtype_requires_conversion`
    becomes `test_input_dtype_mismatch_rejects` and asserts the raise.
    `test_boundary_does_not_warn_when_input_dtype_matches_graph` becomes
    `test_matching_dtype_does_not_raise`. The set of "first-class round-
    trip" dtypes is captured as `_FIRST_CLASS_NOOP_DTYPES` — narrow
    integers (uint8 / int8 / int16) collapse to luminal's `Int` (i32),
    so they can't round-trip the noop model without an explicit
    `.to(int32)` cast and live only in the reject-path test.

Breaks user code that today silently autocasts. Intentional. The
migration message at the raise site names the exact `.to(...)` call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_cuda_lite: I64 / F64 output read paths

Wires the runtime side of `DType::I64` / `DType::F64` for CUDA. The
`cuda_dtype` mapping in `luminal_cuda_lite/src/lib.rs` already returned
`"long long"` / `"double"` for these (added with first-class IR
support), so the kernel emitters were producing correctly-typed
output bytes — but the Python wrapper's `get_output_i64` /
`get_output_f64` calls landed on the trait-default panic
("not supported by 'cuda_lite'"), surfacing as 8 CUDA test failures
on the test_dtype_boundary suite.

Adds:

  - `CudaRuntime::get_i64` / `get_f64` — read raw 8-byte chunks from
    the output buffer and reinterpret. Mirrors the existing `get_f16`
    / `get_bf16` byte-reinterpret pattern.

  - `CudaLiteDynBackend::get_output_i64` / `get_output_f64` — thin
    forwarders to the runtime methods.

Verified end-to-end with `test_boundary_noop_preserves_dtype_and_values[cuda-int64_outside_i32_range]`
(2**40 round-trips bitexactly through the CUDA kernel) and
`[cuda-float64_precision_sensitive]` (1.0000000000000002 round-trips
without f32 truncation). Full CUDA dtype suite: 42 passed, 0 failed.

The design-doc commit 18 (int32 / bool CUDA zero-copy output plumbing)
is deferred to a follow-up. Both dtypes already work end-to-end via
the host-roundtrip `get_output_*` path; zero-copy is a perf
optimization not blocking any test in the contract suite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_cuda_lite: include I64 in scatter elem-size tables

CUDA scatter kernels compute output buffer / load / store byte counts
via per-dtype size tables. After landing first-class I64, the scatter
emission for an i64 output panicked with
`Unsupported dtype for scatter output_bytes: Int64`, which surfaced as
the egglog optimizer reporting "Failed to find a viable initial genome
after 100 attempts" because every candidate genome containing an i64
scatter immediately panicked.

Adds I64 → 8 bytes alongside F64 to the five size tables in
`kernel/other_ops.rs` and `kernel/hlir.rs`. MoE routing (idx_dtype =
int32 and int64) now compiles and runs end-to-end on CUDA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* tests: drop input-layout / mutation-alias tests from dtype branch

These two test files came along with `d0cec1fc tests for interface
specification` as the test scaffolding for the broader boundary-
contract work — input layout strides (Phase G) and mutation/alias
writebacks (Phase D). Neither feature is in the dtype-only branch,
so the tests either xfail or skip here and are noise to the reader
trying to understand what this branch ships.

Keep only `test_dtype_boundary.py` since that's the suite that
exercises the I64/F64 IR work and the FFI dtype-mismatch rejection
this branch actually delivers. The two removed files live on
`pt2-boundary-contract` where the features they test land.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* tests: drop removed files from run_all_tests.sh and run_test.sh

Follow-up to the previous commit's deletion of test_input_layout.py
and test_mutation_alias_contract.py. Both scripts referenced those
files in their pytest invocations.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* boundary: strict dtype at output read; translator inserts Cast

Reviewer's "no implicit casts at the read boundary" directive,
applied to both runtimes:

* `CudaRuntime::get_i64` / `get_f64` now check the producer buffer's
  `buffer_specs[..].dtype` and panic on anything other than `I64` /
  `F64`. The panic message points at the translator as the place to
  insert an explicit `Cast` — no silent widening from i32 / bool /
  f32 / f16 / bf16.

* `NativeDynBackend::get_output_i64` / `get_output_f64` match only
  `NativeData::I64` / `F64` and panic otherwise. The internal
  `NativeData::i64()` / `f64()` accessors stay (they're load-bearing
  for in-kernel mixed-dtype binary ops); only the user-visible read
  boundary is strict.

* `CompiledGraph::get_output_i64` / `get_output_f64` docstrings drop
  the "widens i32 / bool when the producer chose a narrower dtype"
  line; replaced with "Strict on producer dtype — the graph's output
  node must already be I64 / F64."

For the strict boundary to be reachable when the EP-declared dtype
differs from what the producer chose (e.g. `Argsort` / `TopK` emit
i32 indices but `torch.int64` was requested), the translator's
output loop now inserts an explicit `tensor.cast(declared)` before
`output()` when the declared dtype is `I64` / `F64`. The Cast is in
the graph — egglog can see it.

`Vec<f32>::from([…])` typed-local style applied to test set_data
call sites that previously relied on float-literal inference
collapsing to `Vec<f32>`; after 941b6962 added `From<Vec<f64>>`,
those literals now infer as `Vec<f64>` and the buffer lands as
`NativeData::F64`, panicking the strict read.

CPU: 234 pytest passed, 21 skipped. Core: 112 luminal + 16
luminal_nn tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* translator: explicit F32 bridge around unary transcendentals on F64

The CPU `unary_impl` has no native F64 path — `Log2` / `Exp2` /
`Sin` / `Sqrt` / `Recip` and the higher-level transcendentals that
compose them all bridge through f32 in v1. Previously the panic
inside `unary_impl` for `NativeData::F64` was the only thing keeping
the F32-bridge story honest, and the comment apologized for not
inserting the bridge ourselves.

Two changes:

* Add `Translator::translate_unary_op_f32_bridge` — same shape as
  `translate_unary_op`, but when the input is `DType::F64` wraps the
  op as `f(input.cast(F32)).cast(F64)`. The two `Cast` nodes are in
  the graph; egglog sees them; the kernel only ever sees F32.

* Re-dispatch every transcendental unary in `translator/dispatch.rs`
  (`aten.{log,log2,exp,exp2,sin,cos,sqrt,rsqrt,reciprocal,sigmoid,
  tanh,silu,gelu}.default`) through the f32-bridge variant. Ops that
  don't need transcendentals (`neg` = mul-by-(-1), `relu`, `abs`)
  stay on plain `translate_unary_op` and preserve F64 natively.

* Update the `unary_impl` F64 panic message to direct readers at
  `translate_unary_op_f32_bridge` — reaching the panic now means a
  new transcendental dispatch site forgot to bridge.

Tests: CPU 234 passed, 21 skipped. The
`test_boundary_noop_preserves_dtype_and_values[*-float64_*]` cases
continue to pass via the bridge (they go through the noop addition
not a transcendental, so the bridge doesn't fire for them; but if
anyone adds an F64-transcendental test it'll exercise the bridge
end-to-end).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* ffi: panic on narrow-int dtype codes; defer first-class narrower-int IR

Reviewer flagged the "narrow-int widening" docstring at
typed_data.rs:156 as concerning: today luminal collapses uint8 / int8
/ int16 to `DType::Int` at the byte-conversion boundary. The
restrained answer is to **panic** at the boundary rather than widen
silently — matches the "no implicit casts" directive end-to-end.

Both byte-conversion entry points now reject narrow-int PT2 codes:

* `TypedData::from_pytorch_bytes` (user inputs via
  `set_input_from_ptr`) — codes 1 (uint8) / 2 (int8) / 3 (int16)
  panic with "cast to torch.int32 at the call site, or wait for the
  narrower-int IR follow-up."

* `pt2_compiled_model::bytes_to_typed` (PT2 file weights) — same
  panic, same message.

Models that previously round-tripped through implicit widening
(e.g. quantized int8 weights) will now fail at load time with a
clear message pointing at the missing infrastructure. Follow-up
issue: "Narrower integer dtypes (i8 / u8 / i16) first-class in
`NativeData` + CPU kernels" — once that lands, these panics
disappear and the bytes flow through as `DType::U8` / etc.

Tests: `test_dtype_boundary.py` 21 passed, 21 skipped. The narrow-int
cases in `test_input_dtype_mismatch_rejects` continue to assert
`pytest.raises` — the rejection now comes from the FFI panic
instead of the input-dtype boundary check, but the contract from
the user's perspective is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* compiled_model: unify output read dispatch; clarify zero-copy comment

Three review comments addressed in one place:

* `compiled_model.py:148` — the stale comment ("float64 is collapsed
  to f32 internally; registering an f64 device-ptr would have the
  kernel write 12 bytes into a 24-byte buffer") was wrong after
  941b6962 made F64 first-class. Rewrite to explain why
  pre-allocation is GPU-only: the CUDA kernel needs the device-ptr
  registered before `run()`, while CPU reads back after via
  `_read_typed_output`.

* `compiled_model.py:189` — the per-dtype elif chain duplicated
  across the CUDA-zero-copy and native paths. Refactor into a single
  `_output_readers` dispatch table keyed on `out_dtype` →
  `(getter_name, read_dtype, final_cast)`. The zero-copy fast path
  for f32 / f16 / bf16 stays as a single check at the top; every
  other dtype goes through `_read_typed_output`.

* `compiled_model.py:243` — annotate the `if _use_zero_copy:`
  pre-allocation branch: "the CUDA kernel needs the output's device
  pointer registered *before* `_graph.run()` so the final kernel
  writes directly into PyTorch's buffer. CPU never zero-copies —
  there's no separate device buffer to register against."

Tests: CPU 234 passed, 21 skipped (no behavior change, just refactor).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: clarify scope of kernel-internal widening accessors

Two reviewer comments addressed:

* `src/hlir.rs:1212` — the "Narrowing cast: explicit i64 -> i32 …
  used when the translator bridges an i64 value through a kernel
  that only has an i32 path" comment apologized for a non-existent
  problem. Reword: the `Cast` op IS the explicit graph-level
  conversion; saturating via `as i32` matches
  `tensor.to(torch.int32)` semantics on overflow. No bridging
  framing.

* `src/hlir.rs:2989` (and the matching `f64` accessor at :2914) —
  the docstring said "Used by I64-aware kernels; widens other
  variants when an op promotes a mixed-dtype binary to I64" without
  scoping why that's OK. Rewrite to be explicit: this is a
  **kernel-internal** widening accessor, used by binary kernels to
  read RHS at LHS's width, mirroring PyTorch eager's mixed-dtype
  promotion. The user-visible read boundary
  (`DynBackend::get_output_*`) is strict — that's where the reviewer
  was originally complaining about implicit casts. A follow-up
  translator pass that inserts explicit `Cast` ops on mixed-dtype
  binary operands would remove this in-kernel widening entirely;
  not in scope here.

No code change. Tests unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Revert "translator: explicit F32 bridge around unary transcendentals on F64"

This reverts commit f77a2b92.

The bridge inserted `Cast(F32) → unary → Cast(F64)` inside the
translator whenever a user called `torch.exp(x)` (or sin/cos/log/...)
on an `f64` tensor. The output kept the `torch.float64` dtype tag,
but the math itself ran in single precision — exactly the kind of
silent precision downgrade hidden behind a wider dtype that this
PR's "no implicit casts" directive is meant to reject. The bridge
solved one reviewer comment ("unary_impl panics on F64") by
relocating the implicit cast from the runtime to the translator —
not by removing it.

Restore the original behavior: `unary_impl` panics on `F64`, and now
with a sharper message that says outright "cast inputs to F32 at the
call site" and explicitly names the rejected alternative ("silent
F32 bridging is intentionally rejected: it would hide a precision
downgrade behind an `F64` dtype tag"). The same wording goes on the
Int / I64 / Bool arms so each unsupported variant has a clear,
self-contained recovery path.

A native F64 transcendental kernel is the proper fix for double-
precision `exp`/`log`/`sin`/... — tracked in the F64-CUDA-elementwise
follow-up issue.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* hlir: drop From<Vec<f64>> for NativeData; revert luminal_nn / movement / tests churn

The PR was carrying ~70 lines of `vec![1., 2., 3.]` →
`Vec::<f32>::from([1., 2., 3.])` style churn across
`luminal_nn/src/attention.rs`, `src/frontend/movement.rs`, and
`src/tests/mod.rs`. The trigger was the new
`impl From<Vec<f64>> for NativeData`: it made float literals
ambiguous between `Vec<f32>` and `Vec<f64>` at every `set_data`
call site, forcing the explicit `Vec::<f32>::from([...])` spelling.

Drop the `From<Vec<f64>>` impl. It had no callers (`grep -rn` for
`Vec<f64>` going into NativeData turned up nothing — the F64
buffer-construction sites in `dyn_backend.rs` and `typed_data.rs`
use `as_bytes` on a raw `Vec<f64>`, not the `From` impl). Callers
that genuinely want an F64 buffer can still write
`NativeData::F64(my_vec)` directly. With the impl gone, float
literals re-infer to `f32` via the sole `From<Vec<f32>>` impl —
the original idiom — so the three churn-only files revert cleanly
to their `main` state. A short comment at the deletion site explains
why this impl is intentionally absent.

Net diff on the PR drops by ~70 lines of pure style churn. No
behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* translator: cast argsort / topk / sort indices to I64 at the PT2 boundary

`torch.argsort` / `torch.topk(...).indices` / `torch.sort(...).indices`
always return `int64`. luminal's frontend `stable_argsort` returns `Int`
(i32) — the storage-efficient default that direct Rust callers want and
that the existing `luminal_cuda_lite` op-functional / search-equivalence
tests read back via `rt.get_i32(...)`.

Previously, the gap was bridged with a post-hoc Cast in the translator's
output loop (`translator/mod.rs`) — "if the EP declared `I64` and the
producer chose `Int`, insert a Cast(I64) before Output." That meant a
graph node was being inserted by the framework whose presence and
location the user couldn't see in their dispatch — exactly the kind of
hidden behavior this PR's "no implicit casts" directive is meant to
avoid. It also did nothing to fix the underlying mismatch — the producer
was still emitting i32 indices.

Move the cast to the producer side of the PT2 boundary instead:

* `translate_argsort` casts the `stable_argsort` result to I64 before
  inserting it into the tensor map.
* `translate_topk` casts the sliced `topk_indices` to I64. Same buffer
  feeds both the values-gather (via `gather_elements`, which accepts
  any int dtype on its index operand) and the indices output.
* `translate_sort` casts the indices half of the tuple to I64; the
  values half stays at the source dtype.

The frontend `argsort` / `stable_argsort` are unchanged — direct Rust
callers continue to get i32 indices.

Drops the band-aid output-Cast block from `translator/mod.rs`, which is
no longer needed (the producer now emits the right dtype). The strict
read boundary still catches any future dtype mismatch loudly.

Verification:
* `cargo test -p luminal -p luminal_nn`: 114 + 16 + 5 passed.
* CPU pytest (hlir_ops + unary + dtype_boundary): 250 passed, 21 skipped.
* CUDA pytest (same suites + test_llama3 non-slow): 281 passed
  (previously 278 passed, 3 failed on `test_argsort_stable_duplicates
  [idx_dtype1]`, `test_topk_values_width_128_with_indices`,
  `test_tiny_moe_routing[idx_dtype1]` — all now passing).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* translator: cast argmax / argmin to I64; fix narrow-int range-pattern clippy

CI surfaced two issues on the dtype-i64-f64-first-class branch:

* `Python CUDA Tests` failed on 15 `tests/test_scalars.py` cases for
  `argmax` / `argmin` (all variants — keepdim, 0d, all-reduce, per-dim).
  Same root cause as the previously-fixed argsort/topk/sort cases:
  PyTorch's `torch.argmax` / `torch.argmin` return int64 indices (same
  `kLong` contract as `sort` / `topk`, pinned in the structured kernel
  meta function), but `translate_argextremum` was emitting i32 — and
  the strict CUDA `get_i64` read boundary refused to widen.

  The old docstring for `translate_argextremum` already named the
  trick: "the Python wrapper widens at the boundary." That wrapper
  is gone (strict reads), so the fix is to cast at the translator
  site, same as argsort/topk/sort:

  - `Ok(result * 1)` → `Ok((result * 1).cast(DType::I64))`
  - The 0-d short-circuit path's `.cast(DType::Int)` becomes
    `.cast(DType::I64)`.
  - Docstring updated to reflect the new boundary cast.

  I had missed these locally because `test_scalars.py` wasn't in the
  CUDA sweep I ran while iterating; the PR-CI full pytest run caught
  them.

* `CUDA Clippy` failed on two `1 | 2 | 3 =>` match arms — Rust 1.95
  clippy now flags those under `manual_range_patterns`. Rewrote both
  as `1..=3 =>`. No behavior change.

Verification:
* `cargo clippy -p luminal_python --features cuda --tests
  -- -D warnings`: clean.
* `LUMINAL_TEST_DEVICE=cuda pytest tests/test_scalars.py`:
  171 passed, 4 xfailed (all previously-failing argmax/argmin cases
  now pass).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: refresh get_output_i64 / get_output_f64 comments + panic messages

The doc comments on the strict i64 / f64 readers still pointed at the
band-aid output-loop Cast in `translator::translate_graph` ("the
translator inserts an explicit `Cast(I64)` before the Output; see
`translator::translate_graph`"). That block was reverted earlier in
this PR — casts now live at each producer op's translator dispatch
site (`translate_argsort` / `translate_topk` / `translate_sort` /
`translate_argextremum`, mirroring PyTorch's `kLong` contract pinned
by the structured-kernel meta function in `Sorting.cpp`).

Updates the doc + panic-message wording in three places to match the
post-revert reality:

* `CompiledGraph::get_output_i64` / `get_output_f64` (pyo3 wrapper,
  `compiled_graph.rs`)
* `NativeDynBackend::get_output_i64` / `get_output_f64`
  (`dyn_backend.rs`)
* `CudaRuntime::get_i64` / `get_f64` (`runtime.rs`)

Each one now says, in substance: "the producer's buffer must already
carry the requested dtype; on the PT2 path that's handled at the
per-op translator dispatch site, not in a centralized output loop."
Panic messages reworded from "Insert an explicit Cast(I64) in the
graph before the Output" — which read like advice to an end user
authoring the IR by hand — to "Add a `Cast(DType::I64)` before the
Output in the producer graph," which fits both manual IR-authoring
callers and the translator-dispatch case naturally.

For `get_output_f64`, also added a one-liner pointing readers at the
`unary_impl` F64 panic policy (cast inputs to F32 at the call site;
no silent F32 bridging behind an F64 tag).

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: trim get_output_i64 / get_output_f64 doc + panic strings

Previous pass over these comments name-dropped every per-op translator
dispatch site (`translate_argsort`, `translate_topk`, ...) — context
that's irrelevant to a caller of the read functions. Reduce each to a
one-line contract: "Strict: the buffer must already be `DType::Xxx`;
no widening at the read boundary." Panic strings shortened the same
way — keep the "Add a `Cast(...)` before the Output" pointer, drop
the editorial trailing clause.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: trim argextremum dtype paragraph

Replace the seven-line "PyTorch's kLong contract / structured kernel
meta function / storage-efficient default" exposition with one line:
"The result is cast to `DType::I64` to match PyTorch's int64
argmax / argmin indices." The rest of the docstring (FX positional
inputs, `dim=None` flattening, slice-then-materialize rationale)
stays — those are non-obvious mechanical details a reader fixing a
bug actually needs.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* dtype: PyTorch ScalarType as source-of-truth for PT2 dtype codes

Addresses the last open review comment on `compiled_graph.rs:84` /
`pt2_util.rs:208`: "we maintain our own enum of the datatypes in
pytorch, it would be nice if we could bind over the ones from
pytorch and use that as the source of truth and not our own."

Before: the PT2 dtype-code numbering (`1=uint8, 2=int8, ...,
13=bfloat16`) was duplicated across **four** Rust sites and a
hand-rolled dict in Python. Renumbering or new variants in PyTorch's
PT2 schema (e.g. the float8 family added in
pytorch/pytorch#143343) silently miscompiled at runtime.

After: a single `TorchDType` enum in `crates/luminal_python/rust/src/
torch_dtype.rs` owns the canonical numbering. All four call sites
route through it:

* `pt2_util::torch_dtype_int_to_luminal` — delegates to
  `TorchDType::from_code(...).into()`.
* `typed_data::from_pytorch_bytes` — matches on named variants;
  narrow-int panic now reads `TorchDType::Byte | Char | Short`
  instead of `1..=3`. The silent `_ => f32` fallback is gone —
  unknown codes panic with the variant name.
* `pt2_compiled_model::bytes_to_typed` — collapsed to a one-line
  delegate (`TypedData::from_pytorch_bytes(bytes.to_vec(), dtype)`);
  the duplicated panic block is deleted.
* `compiled_graph::luminal_dtype_to_pt2_code` — delegates to
  `TorchDType::try_from(dtype).map(|t| t.code())`.

Python side: `dtype_util.py`'s hardcoded `_TORCH_DTYPE_TO_CODE`
dict is rebuilt at import time from `torch._export.serde.schema.
ScalarType.<NAME>.value` — PyTorch becomes the runtime source of
truth on both sides of the FFI boundary. `torch._export.serde.
schema` is a quasi-private API (leading underscore) but it's the
module PT2 actually wire-serializes against; documented at the
import site.

Parity test: `tests/test_torch_dtype_parity.py` consumes a new
pyo3-exported `_torch_dtype_codes()` map and asserts every Rust
variant matches PyTorch's enum by name and value. If PyTorch
renumbers or adds a variant, the test fails loudly at CI rather
than miscompiling silently at runtime. Negative-test verified
locally by setting `Long = 99` — fails with
`LONG: luminal=99, pytorch=5`. Added to both `run_test.sh` and
`run_all_tests.sh`; CUDA runner globs `tests/` so it picks it up
automatically.

`TorchDType` enumerates all 19 variants currently in
`torch._export.serde.schema.ScalarType` (including `Unknown`,
the three `Complex*` types, `Uint16`, and the four `Float8E*`
variants); `TryFrom<TorchDType> for DType` returns `Err` for any
variant luminal's IR doesn't model, with the boundary code
panicking on `Err` with the variant name.

Verification:
* `cargo test -p luminal_python` — 8 passed (3 new for the enum,
  5 pre-existing).
* `cargo test -p luminal` — 114 passed.
* `cargo clippy -p luminal_python --features cuda --tests
  -- -D warnings` — clean.
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 252 passed, 21 skipped.
* CUDA pytest (same suites + `test_scalars`, `-m "not slow"`) —
  444 passed, 4 xfailed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fmt: cargo fmt on torch_dtype refactor

Applies `cargo fmt --all` to the three files touched by the previous
commit. The Fmt CI job caught:

* `lib.rs` — `_torch_dtype_codes` chain wrapped over multiple lines.
* `pt2_compiled_model.rs` — `use crate::pt2_parser;` ordered before
  `use crate::pt2_schema;`.
* `typed_data.rs` — `unwrap_or_else` closure inlined onto one line.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* ci: bump Python CUDA Slow Tests to A100-80GB

`test_hf_qwen3_moe_real_config_full` loads the real Qwen3-30B-A3B
checkpoint at bf16 (≈60 GiB of weights). Modal's default `--gpu A100`
is the 40 GiB SKU, which can't hold the full model + PyTorch's
reference forward state. When the test OOMs it doesn't release its
allocated memory back to the CUDA driver, so every subsequent
big-model test in the run inherits a ~39 GiB dead-memory wall and
also OOMs (`test_hf_llama38b_mark_dynamic_seq_dim_before_compile`,
`test_hf_llama3_full`, ...).

Request the 80 GiB SKU explicitly. Aligns with the model-specific
Modal jobs on this PR (`gemma`, `qwen3_moe`, etc.) which already
spec `A100-80GB` and pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* pt2_util: panic on narrow ints instead of widening to Int

`torch_dtype_int_to_luminal` was the one remaining site that silently
collapsed `Byte` / `Char` / `Short` to `DType::Int`. Even though the
byte-loading paths (`typed_data::from_pytorch_bytes`,
`pt2_compiled_model::bytes_to_typed`) already refuse those codes,
the metadata-read path through `pt2_util` was still happy to widen,
which left the user's actual dtype invisible past the FFI boundary
on graphs whose declared inputs were narrow ints.

Reject at this site too. Same panic message as the byte paths
("isn't a first-class IR type yet — cast to torch.int32 at the call
site, or wait for the narrower-int IR follow-up"), so the failure
mode is consistent across all three sites.

Test update: `test_input_dtype_mismatch_rejects[uint8 / int8 /
int16]` previously asserted a `DTypeBoundaryError` raised at *call*
time — that was the artifact of the silent widening flow (the graph
compiled with narrow → int32 substitution, then call-time refused
because the user's tensor still had the narrow dtype). The reject
now fires at *compile* time via the translator panic, so the test
asserts on the panic message instead. `pyo3_runtime.PanicException`
inherits from `BaseException`, not `Exception`, so `pytest.raises`
broadens to `BaseException`; the message match keeps the contract
test specific.

Verification:
* `cargo test -p luminal_python` — 8 passed.
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 252 passed, 21 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* torch_dtype: refuse narrow-int conversions in both directions

The two `TryFrom` impls were silently mapping narrow ints:

* `TryFrom<TorchDType> for DType` mapped `Byte` → `DType::U8`,
  `Char` → `DType::I8`, `Short` → `DType::I16`, `Uint16` → `DType::U16`.
  Those luminal DType variants exist in the enum but aren't first-
  class through the IR (no kernels, no codegen) — handing them out
  produced buffers downstream code couldn't actually run on.

* `TryFrom<DType> for TorchDType` was the mirror: `U8` → `Byte`,
  `I8` → `Char`, `I16` → `Short`, plus a stale `U16` → `Int`
  *workaround* (silently aliased uint16 bytes as int32, predating
  PyTorch's `UINT16 = 28` schema entry).

Move all of those to the `Err` arm in both directions. Downstream
sites (`compiled_graph::luminal_dtype_to_pt2_code`,
`pt2_util::torch_dtype_int_to_luminal`, ...) translate the `Err`
into a typed panic with the variant name, so the failure mode is
consistent with the rest of the no-implicit-cast directive — same
spirit as the previous commit on `pt2_util`.

Test updates:
* `supported_dtypes_roundtrip` no longer includes `U8`/`I8`/`I16` —
  they aren't first-class, can't roundtrip.
* New `narrow_ints_refuse_conversion` asserts the `Err` direction
  on `Byte`/`Char`/`Short` (forward) and `U8`/`I8`/`I16`/`U16`
  (reverse).

Verification:
* `cargo test -p luminal_python --lib torch_dtype` — 4 passed.
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 252 passed, 21 skipped.
* `cargo clippy -p luminal_python --features cuda --tests
  -- -D warnings` — clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* torch_dtype: refuse TF32 in DType → TorchDType conversion

`TryFrom<DType> for TorchDType` was silently aliasing
`DType::TF32 → TorchDType::Float`. TF32 isn't a storage dtype on
the PyTorch side (PyTorch has no `torch.tf32`); it's a compute-mode
hint that affects how matmuls are rounded but the underlying buffer
is still f32. If a luminal graph genuinely carried `DType::TF32`
through to the boundary and we mapped it to `Float`, PyTorch would
receive a tensor tagged as f32 that the caller had been tracking as
TF32 inside luminal — exactly the silent-dtype-aliasing pattern
we've been hunting down through the rest of this PR.

Refuse instead. A caller that needs a real f32 bridge can insert an
explicit `Cast(F32)` upstream — same pattern as the F64
transcendental story (a graph-level Cast rather than a hidden
runtime conversion). The existing `Err`-handling at every caller
(`compiled_graph::luminal_dtype_to_pt2_code`, ...) panics with the
named variant.

Test update: `TF32` joins the narrow-int set in
`narrow_ints_refuse_conversion`.

Verification:
* `cargo test -p luminal_python --lib torch_dtype` — 4 passed.
* CPU pytest sweep — 252 passed, 21 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* compiled_model: strict per-dtype dispatch; drop narrow-int reread cast

Addresses two recent review comments on
`crates/luminal_python/src/luminal/compiled_model.py`:

* "are we making vectors that are int32 and putting narrow int types
  in each slot? What is going on?" — `_output_readers` had three
  entries that read via `get_output_i32` then `.to(narrow_dtype)`'d
  back: a leftover from when the IR silently widened narrow ints to
  i32. After the recent narrow-int rejections in `pt2_util` and
  `torch_dtype.rs`, no graph can actually reach this code with a
  narrow-int declared output, so the dispatch entries are unreachable.
  Delete them.

* "Why do we fallback to f32 instead of erroring?" — `_read_typed_
  output`'s `if entry is None:` branch read the buffer as f32 and
  `.to(out_dtype)`'d back regardless of the declared dtype. That's
  the same silent-dtype-aliasing pattern we've been hunting down
  through the rest of the PR.

  Replace with an explicit `NotImplementedError` naming the unsupported
  dtype. Add explicit `_output_readers` entries for `float32` (which
  was relying on the fallback as a no-op cast on CPU) and for
  `float16` / `bfloat16` (documented as reading via the generic
  f32 getter and `.to()`-ing back — the runtime kernels already emit
  f32 bytes for these, so the cast at the end is the inverse of
  upstream's conversion, not a fresh precision drop; a proper typed
  getter is follow-up work).

Net effect: every supported output dtype is an explicit dispatch
entry, every unsupported one raises a clear `NotImplementedError`,
and the narrow-int reread-and-cast path is gone.

Verification:
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 252 passed, 21 skipped.
* CUDA pytest (same suites + `test_scalars`, `-m "not slow"`) —
  444 passed, 4 xfailed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* compiled_model: error vs silently default to float32 on missing dtypes

Addresses the new review comment on
`compiled_model.py:214` ("probably should error vs assume?"):

The pattern `code_to_torch_dtype(codes[i]) if i < len(codes) else
torch.float32` appeared in three places (one input loop, two
output loops) and silently defaulted to float32 when the Rust side
returned a shorter dtype-code list than the declared input/output
count. Same silent-default pattern the reviewer's been hunting down
through the rest of the PR.

Replace all three sites with up-front length checks that raise
`RuntimeError` if the counts don't match, then build the typed
`torch.dtype` list once from the codes and reuse it. Net effect:
* If the Rust side returns inconsistent counts, the error names the
  declared names and the count mismatch directly — points at the
  graph-construction bug instead of papering it over.
* No `else torch.float32` remains for missing-code fallbacks.

Also tightened `dtype_util.py`:
* `code_to_torch_dtype(unknown_code)` and
  `torch_dtype_code(unsupported_dtype)` now raise `KeyError` listing
  the known set, instead of silently aliasing the unknown to float32.

Verification:
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 273 passed.
* CUDA pytest (same suites + `test_scalars`, `-m "not slow"`) —
  444 passed, 4 xfailed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* hlir: drop unused widening accessors; add typed f16 / bf16 read path

The reviewer flagged the kernel-internal `NativeData::{f32, f16, bf16,
i32, i64, f64, bool}(i)` accessors as silent wideners. After main's
PR #330 made the binary kernels strict on dtype, those accessors
became dead code — `rg` for callers across `src/` and `crates/`
turns up only their own panic strings. Delete all seven.

The bigger silent-widening surface was on the **read** side: the
native backend's `get_output_f32 / get_output_i32 / get_output_bool`
just delegated to `NativeData::to_{f32,i32,bool}_vec()`, which
happily accept any source variant. That's the same "widen on read"
pattern the reviewer's been hammering on for `get_output_i64 /
get_output_f64`. Tighten them with the same match-on-variant +
panic-on-mismatch pattern (`Add a Cast(DType::X) before the Output`).

Tightening the read boundary broke the existing `float16` /
`bfloat16` output paths in `compiled_model.py`, which were dispatching
through the generic f32 getter and `.to(half)`-ing back — relying on
exactly the silent widening we just removed. Add proper typed paths:

* Backend trait: `get_output_f16` / `get_output_bf16` with default
  panic impls (`src/dyn_backend.rs`).
* `NativeDynBackend`: strict match on `F16` / `Bf16` variants.
* `luminal_cuda_lite::CudaRuntime`: pre-existing `get_f16` / `get_bf16`
  reinterpreted bytes without checking dtype — add the same
  buffer-spec strictness as `get_i64` / `get_f64`.
* `CudaLiteDynBackend`: wire `get_output_f16` / `get_output_bf16`
  through.
* `CompiledGraph` (pyo3): new `get_output_f16` / `get_output_bf16`
  methods that return `bytes` (Python has no native f16/bf16);
  caller bit-casts via `torch.frombuffer(..., dtype=torch.float16)`
  / `torch.bfloat16`.
* `compiled_model.py`: dispatch table maps `torch.float16` →
  `get_output_f16` (and same for bf16); the helper bit-casts the
  bytes back, then `.clone()`s so the returned tensor owns its
  storage.

Net effect: every supported read boundary is strict — buffer dtype
must already match the requested width. No silent widening anywhere
in the read path.

Verification:
* `cargo test -p luminal -p luminal_python` — 114 + 9 + 5 passed.
* `cargo clippy -p luminal_python --features cuda --tests
  -- -D warnings` — clean.
* CPU pytest (`test_hlir_ops` + `test_unary` + `test_dtype_boundary`
  + `test_torch_dtype_parity`) — 252 passed, 21 skipped.
* CUDA pytest (same suites + `test_scalars`, `-m "not slow"`) —
  444 passed, 4 xfailed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* compiled_model: use bytearray for f16/bf16 frombuffer to silence warning

The f16/bf16 read path in `_read_typed_output` calls `torch.frombuffer`
on the `bytes` returned by `CompiledGraph::get_output_f16` /
`get_output_bf16`. Python `bytes` is immutable, so PyTorch emits a
`UserWarning` ("The given buffer is not writable... You may want to
copy the buffer to protect its data or make it writable **before
converting** it to a tensor").

That warning's message contains the word "converting", which
`test_dtype_boundary.test_matching_dtype_does_not_raise` catches in
its boundary-warning filter — surfaced in CI as a `[cpu-bfloat16]`
failure on the most recent run.

Wrap the bytes in `bytearray()` before `frombuffer` so the storage is
writable and no warning fires. `bytearray(b)` copies the underlying
bytes once; the returned tensor owns its own storage, so the previous
`.clone()` becomes unnecessary and is removed.

No behavior change. CPU sweep still 252 passed / 21 skipped locally
(verified before push this time).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* ruff format: line-break style on f16/bf16 frombuffer call

Ruff's `pre-commit` hook reformats the multi-line `torch.frombuffer(...)
.reshape(tuple(shape))` chain to break after `.reshape(` instead of
inside `frombuffer(...)`. CI's Ruff Format step flagged it on the
previous commit (`4d882763`). No semantic change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Austin Glover <austin_glover@berekely.edu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-23 00:35:46 -04:00
tucker-luminal
4e93f02725 Tucker/llama3 rsqrt fix (#331) 2026-05-21 21:29:37 -07:00
Ali
25393a9fdd call for allocate_intermediate_buffers is redundant (#321) 2026-05-21 02:10:51 -04:00
Joe Fioti
81ea750e6b cargo examples (#325)
* cargo examples

* Fix commit message generation for diff context

* Generalize GLUMoE search-space checks and harden NaN tests
2026-05-21 02:09:32 -04:00
109 changed files with 9179 additions and 3930 deletions

View File

@@ -1,3 +1,6 @@
[alias]
examples = "run --release --bin examples-perf --"
[target.aarch64-unknown-linux-gnu]
rustflags = [
"-Ctarget-feature=+fp16,+fhm"

View File

@@ -64,4 +64,4 @@ jobs:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 14400 tests/ -v -s -m slow
run: modal run modal_pytest_runner.py --gpu A100-80GB --timeout 14400 tests/ -v -s -m slow

View File

@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
let c = a.matmul(b).output();
// Compile
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

185
ci/examples_perf.py Normal file
View File

@@ -0,0 +1,185 @@
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field
from example_output import validate_output
DEFAULT_EXAMPLES = ["llama", "gemma", "qwen", "qwen3_moe", "gemma4_moe", "whisper"]
EXAMPLE_CARGO_ARGS = {
"llama": ["run", "--release", "-p", "llama"],
"gemma": ["run", "--release", "-p", "gemma"],
"qwen": ["run", "--release", "-p", "qwen", "--features", "cuda"],
"qwen3_moe": ["run", "--release", "-p", "qwen3_moe"],
"gemma4_moe": ["run", "--release", "-p", "gemma4_moe"],
"whisper": ["run", "--release", "-p", "whisper"],
}
@dataclass
class Metrics:
ttft_ms: float | None = None
tpot_ms: float | None = None
tps: float | None = None
@dataclass
class ExampleResult:
name: str
ok: bool
metrics: Metrics = field(default_factory=Metrics)
wall_s: float = 0.0
error: str | None = None
def main() -> None:
args = [arg for arg in sys.argv[1:] if arg != "--"]
if any(arg in {"-h", "--help"} for arg in args):
print_help()
return
if "--list" in args:
print("\n".join(DEFAULT_EXAMPLES))
return
examples = args or DEFAULT_EXAMPLES
results = [run_example(example) for example in examples]
print_table(results)
if any(not result.ok for result in results):
raise SystemExit(1)
def print_help() -> None:
print(
"Run validated Luminal examples, validate textual output, and summarize perf.\n"
"\n"
"Usage:\n"
" cargo examples\n"
" cargo examples llama qwen whisper\n"
"\n"
"Options:\n"
" --list Print the default validated examples\n"
" -h, --help\n"
"\n"
f"The default set matches the Modal examples CI: {', '.join(DEFAULT_EXAMPLES)}."
)
def run_example(example: str) -> ExampleResult:
cargo_args = EXAMPLE_CARGO_ARGS.get(example)
if cargo_args is None:
known = ", ".join(DEFAULT_EXAMPLES)
return ExampleResult(example, False, error=f"unknown example; known examples: {known}")
print(f"\n=== Running {example} ===")
print(f"$ cargo {' '.join(cargo_args)}")
started = time.monotonic()
env = os.environ.copy()
env.setdefault("CUDARC_CUDA_VERSION", "12080")
process = subprocess.Popen(
["cargo", *cargo_args],
cwd=repo_root(),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
chunks: list[bytes] = []
while True:
chunk = process.stdout.read1(4096)
if not chunk:
break
sys.stdout.buffer.write(chunk)
sys.stdout.buffer.flush()
chunks.append(chunk)
return_code = process.wait()
output = b"".join(chunks).decode("utf-8", errors="replace")
wall_s = time.monotonic() - started
metrics = parse_metrics(output)
if return_code:
return ExampleResult(
example,
False,
metrics=metrics,
wall_s=wall_s,
error=f"process exited with code {return_code}",
)
try:
validate_output(example, output)
except Exception as exc:
return ExampleResult(example, False, metrics=metrics, wall_s=wall_s, error=str(exc))
return ExampleResult(example, True, metrics=metrics, wall_s=wall_s)
def repo_root() -> str:
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def parse_metrics(output: str) -> Metrics:
metrics = Metrics()
for line in output.splitlines():
if "TTFT:" in line:
metrics.ttft_ms = parse_number_after(line, "TTFT:")
if "TPOT:" in line:
metrics.tpot_ms = parse_number_after(line, "TPOT:")
if "tok/s" in line:
metrics.tps = parse_tok_per_second(line)
if metrics.tps is None and metrics.tpot_ms:
metrics.tps = 1000.0 / metrics.tpot_ms
return metrics
def parse_number_after(line: str, marker: str) -> float | None:
tail = line.split(marker, 1)[1].lstrip()
chars = []
for char in tail:
if char.isdigit() or char == ".":
chars.append(char)
else:
break
if not chars:
return None
return float("".join(chars))
def parse_tok_per_second(line: str) -> float | None:
head = line.split("tok/s", 1)[0].rstrip(" (")
parts = head.split()
if not parts:
return None
try:
return float(parts[-1])
except ValueError:
return None
def print_table(results: list[ExampleResult]) -> None:
print("\nSummary")
print(f"{'example':<14} {'status':<8} {'TTFT ms':>10} {'TPOT ms':>10} {'tok/s':>10} {'wall s':>10}")
print("-" * 68)
for result in results:
status = "ok" if result.ok else "failed"
print(
f"{result.name:<14} {status:<8} "
f"{format_metric(result.metrics.ttft_ms):>10} "
f"{format_metric(result.metrics.tpot_ms):>10} "
f"{format_metric(result.metrics.tps):>10} "
f"{result.wall_s:>10.1f}"
)
if result.error:
print(f" error: {result.error}")
def format_metric(value: float | None) -> str:
return "-" if value is None else f"{value:.2f}"
if __name__ == "__main__":
main()

View File

@@ -39,7 +39,7 @@ fn run_metal_pattern_benchmark(
let mut cx = Graph::default();
pattern.build_graph(&mut cx, *size);
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let mut rng = rand::rng();
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
}
}
let mut rt = cx.search(rt, 5);
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
let mut bench_metrics = None;

View File

@@ -41,7 +41,7 @@ struct PreparedBench {
#[cfg(feature = "metal")]
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let mut rng = rand::rng();
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
rt.set_data(*node, &data);
}
let rt = cx.search(rt, 5);
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
Some(PreparedBench {
rt,

View File

@@ -41,7 +41,7 @@ mod metal_backend {
const NAME: &'static str = "Metal";
fn build_search_space(cx: &mut Graph) {
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
}
}
}

View File

@@ -29,9 +29,21 @@ impl DynBackend for CudaLiteDynBackend {
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
self.runtime.get_f32(node)
}
fn get_output_f16(&self, node: NodeIndex) -> Vec<half::f16> {
self.runtime.get_f16(node)
}
fn get_output_bf16(&self, node: NodeIndex) -> Vec<half::bf16> {
self.runtime.get_bf16(node)
}
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
self.runtime.get_i32(node)
}
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
self.runtime.get_i64(node)
}
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
self.runtime.get_f64(node)
}
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
self.runtime.get_bool(node)
}

View File

@@ -1,258 +0,0 @@
use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND, STRING},
extract_expr,
},
op::{EgglogOp, LLIROp},
prelude::{
tracing::{Level, span, trace},
*,
},
};
use crate::{
cudarc::{
cublas::{
CudaBlas,
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
},
driver::CudaStream,
},
host::{DeviceBuffer, HostOp},
};
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
// Strip quotes if present (egglog strings are stored with quotes)
let stripped = s.trim_matches('"');
match stripped {
"T" => cublasOperation_t::CUBLAS_OP_T,
"N" => cublasOperation_t::CUBLAS_OP_N,
"C" => cublasOperation_t::CUBLAS_OP_C,
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct CuBlasSgemmV2 {
m: Expression,
n: Expression,
k: Expression,
a_layout: cublasOperation_t,
b_layout: cublasOperation_t,
lda: Expression,
ldb: Expression,
ldc: Expression,
/// Lazily initialized cuBLAS handle - created on first execute
cublas: OnceLock<Arc<CudaBlas>>,
}
// Useless default for IntoEgglogOp
impl Default for CuBlasSgemmV2 {
fn default() -> Self {
Self {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
cublas: OnceLock::new(),
}
}
}
impl EgglogOp for CuBlasSgemmV2 {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"cublasSgemmV2",
&[
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
("a_layout", STRING),
("b_layout", STRING),
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
]
}
#[allow(unused_variables)]
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
let extracted_state = Self {
m,
n,
k,
a_layout,
b_layout,
lda,
ldb,
ldc,
cublas: OnceLock::new(),
};
trace!(?extracted_state);
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, input_enodes)
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for CuBlasSgemmV2 {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
let m = self.m.exec(dyn_map).unwrap() as i32;
let n = self.n.exec(dyn_map).unwrap() as i32;
let k = self.k.exec(dyn_map).unwrap() as i32;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = self.lda.exec(dyn_map).unwrap() as i32;
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
let alpha = 1.0f32;
let beta = 0.0f32;
// Get buffers: output is self_node, inputs are from graph edges
let c_buf = buffers[&self_node];
let a_buf = buffers[&inputs[0]];
let b_buf = buffers[&inputs[1]];
// Get device pointers
let a_ptr = a_buf.ptr();
let b_ptr = b_buf.ptr();
let c_ptr = c_buf.ptr();
// Debug: Check buffer sizes
trace!(
"buffer_validation {}=={},{}=={},{}=={}",
a_buf.len(),
m * k * 4,
b_buf.len(),
k * n * 4,
c_buf.len(),
m * n * 4
);
let _sgemm_span = span!(
Level::TRACE,
"cuBLAS_SGEMM_V2",
m,
n,
k,
alpha,
beta,
lda,
ldb,
ldc,
?a_layout,
?b_layout,
)
.entered();
// Use shared cuBLAS handle to avoid per-operation workspace allocation
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
// Set the stream for this operation (cuBLAS handle can work with any stream)
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
unsafe {
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
}
let status = unsafe {
cublasSgemm_v2(
*cublas.handle(),
a_layout,
b_layout,
m,
n,
k,
&alpha as *const f32,
a_ptr as *const f32,
lda,
b_ptr as *const f32,
ldb,
&beta as *const f32,
c_ptr as *mut f32,
ldc,
)
};
stream.synchronize().unwrap();
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(anyhow::anyhow!(
"cuBLAS SGEMM TN failed with status: {:?}",
status
));
}
Ok(())
}
fn output_size(&self) -> Expression {
self.m * self.n
}
fn output_bytes(&self) -> Expression {
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
self.output_size() * 4
}
}

View File

@@ -1,73 +0,0 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublas sgemm column-major × column-major"
)

View File

@@ -1,73 +0,0 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublas sgemm column-major × row-major"
)

View File

@@ -1,73 +0,0 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublas sgemm row-major × column-major"
)

View File

@@ -1,73 +0,0 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:ruleset matmul_backend
:name "cublas sgemm row-major"
)

View File

@@ -11,11 +11,13 @@
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match the generic matmul produced from Mul -> Sum.
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -77,8 +79,12 @@
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))

View File

@@ -11,11 +11,13 @@
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match the generic matmul produced from Mul -> Sum.
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -77,8 +79,12 @@
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))

View File

@@ -11,11 +11,13 @@
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match the generic matmul produced from Mul -> Sum.
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -77,8 +79,12 @@
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))

View File

@@ -11,11 +11,13 @@
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match the generic matmul produced from Mul -> Sum.
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -79,8 +81,12 @@
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
; Output shape: [batch, m, n]
(= ?batch (nth_from_end ?out_shape 2))

View File

@@ -25,8 +25,12 @@
(ICons ?input_scale (INil))))
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
(ICons ?input_scale (ICons ?weight_scale (INil)))))
@@ -96,8 +100,12 @@
(ICons ?input_scale (INil))))
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
(ICons ?input_scale (ICons ?weight_scale (INil)))))
@@ -368,8 +376,12 @@
(ICons ?input_scale (INil))))
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
(ICons ?input_scale (ICons ?weight_scale (INil)))))
@@ -440,8 +452,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -489,8 +505,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -538,8 +558,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
@@ -587,8 +611,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))
@@ -650,8 +678,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))
@@ -713,8 +745,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
(= ?batch (nth_from_end ?out_shape 2))

View File

@@ -5,8 +5,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
@@ -54,8 +58,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
@@ -103,8 +111,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
@@ -152,8 +164,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
@@ -201,8 +217,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
@@ -264,8 +284,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
@@ -327,8 +351,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
@@ -390,8 +418,12 @@
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?out_shape ?mul_shape ?k
?a_stride ?b_stride
?sum_in_stride ?k_stride ?sum_out_stride
?matmul_dtype)
(ICons ?a (ICons ?b (INil)))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))

View File

@@ -1,4 +1,7 @@
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
use half::{bf16, f16};
use luminal::{
@@ -15,6 +18,8 @@ use luminal::{
},
};
#[cfg(test)]
use crate::kernel::CudaGraphHandle;
use crate::{
cudarc::{
cublas::sys::cublasOperation_t,
@@ -33,12 +38,22 @@ use crate::{
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
},
},
driver::{CudaStream, DevicePtr},
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
host::{DeviceBuffer, HostOp},
try_create_cublaslt,
};
fn parse_cublas_op(s: &str) -> cublasOperation_t {
let stripped = s.trim_matches('"');
match stripped {
"T" => cublasOperation_t::CUBLAS_OP_T,
"N" => cublasOperation_t::CUBLAS_OP_N,
"C" => cublasOperation_t::CUBLAS_OP_C,
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct CuBlasLt {
@@ -189,50 +204,50 @@ impl EgglogOp for CuBlasLt {
Rule::raw(include_str!["cublaslt_beta_rewrite.egg"]),
Rule::raw(include_str!["cublaslt_epilogue_rewrite.egg"]),
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
// Delete the matmul-broadcast Mul eclass when the consuming Sum
// eclass has a `cublaslt` or `KernelBatchMatMul` alternative. The
// cuBLASLt / batched-matmul rewrite rules only union those enodes
// into the Sum eclass after the broadcast pattern check passes,
// so their presence is the matmul-broadcast signal — no further
// stride-form check needed.
//
// Delete the HLIR `Mul` fallback from the Mul eclass. Emptying that
// eclass lets the empty-eclass cascade prune the downstream Sum /
// KernelSum fallback. cuBLAS, TileMatmulFullSplit, KernelBatchMatVec,
// and KernelBatchMatMul all take original (a, b) inputs rather than
// the Mul eclass, so they survive the cascade and remain as the
// matmul output alternative.
// cuBLASLt now specializes GenericMatmul, so cleanup should prune
// the matmul output alternatives directly. Do not delete the
// broadcast Mul here; it may still have non-matmul consumers.
Rule::raw("(rule
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs)))
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
:ruleset cleanup
:name \"delete-sum-when-cublaslt-exists\"
)"),
Rule::raw("(rule
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs)))
:ruleset cleanup
)"),
// Also remove any generic fusion wrapper that was unioned with the
// broadcast Mul. This is deliberately a separate rule: requiring a
// FusionEnd in the same eclass made cleanup miss valid cuBLASLt
// matmuls when fusion wrapping was absent.
Rule::raw("(rule
((= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
((delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
:ruleset cleanup
:name \"delete-kernel-sum-when-cublaslt-exists\"
)"),
Rule::raw("(rule
((= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
:ruleset cleanup
:name \"delete-sum-when-scaled-cublaslt-exists\"
)"),
Rule::raw("(rule
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
:ruleset cleanup
:name \"delete-kernel-sum-when-scaled-cublaslt-exists\"
)"),
Rule::raw("(rule
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
:ruleset cleanup
:name \"prefer-cublaslt-over-generic-matmul\"
)"),
Rule::raw("(rule
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
:ruleset cleanup
:name \"prefer-scaled-cublaslt-over-generic-matmul\"
)"),
]
}
@@ -571,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
}
}
#[derive(Debug, Clone, Copy)]
enum LtScalar {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum LtScalar {
F64(f64),
F32(f32),
F16(f16),
@@ -612,16 +627,16 @@ impl LtScalar {
}
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulProblem {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulProblem {
m: u64,
n: u64,
k: u64,
batch_count: i32,
}
#[derive(Debug, Clone, Copy)]
struct LtMatrixSpec {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatrixSpec {
dtype: cudaDataType,
rows: u64,
cols: u64,
@@ -630,8 +645,8 @@ struct LtMatrixSpec {
order: cublasLtOrder_t,
}
#[derive(Debug, Clone, Copy)]
struct LtComputeSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtComputeSpec {
compute_type: cublasComputeType_t,
scale_dtype: cudaDataType,
alpha: LtScalar,
@@ -639,8 +654,8 @@ struct LtComputeSpec {
epilogue: cublasLtEpilogue_t,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulSpec {
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct LtMatmulSpec {
problem: LtMatmulProblem,
trans_a: cublasOperation_t,
trans_b: cublasOperation_t,
@@ -652,8 +667,8 @@ struct LtMatmulSpec {
workspace_size: usize,
}
#[derive(Debug, Clone, Copy)]
struct LtMatmulPointers {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LtMatmulPointers {
a: u64,
b: u64,
c: u64,
@@ -663,7 +678,35 @@ struct LtMatmulPointers {
b_scale: Option<u64>,
}
struct LtRawDescriptors {
impl LtMatmulPointers {
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
let mut fields = Vec::new();
if self.a != other.a {
fields.push("a");
}
if self.b != other.b {
fields.push("b");
}
if self.c != other.c {
fields.push("c");
}
if self.d != other.d {
fields.push("d");
}
if self.bias != other.bias {
fields.push("bias");
}
if self.a_scale != other.a_scale {
fields.push("a_scale");
}
if self.b_scale != other.b_scale {
fields.push("b_scale");
}
fields
}
}
pub(crate) struct LtRawDescriptors {
matmul_desc: cublasLtMatmulDesc_t,
a_desc: cublasLtMatrixLayout_t,
b_desc: cublasLtMatrixLayout_t,
@@ -672,6 +715,23 @@ struct LtRawDescriptors {
preference: cublasLtMatmulPreference_t,
}
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
> = OnceLock::new();
#[cfg(test)]
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
}
#[cfg(test)]
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
}
impl Default for LtRawDescriptors {
fn default() -> Self {
Self {
@@ -710,6 +770,121 @@ impl Drop for LtRawDescriptors {
}
}
pub(crate) struct PreparedCuBlasLtMatmul {
cublaslt: Arc<CudaBlasLT>,
spec: LtMatmulSpec,
resources: LtRawDescriptors,
heuristic: cublasLtMatmulHeuristicResult_t,
_workspace: CudaSlice<u8>,
workspace_ptr: u64,
_a_scale: Option<CudaSlice<f32>>,
default_a_scale_ptr: Option<u64>,
_b_scale: Option<CudaSlice<f32>>,
default_b_scale_ptr: Option<u64>,
_c_scale: Option<CudaSlice<f32>>,
_d_scale: Option<CudaSlice<f32>>,
}
impl PreparedCuBlasLtMatmul {
fn update_descriptor_pointers(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
stream.context().bind_to_thread()?;
if let Some(bias_ptr) = ptrs.bias {
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias_ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
ptr,
)?;
}
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
})?;
set_scalar_scale_pointer(
self.resources.matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
ptr,
)?;
}
Ok(())
}
pub(crate) fn enqueue(
&self,
stream: &Arc<CudaStream>,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
self.update_descriptor_pointers(stream, ptrs)?;
let alpha_ptr = self.spec.compute.alpha.as_ptr();
let beta_ptr = self.spec.compute.beta.as_ptr();
unsafe {
cublasLtMatmul(
*self.cublaslt.handle(),
self.resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
self.resources.a_desc,
ptrs.b as *const std::ffi::c_void,
self.resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
self.resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
self.resources.d_desc,
&self.heuristic.algo,
self.workspace_ptr as *mut std::ffi::c_void,
self.spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtCaptureSignature {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtPrepareKey {
spec: LtMatmulSpec,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct CuBlasLtResolvedGraphCall {
pub(crate) spec: LtMatmulSpec,
pub(crate) ptrs: LtMatmulPointers,
}
impl CuBlasLtResolvedGraphCall {
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
CuBlasLtCaptureSignature {
spec: self.spec,
ptrs: self.ptrs,
}
}
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
CuBlasLtPrepareKey { spec: self.spec }
}
}
fn create_matrix_layout(
desc: &mut cublasLtMatrixLayout_t,
spec: LtMatrixSpec,
@@ -786,12 +961,15 @@ fn set_scalar_scale_pointer(
Ok(())
}
fn run_cublaslt_matmul(
pub(crate) fn prepare_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
#[cfg(test)]
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
return Err(anyhow::anyhow!(
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
@@ -803,17 +981,17 @@ fn run_cublaslt_matmul(
let mut resources = LtRawDescriptors::default();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
drop(workspace_guard);
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
};
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
Some(stream.clone_htod(&[1.0f32])?)
} else {
None
@@ -869,29 +1047,27 @@ fn run_cublaslt_matmul(
}
}
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
(Some(ptr), None)
} else if let Some(scale) = &a_scale {
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
(Some(ptr), None)
} else if let Some(scale) = &b_scale {
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
(None, None)
};
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
let (ptr, guard) = scale.device_ptr(stream);
(Some(ptr), Some(guard))
} else {
@@ -925,6 +1101,7 @@ fn run_cublaslt_matmul(
ptr,
)?;
}
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
create_matrix_layout(&mut resources.a_desc, spec.a)?;
create_matrix_layout(&mut resources.b_desc, spec.b)?;
@@ -942,58 +1119,148 @@ fn run_cublaslt_matmul(
}
}
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
let cached_heuristic = {
let cache = heuristic_cache.lock().unwrap();
cache
.iter()
.find(|(cached_spec, _)| cached_spec == spec)
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
};
if let Some(cached) = cached_heuristic {
heuristic = cached;
} else {
let mut algo_count: i32 = 0;
unsafe {
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
resources.preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&spec.workspace_size as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
resources.matmul_desc,
resources.a_desc,
resources.b_desc,
resources.c_desc,
resources.d_desc,
resources.preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
if algo_count == 0 {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
}
let alpha_ptr = spec.compute.alpha.as_ptr();
let beta_ptr = spec.compute.beta.as_ptr();
cublasLtMatmul(
*cublaslt.handle(),
resources.matmul_desc,
alpha_ptr,
ptrs.a as *const std::ffi::c_void,
resources.a_desc,
ptrs.b as *const std::ffi::c_void,
resources.b_desc,
beta_ptr,
ptrs.c as *const std::ffi::c_void,
resources.c_desc,
ptrs.d as *mut std::ffi::c_void,
resources.d_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
spec.workspace_size,
stream.cu_stream() as *mut _,
)
.result()?;
heuristic_cache
.lock()
.unwrap()
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
}
Ok(())
Ok(PreparedCuBlasLtMatmul {
cublaslt: cublaslt.clone(),
spec: *spec,
resources,
heuristic,
_workspace: workspace,
workspace_ptr,
_a_scale: a_scale,
default_a_scale_ptr,
_b_scale: b_scale,
default_b_scale_ptr,
_c_scale: c_scale,
_d_scale: d_scale,
})
}
fn run_cublaslt_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
spec: &LtMatmulSpec,
ptrs: LtMatmulPointers,
) -> anyhow::Result<()> {
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
prepared.enqueue(stream, ptrs)
}
#[cfg(test)]
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
let capture_stream = stream.context().new_stream()?;
let cublaslt = try_create_cublaslt(stream.clone())
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
let a_buf = stream.clone_htod(&[1.0f32])?;
let b_buf = stream.clone_htod(&[1.0f32])?;
let d_buf = unsafe { stream.alloc::<f32>(1)? };
let (a, a_guard) = a_buf.device_ptr(stream);
let (b, b_guard) = b_buf.device_ptr(stream);
let (d, d_guard) = d_buf.device_ptr(stream);
drop((a_guard, b_guard, d_guard));
let matrix = LtMatrixSpec {
dtype: cudaDataType::CUDA_R_32F,
rows: 1,
cols: 1,
ld: 1,
batch_stride: 1,
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m: 1,
n: 1,
k: 1,
batch_count: 1,
},
trans_a: cublasOperation_t::CUBLAS_OP_N,
trans_b: cublasOperation_t::CUBLAS_OP_N,
a: matrix,
b: matrix,
c: matrix,
d: matrix,
compute: LtComputeSpec {
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
scale_dtype: cudaDataType::CUDA_R_32F,
alpha: LtScalar::F32(1.0),
beta: LtScalar::F32(0.0),
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
},
workspace_size: 1024 * 1024,
};
let ptrs = LtMatmulPointers {
a,
b,
c: d,
d,
bias: None,
a_scale: None,
b_scale: None,
};
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
let entry = graph.add_empty_node(&[])?;
capture_stream.join(stream)?;
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
let end_result = graph.end_capture(&capture_stream);
enqueue_result?;
end_result?;
Ok(())
}
let supported = probe(stream).is_ok();
let _ = stream.synchronize();
supported
}
fn resolve_cublaslt_pointers(
@@ -1116,6 +1383,151 @@ impl CuBlasLt {
Ok(created)
}
pub(crate) fn graph_inputs(&self) -> usize {
self.n_inputs()
}
pub(crate) fn resolve_for_graph(
&self,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
let element_size = (self.d_dtype.bits() / 8) as u64;
assert!(
element_size > 0,
"cuBLAS LT does not support sub-byte dtype {}",
self.d_dtype
);
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
let ptrs = resolve_cublaslt_pointers(
self_node,
inputs,
buffers,
self.beta,
self.epilogue,
self.a_scale_input,
self.b_scale_input,
)?;
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
let _span = span!(
Level::TRACE,
"cuBLASLT_resolve_graph",
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
?self.epilogue,
)
.entered();
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
let c_spec = LtMatrixSpec {
dtype: c_cuda_dtype,
rows: m,
cols: n,
ld: ldc,
batch_stride: stride_c,
order: self.c_order,
};
let d_spec = LtMatrixSpec {
dtype: d_cuda_dtype,
rows: m,
cols: n,
ld: ldd,
batch_stride: stride_d,
order: self.d_order,
};
let spec = LtMatmulSpec {
problem: LtMatmulProblem {
m,
n,
k,
batch_count,
},
trans_a: a_layout,
trans_b: b_layout,
a: LtMatrixSpec {
dtype: a_cuda_dtype,
rows: a_rows,
cols: a_cols,
ld: lda,
batch_stride: stride_a,
order: self.a_order,
},
b: LtMatrixSpec {
dtype: b_cuda_dtype,
rows: b_rows,
cols: b_cols,
ld: ldb,
batch_stride: stride_b,
order: self.b_order,
},
c: c_spec,
d: d_spec,
compute: LtComputeSpec {
compute_type: self.compute_type,
scale_dtype: scale_cuda_dtype,
alpha,
beta,
epilogue: self.epilogue,
},
workspace_size: WORKSPACE_SIZE,
};
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
}
pub(crate) fn prepare_resolved_for_graph(
&self,
stream: &Arc<CudaStream>,
resolved: CuBlasLtResolvedGraphCall,
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
let cublaslt = self.get_cublaslt(stream)?;
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
}
#[cfg(test)]
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
(

View File

@@ -2,13 +2,11 @@ use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaStream, DriverError, result};
use luminal::{op::EgglogOp, prelude::*};
mod cublas;
mod cublaslt;
pub(crate) mod cublaslt;
pub mod flashinfer;
pub mod moe;
pub type Ops = (
// cublas::CuBlasSgemmV2,
cublaslt::CuBlasLt,
cublaslt::CuBlasLtScaled,
moe::GLUMoE,
@@ -169,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
None
}
/// Returns pairs of extra buffer nodes that must not share arena storage.
///
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
/// two buffers may have disjoint positions in one topological order while
/// still being unordered by real dependencies, so CUDA could overlap them.
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
None
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.

View File

@@ -7,7 +7,10 @@ use std::sync::Arc;
use cudarc::driver::{
CudaContext, CudaFunction, CudaStream, DriverError,
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
sys::{
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
},
};
/// A CUDA graph that can be modified and instantiated.
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
}
}
/// Updates a kernel node in the mutable source graph.
pub unsafe fn set_kernel_node_params(
&mut self,
node: CUgraphNode,
func: CUfunction,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
shared_mem_bytes: u32,
kernel_params: *mut *mut c_void,
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let params = sys::CUDA_KERNEL_NODE_PARAMS {
func,
gridDimX: grid_dim.0,
gridDimY: grid_dim.1,
gridDimZ: grid_dim.2,
blockDimX: block_dim.0,
blockDimY: block_dim.1,
blockDimZ: block_dim.2,
sharedMemBytes: shared_mem_bytes,
kernelParams: kernel_params,
extra: std::ptr::null_mut(),
kern: std::ptr::null_mut(),
ctx: std::ptr::null_mut(),
};
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, &params).result() }
}
/// Adds an empty dependency node to the graph.
pub fn add_empty_node(
&mut self,
dependencies: &[CUgraphNode],
) -> Result<CUgraphNode, DriverError> {
self.ctx.bind_to_thread()?;
let mut node = MaybeUninit::uninit();
unsafe {
sys::cuGraphAddEmptyNode(
node.as_mut_ptr(),
self.cu_graph,
dependencies.as_ptr(),
dependencies.len(),
)
.result()?;
Ok(node.assume_init())
}
}
/// Destroys a node in the mutable graph.
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { sys::cuGraphDestroyNode(node).result() }
}
/// Adds dependency edges to the mutable graph.
pub fn add_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Removes dependency edges from the mutable graph.
pub fn remove_dependencies(
&mut self,
from: &[CUgraphNode],
to: &[CUgraphNode],
) -> Result<(), DriverError> {
assert_eq!(from.len(), to.len());
self.ctx.bind_to_thread()?;
unsafe {
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
}
.result()
}
/// Returns all nodes currently in the graph.
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut nodes = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
}
nodes.truncate(count);
Ok(nodes)
}
/// Returns the direct dependencies of a node.
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Returns the direct dependents of a node.
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
self.ctx.bind_to_thread()?;
let mut count = 0usize;
unsafe {
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
}
if count == 0 {
return Ok(Vec::new());
}
let mut deps = vec![std::ptr::null_mut(); count];
unsafe {
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
}
deps.truncate(count);
Ok(deps)
}
/// Begins stream capture that appends captured work into this graph.
pub fn begin_capture_to_graph(
&mut self,
stream: &CudaStream,
dependencies: &[CUgraphNode],
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe {
sys::cuStreamBeginCaptureToGraph(
stream.cu_stream(),
self.cu_graph,
dependencies.as_ptr(),
std::ptr::null(),
dependencies.len(),
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
)
.result()
}
}
/// Ends stream capture previously started by begin_capture_to_graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut graph = MaybeUninit::uninit();
unsafe {
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
let captured = graph.assume_init();
if captured != self.cu_graph && !captured.is_null() {
sys::cuGraphDestroy(captured).result()?;
}
}
Ok(())
}
/// Adds an event record node to the graph for timing.
pub fn add_event_record_node(
&mut self,
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, &params) }
.result()
}
/// Attempts to update this executable graph from an already-mutated source graph.
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let mut result = CUgraphExecUpdateResultInfo {
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
errorNode: std::ptr::null_mut(),
errorFromNode: std::ptr::null_mut(),
};
unsafe {
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
}
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
return Err(DriverError(
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
));
}
Ok(())
}
}
impl Drop for CudaGraphExecHandle {
@@ -480,6 +672,38 @@ mod tests {
assert_eq!(result[0], 6.0f32);
}
#[test]
fn test_graph_empty_node_dependency_reconnect() {
let Ok(ctx) = CudaContext::new(0) else { return };
let mut graph = CudaGraphHandle::new(ctx).unwrap();
let entry = graph.add_empty_node(&[]).unwrap();
let middle = graph.add_empty_node(&[entry]).unwrap();
let exit = graph.add_empty_node(&[middle]).unwrap();
let nodes = graph.nodes().unwrap();
assert!(nodes.contains(&entry));
assert!(nodes.contains(&middle));
assert!(nodes.contains(&exit));
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
graph.add_dependencies(&[entry], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert!(exit_deps.contains(&entry));
assert!(exit_deps.contains(&middle));
graph.remove_dependencies(&[middle], &[exit]).unwrap();
let exit_deps = graph.dependencies(exit).unwrap();
assert_eq!(exit_deps.len(), 1);
assert!(exit_deps.contains(&entry));
unsafe {
graph.destroy_node(middle).unwrap();
}
assert!(!graph.nodes().unwrap().contains(&middle));
}
// CUDA Graph Tests
#[test]
@@ -498,8 +722,8 @@ mod tests {
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
rt.execute(&cx.dyn_map);
@@ -530,8 +754,8 @@ mod tests {
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
let mut results = Vec::new();
for _ in 0..5 {
rt.execute(&cx.dyn_map);
@@ -568,8 +792,8 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
.iter()
@@ -610,8 +834,8 @@ mod tests {
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
@@ -641,8 +865,8 @@ mod tests {
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
for _ in 0..10 {
rt.execute(&cx.dyn_map);
}
@@ -674,8 +898,8 @@ mod tests {
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Initial execution
rt.execute(&cx.dyn_map);

View File

@@ -89,6 +89,21 @@ impl EgglogOp for CudaUnaryElementwise {
)));
}
rules.push(Rule::raw(
"(rule (
(= ?sqrt (Op (Sqrt ?shape ?x_stride ?sqrt_stride) (ICons ?x (INil))))
(= ?recip (Op (Recip ?shape ?sqrt_stride ?out_stride) (ICons ?sqrt (INil))))
(= ?dt (dtype ?recip))
) (
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
(let ?elem (Op (CudaUnaryElementwise \"Rsqrt\" ?shape ?x_stride ?out_stride ?dt)
(ICons ?fs (INil))))
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
(union ?recip ?fe)
(set (dtype ?fe) ?dt)
) :ruleset kernel_lower :name \"cuda-elem-rsqrt-from-sqrt-recip\")",
));
rules.push(Rule::raw(
"(rule
(

View File

@@ -309,6 +309,61 @@ impl EgglogOp for FusionEnd {
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
// correctly without dissolve.
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
(!= ?fe_shape ?inner_shape)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-unary-shape\")",
));
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
(!= ?fe_s ?inner_s)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-unary-strides\")",
));
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
(!= ?fe_shape ?inner_shape)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-binary-shape\")",
));
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
(!= ?fe_s ?inner_s)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-binary-strides\")",
));
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
(!= ?fe_shape ?inner_shape)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-nested-shape\")",
));
rules.push(Rule::raw(
"(rule (
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
(!= ?fe_s ?inner_s)
) (
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
) :ruleset cleanup :name \"delete-malformed-FE-nested-strides\")",
));
rules
}

View File

@@ -358,6 +358,7 @@ fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
match op {
"Sin" => format!("sinf({})", a()),
"Sqrt" => format!("sqrtf({})", a()),
"Rsqrt" => format!("rsqrtf({})", a()),
"Exp" => format!("expf({})", a()),
"Exp2" => format!("exp2f({})", a()),
"Log2" => format!("log2f({})", a()),

View File

@@ -0,0 +1,319 @@
use std::sync::Arc;
use crate::{
compile_module_image_for_current_device, cuda_dtype,
kernel::{
KernelOp,
hlir::{dtype_includes, generate_dyn_dims_defines},
},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
extract_dtype, extract_expr, extract_expr_list,
},
op::*,
prelude::*,
shape::flatten_strides,
};
#[derive(Default, Debug, Clone)]
pub struct GenericMatmul {
out_shape: Vec<Expression>,
mul_shape: Vec<Expression>,
k: Expression,
lhs_strides: Vec<Expression>,
rhs_strides: Vec<Expression>,
sum_input_strides: Vec<Expression>,
sum_iter_stride: Expression,
out_strides: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for GenericMatmul {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"GenericMatmul",
&[
("out_shape", ELIST),
("mul_shape", ELIST),
("k", EXPRESSION),
("lhs_strides", ELIST),
("rhs_strides", ELIST),
("sum_input_strides", ELIST),
("sum_iter_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(
"(rule
(
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
(ICons ?lhs (ICons ?rhs (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
(ICons ?mul (INil))))
(= ?dt (dtype ?sum))
)
(
(let ?generic (Op (GenericMatmul
?out_shape
?mul_shape
?k
?lhs_strides
?rhs_strides
?sum_input_strides
?sum_iter_stride
?out_strides
?dt)
(ICons ?lhs (ICons ?rhs (INil)))))
(union ?sum ?generic)
(set (dtype ?generic) ?dt)
)
:ruleset matmul_backend
:name \"generic-matmul-cuda-mul-sum\"
)",
),
Rule::raw(
"(rule
(
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
(ICons ?lhs (ICons ?rhs (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
(ICons ?mul (INil))))
(= ?sum (Op (GenericMatmul
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
?generic_inputs))
)
(
(delete (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
(ICons ?mul (INil))))
)
:ruleset cleanup
:name \"delete-sum-when-generic-matmul-exists\"
)",
),
Rule::raw(
"(rule
(
(= ?kernel_sum (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
?sum_inputs))
(= ?kernel_sum (Op (GenericMatmul
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
?generic_inputs))
)
((delete (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
?sum_inputs)))
:ruleset cleanup
:name \"delete-kernel-sum-when-generic-matmul-exists\"
)",
),
]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
mul_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
.unwrap(),
k: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
lhs_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
.unwrap(),
rhs_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
.unwrap(),
sum_input_strides: extract_expr_list(
egraph,
kind_children[5],
list_cache,
expr_cache,
)
.unwrap(),
sum_iter_stride: extract_expr(egraph, kind_children[6], expr_cache).unwrap(),
out_strides: extract_expr_list(egraph, kind_children[7], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[8]),
})),
input_enodes,
)
}
}
impl KernelOp for GenericMatmul {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self.all_dyn_vars();
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_outputs = self.output_size();
let sum_base_idx = flatten_strides(&self.out_shape, &self.sum_input_strides).to_kernel();
let iter_offset = self.sum_iter_stride.to_kernel().replace("const_z", "i");
let lhs_idx = flatten_strides(&self.mul_shape, &self.lhs_strides)
.to_kernel()
.replace("const_z", "mul_idx");
let rhs_idx = flatten_strides(&self.mul_shape, &self.rhs_strides)
.to_kernel()
.replace("const_z", "mul_idx");
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
let k = self.k.to_kernel();
let kernel = format!(
"{includes}
#define WARP_SIZE 32
#define THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
{dyn_defines}
extern \"C\" {{
__global__ void generic_matmul({dtype} *out, const {dtype} *lhs, const {dtype} *rhs{dyn_dims_param}) {{
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
if (const_z >= {n_outputs}) return;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
long long base_idx = {sum_base_idx};
long long iters = {k};
float partial = 0.0f;
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
long long mul_idx = base_idx + {iter_offset};
partial += static_cast<float>(lhs[{lhs_idx}]) * static_cast<float>(rhs[{rhs_idx}]);
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {{
partial += __shfl_down_sync(FULL_MASK, partial, s);
}}
if (lane_id == 0) {{
warp_sums[warp_id] = partial;
}}
__syncthreads();
if (warp_id == 0) {{
float block_sum = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? warp_sums[tid] : 0.0f;
#pragma unroll
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s >>= 1) {{
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
}}
if (tid == 0) {{
out[{out_idx}] = ({dtype})block_sum;
}}
}}
}}
}}",
n_outputs = n_outputs.to_kernel(),
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("generic_matmul").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
32.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape
.iter()
.copied()
.product::<Expression>()
.max(Expression::from(1))
}
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.mul_shape.iter().flat_map(|e| e.dyn_vars()))
.chain(self.k.dyn_vars())
.chain(self.lhs_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(self.rhs_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(self.sum_input_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(self.sum_iter_stride.dyn_vars())
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
.collect()
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
fn bytes_loaded(&self) -> Expression {
(self.output_size() * self.k * self.dtype.bits() * 2).ceil_div(8)
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.output_size() * self.k * 2
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"GenericMatmul"
}
}

View File

@@ -987,7 +987,7 @@ extern \"C\" {{
fn output_bytes(&self) -> Expression {
let elem_size: Expression = match self.dtype {
DType::F64 => 8,
DType::F64 | DType::I64 => 8,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
DType::Bool
@@ -1021,7 +1021,7 @@ extern \"C\" {{
fn bytes_loaded(&self) -> Expression {
let data_elem_size: Expression = match self.dtype {
DType::F64 => 8,
DType::F64 | DType::I64 => 8,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
DType::Bool

View File

@@ -12,6 +12,7 @@ use uuid::Uuid;
pub mod conv2d;
pub mod cuda_graph;
pub mod fusion;
pub mod generic_matmul;
pub mod hlir;
pub mod matmul2d;
pub mod other_ops;
@@ -19,13 +20,20 @@ pub mod rope;
pub use conv2d::KernelConv2D;
pub use cuda_graph::*;
pub use generic_matmul::GenericMatmul;
pub use matmul2d::{
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
matmul_3d, matmul_3d_t,
};
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
pub type Ops = (hlir::Ops, other_ops::Ops, conv2d::KernelConv2D, fusion::Ops);
pub type Ops = (
hlir::Ops,
other_ops::Ops,
conv2d::KernelConv2D,
GenericMatmul,
fusion::Ops,
);
/// Build a mapping from interned string IDs to their string values for a given sequence.
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
@@ -296,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
// Kernel to host op compilation
mod to_host;
#[cfg(test)]
pub(crate) use to_host::CudaGraphDebugSummary;
pub use to_host::{CudaGraphOp, kernel_to_host};

View File

@@ -17,13 +17,7 @@ use luminal::{
prelude::*,
};
pub type Ops = (
KernelMeanReduce,
KernelBatchMatVec,
KernelBatchMatMul,
KernelScatterNoCopy,
KernelSoftmax,
);
pub type Ops = (KernelMeanReduce, KernelScatterNoCopy, KernelSoftmax);
#[derive(Default, Debug, Clone)]
@@ -532,7 +526,7 @@ extern \"C\" {{
fn output_bytes(&self) -> Expression {
let elem_size: Expression = match self.dtype {
DType::F64 => 8,
DType::F64 | DType::I64 => 8,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
DType::Bool
@@ -566,7 +560,7 @@ extern \"C\" {{
fn bytes_loaded(&self) -> Expression {
let data_elem_size: Expression = match self.dtype {
DType::F64 => 8,
DType::F64 | DType::I64 => 8,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
DType::Bool
@@ -585,7 +579,7 @@ extern \"C\" {{
fn bytes_stored(&self) -> Expression {
let data_elem_size: Expression = match self.dtype {
DType::F64 => 8,
DType::F64 | DType::I64 => 8,
DType::F32 | DType::Int => 4,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
DType::Bool
@@ -619,569 +613,6 @@ extern \"C\" {{
}
}
// =============================================================================
// KernelBatchMatVec: Fused batched matrix-vector product for attention
// Matches: Mul(broadcast) + Sum pattern for [B, 1, K] x [B, K, N] -> [B, 1, N]
// or [B, M, K] x [B, K, N] -> [B, M, N] with small M
// Replaces the broadcast elementwise Mul + single-threaded KernelSumReduce pipeline
// =============================================================================
#[derive(Default, Debug, Clone)]
pub struct KernelBatchMatVec {
// Output shape: the final reduced shape [B..., M, N]
out_shape: Vec<Expression>,
// K: the reduction dimension (was the Sum iters)
k_dim: Expression,
// Strides for input A (with K dim removed)
a_stride: Vec<Expression>,
a_k_stride: Expression,
// Strides for input B (with K dim removed)
b_stride: Vec<Expression>,
b_k_stride: Expression,
// Output strides
out_stride: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelBatchMatVec {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelBatchMatVec",
&[
("out_shape", ELIST),
("k_dim", EXPRESSION),
("a_stride", ELIST),
("a_k_stride", EXPRESSION),
("b_stride", ELIST),
("b_k_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
; Match Mul node (broadcast multiply)
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape must have 3+ dimensions (batched)
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
; k_stride must be contiguous
(= ?k_stride (MIter))
; Get A's k-dimension stride (second from end in Mul's a_stride)
(= ?a_k_stride (nth_from_end ?a_stride 1))
; Get B's k-dimension stride (second from end in Mul's b_stride)
(= ?b_k_stride (nth_from_end ?b_stride 1))
; A's k stride must be contiguous (row-major A)
(= ?a_k_stride (MIter))
; B's k stride must be contiguous (col-major B)
(= ?b_k_stride (MIter))
; Must be F32
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; Remove the k-dimension from A strides for the kernel
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
; Remove the k-dimension from B strides
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
(let ?bmv (Op (KernelBatchMatVec
?out_shape ?k
?a_kern_stride ?a_k_stride
?b_kern_stride ?b_k_stride
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
(union ?sum ?bmv)
(set (dtype ?bmv) (F32))
)
:ruleset matmul_backend
:name \"batch mat-vec\"
)"
)]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
.unwrap(),
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[7]),
})),
input_enodes, // A, B
)
}
}
impl KernelOp for KernelBatchMatVec {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars: FxHashSet<char> = self
.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.k_dim.dyn_vars())
.chain(self.a_k_stride.dyn_vars())
.chain(self.b_k_stride.dyn_vars())
.collect();
let n_outputs: Expression = self.out_shape.iter().copied().product();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
// Each output element is a dot product of length K.
// We launch one block of 256 threads per output element.
// Threads cooperatively reduce K using warp shuffles.
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
let k_expr = self.k_dim.to_kernel();
let a_k_stride_expr = self
.a_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let b_k_stride_expr = self
.b_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let kernel = format!(
"
#define WARP_SIZE 32
#define THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
{dyn_defines}
extern \"C\" {{
__global__ void batch_matvec(float *out, const float *A, const float *B{dyn_dims_param}) {{
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
long long a_base = {a_idx};
long long b_base = {b_idx};
long long K = {k_expr};
long long a_k_stride = {a_k_stride_expr};
long long b_k_stride = {b_k_stride_expr};
float partial = 0.0f;
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
partial += __shfl_down_sync(FULL_MASK, partial, s);
}}
if (lane_id == 0) {{
warp_sums[warp_id] = partial;
}}
__syncthreads();
if (warp_id == 0) {{
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
#pragma unroll
for (int s = cnt / 2; s > 0; s /= 2) {{
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
}}
if (tid == 0) {{
out[{out_idx}] = block_sum;
}}
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("batch_matvec").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()), // grid: one block per output
(256.into(), 1.into(), 1.into()), // block: 256 threads
32.into(), // shared mem for warp_sums
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
let n = self.output_size();
// Each output loads K elements from A and K elements from B
n * self.k_dim * 2 * 4
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
// Each output: K multiply-adds = 2*K FLOPs
self.output_size() * self.k_dim * 2
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"BatchMatVec"
}
}
// =============================================================================
// KernelBatchMatMul: General batched matmul with arbitrary strides
// Like KernelBatchMatVec but handles non-contiguous K strides (e.g., transposed
// inputs) and non-uniform batch strides (e.g., GQA expansion). One block of 256
// threads per output element; threads cooperatively reduce along K.
// =============================================================================
#[derive(Default, Debug, Clone)]
pub struct KernelBatchMatMul {
out_shape: Vec<Expression>,
k_dim: Expression,
a_stride: Vec<Expression>,
a_k_stride: Expression,
b_stride: Vec<Expression>,
b_k_stride: Expression,
out_stride: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelBatchMatMul {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelBatchMatMul",
&[
("out_shape", ELIST),
("k_dim", EXPRESSION),
("a_stride", ELIST),
("a_k_stride", EXPRESSION),
("b_stride", ELIST),
("b_k_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
; Match Mul node (broadcast multiply)
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape must have 3+ dimensions (batched)
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
; k_stride must be contiguous in the Sum output
(= ?k_stride (MIter))
; K must be > 1 (K=1 is a degenerate outer product, not a real matmul)
(!= ?k (MNum 1))
; Get A's and B's k-dimension strides (no contiguity requirement)
(= ?a_k_stride (nth_from_end ?a_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 1))
; One of A's non-k strides must be 0 (broadcast along n)
(= (MNum 0) (nth_from_end ?a_stride 0))
; One of B's non-k strides must be 0 (broadcast along m)
(= (MNum 0) (nth_from_end ?b_stride 2))
; Must be F32
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
(let ?bmm (Op (KernelBatchMatMul
?out_shape ?k
?a_kern_stride ?a_k_stride
?b_kern_stride ?b_k_stride
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
(union ?sum ?bmm)
(set (dtype ?bmm) (F32))
)
:ruleset matmul_backend
:name \"batch matmul\"
)"
)]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
.unwrap(),
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[7]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelBatchMatMul {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars: FxHashSet<char> = self
.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.k_dim.dyn_vars())
.chain(self.a_k_stride.dyn_vars())
.chain(self.b_k_stride.dyn_vars())
.collect();
let n_outputs: Expression = self.out_shape.iter().copied().product();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
let k_expr = self.k_dim.to_kernel();
let a_k_stride_expr = self
.a_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let b_k_stride_expr = self
.b_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let kernel = format!(
"
#define WARP_SIZE 32
#define THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
{dyn_defines}
extern \"C\" {{
__global__ void batch_matmul(float *out, const float *A, const float *B{dyn_dims_param}) {{
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
long long a_base = {a_idx};
long long b_base = {b_idx};
long long K = {k_expr};
long long a_k_stride = {a_k_stride_expr};
long long b_k_stride = {b_k_stride_expr};
float partial = 0.0f;
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
partial += __shfl_down_sync(FULL_MASK, partial, s);
}}
if (lane_id == 0) {{
warp_sums[warp_id] = partial;
}}
__syncthreads();
if (warp_id == 0) {{
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
#pragma unroll
for (int s = cnt / 2; s > 0; s /= 2) {{
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
}}
if (tid == 0) {{
out[{out_idx}] = block_sum;
}}
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("batch_matmul").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
32.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
let n = self.output_size();
n * self.k_dim * 2 * 4
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
self.output_size() * self.k_dim * 2
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"BatchMatMul"
}
}
// =============================================================================
// KernelSoftmax: Fused softmax over last dimension
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@ fn cuda_dtype(dtype: DType) -> &'static str {
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I64 => "long long",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,10 @@ fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeInde
(cx, a.id, b.id, c.id)
}
fn bucket_options(buckets: &[DimBucket]) -> CompileOptions {
CompileOptions::default().dim_buckets('s', buckets)
}
#[test]
fn test_bucket_dispatch_simple() {
// Tests that bucketed compilation produces correct results for different dim values
@@ -31,9 +35,10 @@ fn test_bucket_dispatch_simple() {
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(bucket_options(&[
DimBucket::new(1, 1),
DimBucket::new(2, 4),
]));
let mut rt = CudaRuntime::initialize(stream);
// Set dummy input for search
@@ -41,7 +46,11 @@ fn test_bucket_dispatch_simple() {
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Test bucket 1: s=1
cx.set_dim('s', 1);
@@ -73,9 +82,10 @@ fn test_bucket_matmul_dynamic() {
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(bucket_options(&[
DimBucket::new(1, 1),
DimBucket::new(2, 8),
]));
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
@@ -85,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute at s=1
cx.set_dim('s', 1);
@@ -135,12 +149,16 @@ fn test_bucket_results_match_unbucketed() {
// Non-bucketed run
let (mut cx1, a1, b1) = build_dynamic_add_graph();
cx1.set_dim('s', 3);
cx1.build_search_space::<CudaRuntime>();
cx1.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt1 = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
rt1 = cx1.search_with_rng(
rt1,
CompileOptions::default().search_graph_limit(5),
&mut rng1,
);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
@@ -148,12 +166,15 @@ fn test_bucket_results_match_unbucketed() {
// Bucketed run with bucket that covers s=3
let (mut cx2, a2, b2) = build_dynamic_add_graph();
cx2.set_dim('s', 3);
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
cx2.build_search_space::<CudaRuntime>();
cx2.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 4)]));
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
rt2 = cx2.search_with_rng(
rt2,
CompileOptions::default().search_graph_limit(5),
&mut rng2,
);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
@@ -172,14 +193,20 @@ fn test_bucket_out_of_range_panics() {
};
let (mut cx, a, _b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(bucket_options(&[
DimBucket::new(1, 1),
DimBucket::new(2, 4),
]));
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
@@ -197,14 +224,18 @@ fn test_bucket_no_buckets_backward_compat() {
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim('s', 2);
// No set_dim_buckets call
// No bucket options
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
@@ -237,9 +268,10 @@ fn test_bucket_switch_preserves_weights() {
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(bucket_options(&[
DimBucket::new(1, 1),
DimBucket::new(2, 4),
]));
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
@@ -249,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
@@ -297,15 +333,17 @@ fn test_bucket_multiple_executions_same_bucket() {
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 8)]));
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {
@@ -323,8 +361,7 @@ fn test_bucket_multiple_executions_same_bucket() {
#[test]
#[should_panic(expected = "Overlapping buckets")]
fn test_bucket_overlapping_ranges_panics() {
let mut cx = Graph::default();
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
let _ = bucket_options(&[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
}
#[test]

View File

@@ -10,7 +10,7 @@ use crate::runtime::CudaRuntime;
/// Helper: build search space and extract all possible kernel names across many random choices.
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
let custom_ops = &cx.custom_ops;
@@ -199,7 +199,7 @@ fn test_scatter_execution_correctness() {
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
@@ -298,7 +298,7 @@ fn test_scatter_kv_cache_roundtrip() {
// Return cache for round-trip
let cache_output = cache_out.output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Print and verify which scatter variant was selected
let scatter_names: Vec<_> = rt
@@ -415,7 +415,7 @@ fn test_scatter_dual_cache() {
let k_cache_out = k_out.output();
let v_cache_out = v_out.output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
// Use seeded search for deterministic variant selection.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Print and verify selected variants
let scatter_names: Vec<_> = rt
@@ -535,7 +539,7 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
let gathered = gather_rows(updated, gather_idx, D).output();
let cache_out = updated.output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
cx.set_dim('s', S);
let mut rt = CudaRuntime::initialize(stream);
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
rt.set_data(gather_idx, scatter);
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
assert_eq!(rt.get_f32(gathered), expected_gather);
@@ -733,7 +741,7 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
cx.set_dim('s', S);
cx.set_dim('c', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let q_data: Vec<f32> = (0..S * Q_DIM)
.map(|i| ((i as f32 + 1.0) * 0.031).sin())
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -844,7 +856,7 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
cx.set_dim('s', S);
cx.set_dim('p', 0);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let q_data: Vec<f32> = (0..S * Q_DIM)
.map(|i| ((i as f32 + 1.0) * 0.031).sin())
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -925,7 +941,7 @@ fn test_dynamic_expanded_causal_mask_softmax() {
cx.set_dim('s', S);
cx.set_dim('c', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut mask_data = vec![0.0f32; S * S];
for row in 0..S {
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(weights);
@@ -991,7 +1011,7 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
cx.set_dim('s', S);
cx.set_dim('c', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let v_data: Vec<f32> = (0..S * KV_DIM)
.map(|i| ((i as f32 + 5.0) * 0.029).sin())
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
rt.set_data(v_full, v_data.clone());
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1055,7 +1079,7 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
cx.set_dim('s', S);
cx.set_dim('c', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let v_data: Vec<f32> = (0..N_KV_HEADS * S * HEAD_DIM)
.map(|i| ((i as f32 + 5.0) * 0.029).sin())
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
rt.set_data(v_full, v_data.clone());
rt.set_data(weights, weights_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1115,7 +1143,7 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
let roundtrip = flat.split_dims(1, D).transpose(0, 1).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let x_data: Vec<f32> = (0..H * S * D)
.map(|i| ((i as f32 + 0.75) * 0.051).sin())
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(x, x_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(roundtrip);
@@ -1158,7 +1190,7 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
.output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let x_data: Vec<f32> = (0..S * H)
.map(|i| ((i as f32 + 0.5) * 0.137).sin())
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
rt.set_data(x, x_data.clone());
rt.set_data(w, w_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1211,7 +1247,7 @@ fn test_batched_topk_axis1_matches_cpu() {
let topk = routing.topk_indexes(K, 1).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let routing_data: Vec<f32> = (0..S * E)
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(topk);
@@ -1250,7 +1290,7 @@ fn test_batched_argsort_axis1_matches_cpu() {
let argsort = routing.argsort(1, true).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let routing_data: Vec<f32> = (0..S * E)
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(argsort);
@@ -1290,7 +1334,7 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
let out = input.sum(1).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let data: Vec<f32> = (0..S * A * B)
.map(|i| ((i as f32 + 4.0) * 0.031).sin())
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1347,7 +1395,7 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
.output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let routing_data: Vec<f32> = (0..S * E)
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(ranks);
@@ -1391,11 +1443,15 @@ fn test_dynamic_3d_flat_index_iota_rows() {
.output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(idx);
@@ -1431,14 +1487,18 @@ fn test_dynamic_2d_to_3d_gather_rows() {
let out = data.gather(idx).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let data_values: Vec<i32> = (0..S * E).map(|i| ((i * 17 + 5) % 1000) as i32).collect();
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(data, data_values.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(out);
@@ -1479,7 +1539,7 @@ fn test_batched_gather_experts_matches_cpu() {
let out = weights.gather(exp_base + exp_within).output();
cx.set_dim('s', S);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let topk_data: Vec<i32> = (0..S * K).map(|i| ((i * 5 + 3) % E) as i32).collect();
let weights_data: Vec<f32> = (0..E * D1 * D2)
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
rt.set_data(topk, topk_data.clone());
rt.set_data(weights, weights_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);

View File

@@ -132,7 +132,7 @@ fn conv2d_matmul_without_conv_output_shape(
#[test]
fn generic_conv2d_rewrite_matches_unfold_matmul_bias() {
let (mut cx, _, _, _, _) = build_conv_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
@@ -148,7 +148,7 @@ fn generic_conv2d_rewrite_matches_unfold_matmul_bias() {
#[test]
fn generic_conv2d_rewrite_matches_conv1x1_matmul_bias() {
let (mut cx, _, _, _, _) = build_conv1x1_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
@@ -165,7 +165,7 @@ fn generic_conv2d_rewrite_requires_conv_output_shape() {
let bias = cx.tensor(3usize);
conv2d_matmul_without_conv_output_shape(x, weight, bias, 3, 2).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
@@ -181,7 +181,7 @@ fn generic_conv2d_candidate_executes_unfold_matmul_bias() {
};
let (mut cx, x, weight, bias, out) = build_conv_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
let input: Vec<f32> = (0..2 * 5 * 6).map(|i| i as f32 * 0.03 - 0.4).collect();
@@ -222,7 +222,7 @@ fn generic_conv2d_candidate_executes_conv1x1_matmul_bias() {
};
let (mut cx, x, weight, bias, out) = build_conv1x1_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.07 - 1.0).collect();
@@ -261,7 +261,7 @@ fn generic_conv2d_candidate_executes_padded_unfold_matmul_bias() {
};
let (mut cx, x, weight, bias, out) = build_padded_conv_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.05 - 0.5).collect();
@@ -302,7 +302,7 @@ fn generic_conv2d_candidate_executes_upsample_view_input() {
};
let (mut cx, x, weight, bias, out) = build_upsample_conv_graph();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 * 0.09 - 0.8).collect();

View File

@@ -7,10 +7,12 @@ use luminal::{
prelude::*,
};
use rand::{SeedableRng, rngs::StdRng};
use std::sync::Arc;
use crate::{
host::{
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
cublaslt_type_tuple,
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
expected
}
fn reference_mixed_chain(
a: &[f32],
pre: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> Vec<f32> {
let mut expected = vec![0.0; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0.0;
for inner in 0..k {
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
}
expected[row * n + col] = acc.exp();
}
}
expected
}
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
crate::try_create_cublaslt(stream.clone()).is_ok()
}
fn build_mixed_chain_graph(
m: impl Into<Expression>,
n: usize,
k: usize,
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
let m = m.into();
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let pre = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = ((a + pre).matmul(b).exp()).output();
(cx, a.id, pre.id, b.id, out.id)
}
fn add_in_place(values: &mut [f32], addends: &[f32]) {
for (value, addend) in values.iter_mut().zip(addends) {
*value += *addend;
@@ -481,7 +522,7 @@ fn cublaslt_cleanup_prunes_flux2_broadcast_mul_fallback() {
let k = cx.tensor((8usize, 4usize));
let _out = q.matmul(k.t()).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
!cublaslt_ir_nodes(egraph).is_empty(),
@@ -507,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
}
}
#[test]
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
});
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
let summaries = rt.debug_cuda_graph_summaries();
let mixed = summaries
.iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_executes_correctly() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (5, 8, 8);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let first = a.matmul(b);
let out = (a + first.sin()).matmul(b).output();
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"ordered matching cuBLASLt prepared reuse",
|llir| {
let orders = cublaslt_matrix_order_tuples(llir);
orders.len() == 2 && orders[0] == orders[1]
},
);
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
let dep = a_data
.iter()
.zip(&first)
.map(|(a, first)| a + first.sin())
.collect::<Vec<_>>();
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 2)
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
assert_eq!(
summary.n_cublaslt_prepared, 1,
"dependency-ordered matching cuBLASLt calls should share prepared resources"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let mut cx = Graph::new();
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
a.output();
b.output();
cx.set_dim('p', 1);
let llir = extract_forced_cublaslt_llir_where(
&mut cx,
"cuBLASLt unchanged under unrelated dyn dim",
|_| true,
);
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
reset_cublaslt_prepare_count_for_test();
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
let first_prepare_count = cublaslt_prepare_count_for_test();
assert!(
first_prepare_count > 0,
"first execution should prepare the captured cuBLASLt island"
);
cx.set_dim('p', 2);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count,
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('m', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('m', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [
(7usize, 0xC002_0001),
(9usize, 0xC002_0002),
(7usize, 0xC002_0003),
] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
let summary = rt
.debug_cuda_graph_summaries()
.into_iter()
.find(|summary| summary.n_cublaslt == 1)
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
assert_eq!(summary.n_kernels, 0);
assert_eq!(summary.n_steps, 1);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn cublaslt_with_dynamic_c_spec_is_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('c', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('c', 7);
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
cx.set_dim('c', c);
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
}
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let mut cx = Graph::new();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let out = a.matmul(b).output();
cx.set_dim('s', 1);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
.into_iter()
.collect();
let bucket_llirs = vec![
(
[('s', 0usize)].into_iter().collect(),
[('s', 1usize)].into_iter().collect(),
llir.clone(),
),
(
[('s', 1usize)].into_iter().collect(),
[('s', 3usize)].into_iter().collect(),
llir,
),
];
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data.clone());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"singleton s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
cx.set_dim('s', 3);
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
rt.set_data(a.id, a_data);
rt.set_data(b.id, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
assert!(
rt.debug_cuda_graph_summaries()
.iter()
.any(|summary| summary.n_cublaslt == 1),
"range s bucket should capture s-dependent cuBLASLt"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
assert!(
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
"bucket with captured cuBLASLt needs stable intermediate pointers"
);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (m, n, k) = (7, 11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
reset_cublaslt_prepare_count_for_test();
let mut first_prepare_count = None;
for seed in [0xCC00_0001, 0xCC00_0002] {
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
if first_prepare_count.is_none() {
first_prepare_count = Some(cublaslt_prepare_count_for_test());
}
}
assert_eq!(
cublaslt_prepare_count_for_test(),
first_prepare_count.unwrap(),
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
);
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after pointer recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
let Some(stream) = get_cuda_stream() else {
return;
};
if !cublaslt_available_for_runtime(&stream) {
return;
}
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
return;
}
let (n, k) = (11, 5);
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
cx.set_dim('m', 7);
let llir =
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
let mut rt = CudaRuntime::initialize(stream);
rt.load_llir(&llir);
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
cx.set_dim('m', m);
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
rt.set_data(a, a_data);
rt.set_data(pre, pre_data);
rt.set_data(b, b_data);
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
}
let summaries = rt.debug_cuda_graph_summaries();
assert!(
summaries.iter().any(|summary| summary.n_cublaslt == 1),
"expected cuBLASLt to remain captured after dynamic-shape recapture"
);
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
}
#[test]
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {
@@ -1027,7 +1525,7 @@ fn cublaslt_fp8_scaled_candidate_reaches_fused_output_scale_consumer() {
let scaled_out = scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n));
(scaled_out * side).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
@@ -1061,7 +1559,7 @@ fn cublaslt_fp8_scaled_candidates_reach_fused_mlp_consumer() {
* (up_input_scale * up_weight_scale).expand_rhs((m, n));
(gate.swish() * up).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
assert!(
@@ -2936,7 +3434,7 @@ fn extract_forced_cublaslt_llir_where(
case_name: &str,
matches: impl Fn(&LLIRGraph) -> bool,
) -> LLIRGraph {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
let ops = cx
@@ -2991,7 +3489,7 @@ fn assert_no_forced_cublaslt_llir_where(
case_name: &str,
matches: impl Fn(&LLIRGraph) -> bool,
) {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
let ops = cx
@@ -3040,7 +3538,7 @@ fn assert_no_cublaslt_llir_where(
case_name: &str,
matches: impl Fn(&LLIRGraph) -> bool,
) {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("search space should have an e-graph");
let ops = cx

View File

@@ -83,13 +83,13 @@ fn run_reference_attention(
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
cx.set_dim('s', batch_size);
cx.set_dim('c', context_len);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
rt.set_data(v_t, v.to_vec());
rt = cx.search(rt, 3);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
@@ -779,7 +779,7 @@ fn flashinfer_extraction_reachable_from_search_space() {
cx.set_dim('s', 1usize);
cx.set_dim('c', 16usize);
cx.set_dim('r', 2usize);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx
.egraph()

View File

@@ -293,7 +293,7 @@ struct FusedRegion {
/// Helper: collect every distinct fused region reachable across many random
/// extractions of the search space.
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
let custom_ops = &cx.custom_ops;

View File

@@ -0,0 +1,169 @@
use luminal::{
egglog_utils::{
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
},
prelude::*,
};
use rand::{SeedableRng, rngs::StdRng};
use crate::{kernel::KernelOp, runtime::CudaRuntime};
use super::utilities::{assert_close, get_cuda_stream};
#[test]
fn generic_matmul_covers_noncontiguous_merged_head_projection() {
let mut cx = Graph::default();
let heads = 3;
let seq = 4;
let head_dim = 5;
let hidden = heads * head_dim;
let out_dim = 7;
let attn = cx.tensor((heads, seq, head_dim));
let weight = cx.tensor((out_dim, hidden));
let merged = attn.transpose(0, 1).merge_dims(1, 2);
merged.matmul(weight.t()).output();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let llir = extract_forced_kernel_llir(&mut cx, "GenericMatmul");
let names = llir_kernel_names(&llir);
assert!(
names.contains(&"GenericMatmul"),
"expected generic matmul fallback, kernels: {names:?}"
);
assert!(
!names.contains(&"Mul") && !names.contains(&"SumReduce"),
"generic matmul should prune the broadcast multiply/sum fallback, kernels: {names:?}"
);
}
#[test]
fn generic_matmul_executes_noncontiguous_merged_head_projection() {
let mut cx = Graph::default();
let heads = 3;
let seq = 4;
let head_dim = 5;
let hidden = heads * head_dim;
let out_dim = 7;
let attn = cx.tensor((heads, seq, head_dim));
let weight = cx.tensor((out_dim, hidden));
let merged = attn.transpose(0, 1).merge_dims(1, 2);
let output = merged.matmul(weight.t()).output();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let stream = get_cuda_stream().expect("CUDA device required for GenericMatmul execution test");
let mut rt = CudaRuntime::initialize(stream);
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
rt.set_data(attn, attn_data.as_slice());
rt.set_data(weight, weight_data.as_slice());
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
assert!(
rt.kernel_names().contains(&"GenericMatmul"),
"expected GenericMatmul to be selected, kernels: {:?}",
rt.kernel_names()
);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output.id);
let mut expected = vec![0.0; seq * out_dim];
for token in 0..seq {
for out_col in 0..out_dim {
let mut sum = 0.0;
for inner in 0..hidden {
let head = inner / head_dim;
let dim = inner % head_dim;
let attn_idx = head * seq * head_dim + token * head_dim + dim;
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
}
expected[token * out_dim + out_col] = sum;
}
}
assert_close(&result, &expected, 1e-5, 1e-5);
}
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
(0..len)
.map(|i| {
let x = ((i * 37 + 11) % 97) as f32 / 97.0;
x * scale + bias
})
.collect()
}
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
let egraph = cx.egraph().expect("search space should have an e-graph");
let ops = cx
.egglog_ops()
.expect("search space should have registered egglog ops");
let kernel_nodes = op_ir_nodes(egraph, kernel_name);
assert!(
!kernel_nodes.is_empty(),
"expected at least one {kernel_name} candidate"
);
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
let mut choices = random_initial_choice(egraph, &mut rng);
let kernel_class = &egraph.node_to_class[*kernel_node];
choices.insert(kernel_class, kernel_node);
if validate_choice_set(egraph, &choices, ops).is_err() {
continue;
}
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
if llir_kernel_names(&llir).contains(&kernel_name) {
return llir;
}
}
panic!("could not extract a valid {kernel_name} candidate");
}
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
llir.node_indices()
.filter_map(|node| {
llir[node]
.to_dialect::<dyn KernelOp>()
.map(|kernel| kernel.kernel_name())
})
.collect()
}
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
let op_kind_classes = egraph
.enodes
.iter()
.filter(|(_, (label, _))| label == kind_label)
.map(|(node, _)| egraph.node_to_class[node].clone())
.collect::<Vec<_>>();
egraph
.enodes
.iter()
.filter_map(|(node, (label, children))| {
(label == "Op"
&& children
.first()
.is_some_and(|kind| op_kind_classes.contains(kind)))
.then_some(node)
})
.collect()
}

View File

@@ -13,6 +13,8 @@ mod flashinfer;
#[cfg(test)]
mod fusion;
#[cfg(test)]
mod generic_matmul_rewrite;
#[cfg(test)]
mod model_fuzz;
#[cfg(test)]
mod op_functional_tests;

View File

@@ -83,7 +83,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -143,7 +143,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
let proj_w = cx.tensor((proj_dim, hidden));
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -219,7 +219,7 @@ fn fuzz_layer_no_attn(
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
let out = (x + mlp_out).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -318,7 +318,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -481,7 +481,7 @@ mod gemma {
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
let out = (x + mlp_normed).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 800u64;
@@ -518,7 +518,7 @@ mod gemma {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -641,7 +641,7 @@ mod qwen {
let embedding = cx.tensor((VOCAB, HIDDEN));
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 1300u64;
@@ -655,7 +655,7 @@ mod qwen {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -256,10 +256,10 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data);
rt = cx.search(rt, 10);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&cx.dyn_map);
let out_dim0 = rt.get_i32(sorted_dim0.id);
let out_dim1 = rt.get_i32(sorted_dim1.id);
@@ -424,7 +424,7 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
let e = (d + c).relu();
let out = e.output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let egraph = cx.egraph().unwrap();
let ops = cx.egglog_ops().unwrap();
@@ -592,7 +592,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
)
.output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let token_data: Vec<i32> = random_i32_vec(seq_len, seed, 0, vocab_size as i32 - 1);
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
rt.set_data(token_ids, token_data.clone());
rt.set_data(embed_table, embed_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);

View File

@@ -27,11 +27,11 @@ pub fn kernel_add_bandwidth_test() {
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Warm up
rt.execute(&cx.dyn_map);

View File

@@ -2,10 +2,7 @@ use half::bf16;
use luminal::{dtype::DType, prelude::*, shape::Expression};
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
use crate::{
host::moe::{GLUMoE, GLUMoEMode},
runtime::CudaRuntime,
};
use crate::{host::moe::GLUMoE, runtime::CudaRuntime};
const SEQ: usize = 2;
const HIDDEN: usize = 32;
@@ -173,30 +170,51 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
x * scaled.sigmoid()
}
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
rt.host_ops()
.into_iter()
.filter_map(|op| {
op.as_any()
.downcast_ref::<GLUMoE>()
.map(|glumoe| glumoe.mode)
})
.collect()
fn search_space_contains(cx: &Graph, op_name: &str) -> bool {
let egraph = cx.egraph().expect("test should build an e-graph");
for (label, children) in egraph.enodes.values() {
if label != "Op" {
continue;
}
let Some(kind_eclass) = children.first() else {
continue;
};
let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) else {
continue;
};
if kind_enodes
.iter()
.any(|kind_node| egraph.enodes[kind_node].0 == op_name)
{
return true;
}
}
false
}
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
fn assert_glumoe_in_search_space(cx: &Graph) {
assert!(
search_space_contains(cx, "GLUMoE"),
"GLUMoE was not in the e-graph search space"
);
}
fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
let Some(stream) = get_cuda_stream() else {
return (vec![], vec![]);
return vec![];
};
let mut model = build_qwen_moe_graph();
model.graph.set_dim('s', SEQ);
if use_glumoe {
model.graph.build_search_space::<CudaRuntime>();
if include_glumoe {
model
.graph
.build_search_space::<CudaRuntime>(CompileOptions::default());
} else {
model
.graph
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
}
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
@@ -215,25 +233,29 @@ fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
rt.set_data(model.router, router_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, 10);
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
(rt.get_f32(model.output.id), glumoe_modes(&rt))
rt.get_f32(model.output.id)
}
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
let Some(stream) = get_cuda_stream() else {
return (vec![], vec![]);
return vec![];
};
let mut model = build_gemma_moe_graph();
model.graph.set_dim('s', SEQ);
if use_glumoe {
model.graph.build_search_space::<CudaRuntime>();
if include_glumoe {
model
.graph
.build_search_space::<CudaRuntime>(CompileOptions::default());
} else {
model
.graph
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
}
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
@@ -258,54 +280,60 @@ fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
rt.set_data(model.per_expert_scale, per_expert_scale_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, 10);
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
(rt.get_f32(model.output.id), glumoe_modes(&rt))
rt.get_f32(model.output.id)
}
#[test]
fn test_glumoe_matches_qwen_swiglu_pattern() {
let (_result, modes) = run_qwen_moe(true);
if modes.is_empty() {
if get_cuda_stream().is_none() {
return;
}
assert_eq!(modes, vec![GLUMoEMode::SwiGLUNormalized]);
let mut model = build_qwen_moe_graph();
model.graph.set_dim('s', SEQ);
model
.graph
.build_search_space::<CudaRuntime>(CompileOptions::default());
assert_glumoe_in_search_space(&model.graph);
}
#[test]
fn test_glumoe_matches_gemma_gelu_pattern() {
let (_result, modes) = run_gemma_moe(true);
if modes.is_empty() {
if get_cuda_stream().is_none() {
return;
}
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
let mut model = build_gemma_moe_graph();
model.graph.set_dim('s', SEQ);
model
.graph
.build_search_space::<CudaRuntime>(CompileOptions::default());
assert_glumoe_in_search_space(&model.graph);
}
#[test]
fn test_glumoe_swiglu_matches_unfused_output() {
let (expected, baseline_modes) = run_qwen_moe(false);
let expected = run_qwen_moe(false);
if expected.is_empty() {
return;
}
assert!(baseline_modes.is_empty());
let (actual, fused_modes) = run_qwen_moe(true);
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLUNormalized]);
let actual = run_qwen_moe(true);
assert_close(&actual, &expected, 3e-2, 3e-2);
}
#[test]
fn test_glumoe_gemma_gelu_matches_unfused_output() {
let (expected, baseline_modes) = run_gemma_moe(false);
let expected = run_gemma_moe(false);
if expected.is_empty() {
return;
}
assert!(baseline_modes.is_empty());
let (actual, fused_modes) = run_gemma_moe(true);
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
let actual = run_gemma_moe(true);
assert_close(&actual, &expected, 3e-2, 3e-2);
}

View File

@@ -1,5 +1,8 @@
use cudarc::driver::CudaContext;
use luminal::{graph::Graph, op::Runtime};
use luminal::{
graph::{CompileOptions, Graph},
op::Runtime,
};
use crate::{kernel::apply_rope, runtime::CudaRuntime};
@@ -42,12 +45,12 @@ fn rope_matches_cpu_reference() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);
@@ -90,12 +93,12 @@ fn rope_flux2_shape() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);

View File

@@ -2,7 +2,7 @@
//!
//! These tests do not compare against a hand-written reference. They assert the
//! stronger search invariant: every selectable LLIR graph from the same e-graph
//! must produce the same outputs for the same runtime inputs.
//! must produce finite, numerically close outputs for the same runtime inputs.
#[allow(dead_code)]
#[path = "../../../../examples/llama/src/model.rs"]
@@ -92,8 +92,8 @@ fn llama_architecture_search_space_equivalence_fuzz() {
.samples(SEARCH_EQUIV_SAMPLES)
.generation_size(8)
.mutations(3)
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
.output_f32(logits.id, "logits", 3e-3, 3e-3);
.build_options(CompileOptions::default().max_memory_mib(512))
.output_f32(logits.id, "logits", 5e-2, 5e-2);
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
let k_out = k_out.output();
let v_out = v_out.output();
@@ -168,7 +168,7 @@ fn gemma_architecture_search_space_equivalence_fuzz() {
.samples(SEARCH_EQUIV_SAMPLES)
.generation_size(8)
.mutations(3)
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
.build_options(CompileOptions::default().max_memory_mib(512))
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 101, -0.15, 0.15))
.input_f32(attn_norm_w.id, random_f32_vec(HIDDEN, 102, 0.7, 1.3))
.input_f32(post_attn_norm_w.id, random_f32_vec(HIDDEN, 103, 0.7, 1.3))
@@ -263,7 +263,7 @@ fn moe_architecture_search_space_equivalence_fuzz() {
.samples(SEARCH_EQUIV_SAMPLES)
.generation_size(8)
.mutations(3)
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
.build_options(CompileOptions::default().max_memory_mib(512))
.input_f32(
router_input.id,
random_f32_vec(SEQ * HIDDEN, 201, -0.15, 0.15),
@@ -353,7 +353,7 @@ fn moe_architecture_native_reference_fuzz() {
.samples(SEARCH_EQUIV_SAMPLES)
.generation_size(8)
.mutations(3)
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
.build_options(CompileOptions::default().max_memory_mib(512))
.native_reference()
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 301, -0.15, 0.15))
.input_f32(

View File

@@ -267,7 +267,7 @@ fn test_mini_transformer_layer() {
let layer = MiniTransformerLayer::init(&mut cx);
let out = layer.forward(input).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
// Use minimal search iterations to avoid excessive graph rewriting
// which can cause float drift through softmax/RMSNorm reordering
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -303,7 +303,7 @@ fn test_mini_transformer_two_layers() {
let x = layer1.forward(input);
let out = layer2.forward(x).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -361,7 +361,7 @@ fn test_transformer_multi_seed() {
let layer = MiniTransformerLayer::init(&mut cx);
let out = layer.forward(input).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -394,7 +394,7 @@ fn test_rms_norm_cuda() {
let weight = cx.tensor(HIDDEN);
let out = rms_norm(input, weight, 1e-5).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 1, -0.5, 0.5);
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
.collect();
rt.set_data(input, input_data.clone());
rt.set_data(weight, weight_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -433,7 +433,7 @@ fn test_self_attention_cuda() {
let wo = cx.tensor((HIDDEN, HIDDEN));
let out = self_attention(input, wq, wk, wv, wo).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 10, -0.5, 0.5);
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
rt.set_data(wk, wk_data.clone());
rt.set_data(wv, wv_data.clone());
rt.set_data(wo, wo_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -479,7 +479,7 @@ fn test_swiglu_mlp_cuda() {
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 20, -0.5, 0.5);
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -526,11 +526,11 @@ fn test_rolled_chained_scalar_muls() {
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
let out = (chained + x).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
rt.set_data(x, x_data.clone());
rt = cx.search(rt, 3);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -1,11 +1,15 @@
use candle_core::{Device, Tensor, WithDType};
use cudarc::driver::CudaContext;
use half::{bf16, f16};
use itertools::Itertools;
use luminal::egglog_utils::{
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
validate_choice_set,
};
use luminal::prelude::*;
use luminal::prelude::{
petgraph::{Direction, algo::toposort, visit::EdgeRef},
*,
};
use num_traits::{Num, Signed};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::sync::Arc;
@@ -180,7 +184,7 @@ pub struct SearchEquivalenceFuzzConfig {
pub generation_size: usize,
pub mutations: usize,
pub max_attempts: usize,
pub build_options: BuildSearchSpaceOptions,
pub build_options: CompileOptions,
pub reference: SearchEquivalenceReference,
}
@@ -198,7 +202,7 @@ impl Default for SearchEquivalenceFuzzConfig {
generation_size: 16,
mutations: 2,
max_attempts: 1_000,
build_options: BuildSearchSpaceOptions::default(),
build_options: CompileOptions::default(),
reference: SearchEquivalenceReference::FirstCudaExtraction,
}
}
@@ -210,6 +214,11 @@ pub struct SearchEquivalenceFuzzReport {
pub skipped_invalid: usize,
}
struct ChoiceRun {
outputs: Vec<Vec<f32>>,
llir_summary: String,
}
pub struct CudaSearchEquivalenceFuzzer<'a> {
cx: &'a mut Graph,
stream: &'a Arc<cudarc::driver::CudaStream>,
@@ -249,7 +258,7 @@ impl<'a> CudaSearchEquivalenceFuzzer<'a> {
self
}
pub fn build_options(mut self, build_options: BuildSearchSpaceOptions) -> Self {
pub fn build_options(mut self, build_options: CompileOptions) -> Self {
self.config.build_options = build_options;
self
}
@@ -302,7 +311,8 @@ impl<'a> CudaSearchEquivalenceFuzzer<'a> {
/// LLIR graphs, runs each with identical inputs, and verifies every requested
/// f32 output matches the first valid extraction. The reference is intentionally
/// another selected LLIR graph, not a hand-written CPU implementation: this
/// catches cases where supposedly equivalent e-graph choices diverge.
/// catches cases where supposedly equivalent e-graph choices diverge, including
/// candidates that produce non-finite outputs.
pub fn fuzz_cuda_search_space_equivalence(
cx: &mut Graph,
stream: &Arc<cudarc::driver::CudaStream>,
@@ -317,11 +327,11 @@ pub fn fuzz_cuda_search_space_equivalence(
let native_reference_outputs = if config.reference == SearchEquivalenceReference::NativeRuntime
{
cx.build_search_space::<NativeRuntime>();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut native_rng = StdRng::seed_from_u64(config.seed);
let mut native_rt = cx.search_options(
let mut native_rt = cx.search_with_rng(
NativeRuntime::default(),
SearchOptions::new(1),
CompileOptions::default().search_graph_limit(1),
&mut native_rng,
);
for input in inputs {
@@ -338,7 +348,7 @@ pub fn fuzz_cuda_search_space_equivalence(
None
};
cx.build_search_space_with_options::<CudaRuntime>(config.build_options);
cx.build_search_space::<CudaRuntime>(config.build_options);
let egraph = cx.egraph().expect("search space should be built");
let ops = cx.egglog_ops().expect("search ops should be built");
@@ -354,12 +364,12 @@ pub fn fuzz_cuda_search_space_equivalence(
let mut skipped_invalid = 0usize;
let reference_is_cuda = native_reference_outputs.is_none();
let (reference_hash, reference_outputs, mut tested) =
let (reference_hash, reference_outputs, reference_llir_summary, mut tested) =
if let Some(reference_outputs) = native_reference_outputs {
(0, reference_outputs, 0usize)
(0, reference_outputs, None, 0usize)
} else {
let mut attempts = 0usize;
let (reference_hash, reference_outputs) = loop {
let (reference_hash, reference_run) = loop {
attempts += 1;
if attempts > config.max_attempts {
panic!(
@@ -372,17 +382,19 @@ pub fn fuzz_cuda_search_space_equivalence(
} else {
let hash = hash_choice_set(&base);
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
Ok(values) => break (hash, values),
Err(err) => {
skipped_invalid += 1;
eprintln!("skipping invalid reference candidate hash={hash}: {err}");
}
Ok(run) => break (hash, run),
Err(err) => panic!("reference candidate hash={hash} failed: {err}"),
}
}
base = random_initial_choice(egraph, &mut rng);
prev_selected.insert(hash_choice_set(&base));
};
(reference_hash, reference_outputs, 1usize)
(
reference_hash,
reference_run.outputs,
Some(reference_run.llir_summary),
1usize,
)
};
let mut attempts = 0usize;
@@ -415,12 +427,14 @@ pub fn fuzz_cuda_search_space_equivalence(
continue;
}
let candidate_outputs = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
let candidate_run = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
assert_fuzz_outputs_close(
outputs,
&reference_outputs,
&candidate_outputs,
&candidate_run.outputs,
&candidate_run.llir_summary,
reference_llir_summary.as_deref(),
reference_hash,
candidate_hash,
);
@@ -446,7 +460,7 @@ fn run_choice_outputs<'a>(
inputs: &[CudaFuzzInput],
outputs: &[F32OutputCheck],
choices: &EGraphChoiceSet<'a>,
) -> Result<Vec<Vec<f32>>, String> {
) -> Result<ChoiceRun, String> {
let egraph = cx.egraph().ok_or("search space was not built")?;
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
let mut list_cache = FxHashMap::default();
@@ -461,21 +475,86 @@ fn run_choice_outputs<'a>(
None,
);
unroll_loops_in_llir(&mut llir_graph);
let llir_summary = summarize_llir(&llir_graph);
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir_graph);
rt.preserve_intermediate_buffers_for_debug();
for input in inputs {
input.apply(&mut rt);
}
if std::env::var_os("LUMINAL_FUZZ_DUMP_LAST_LLIR").is_some() {
let _ = std::fs::write("/tmp/luminal_fuzz_last_candidate_llir.txt", &llir_summary);
}
rt.execute(&cx.dyn_map);
let topo_order = toposort(&llir_graph, None).map_err(|cycle| {
format!(
"extracted LLIR contains cycle at node {:?}",
cycle.node_id()
)
})?;
if let Some(report) = rt.first_nonfinite_f32_buffer_in_nodes(topo_order) {
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
let _ = std::fs::write(dump_path, &llir_summary);
let op = llir_graph
.node_weight(report.node)
.map(|op| format!("{op:?}"))
.unwrap_or_else(|| "unknown op".to_string());
return Err(format!(
"LLIR produced non-finite F32 buffer node={} index={} value={} op={}; llir={dump_path}",
report.node.index(),
report.index,
report.value,
op
));
}
Ok(outputs.iter().map(|out| rt.get_f32(out.id)).collect())
let values = outputs
.iter()
.map(|out| rt.get_f32(out.id))
.collect::<Vec<_>>();
for (spec, values) in outputs.iter().zip(&values) {
if let Some((idx, value)) = values
.iter()
.enumerate()
.find(|(_, value)| !value.is_finite())
{
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
let _ = std::fs::write(dump_path, &llir_summary);
let internal = rt
.first_nonfinite_f32_buffer()
.map(|report| {
let op = llir_graph
.node_weight(report.node)
.map(|op| format!("{op:?}"))
.unwrap_or_else(|| "unknown op".to_string());
format!(
"; first observed non-finite buffer node={} index={} value={} op={}",
report.node.index(),
report.index,
report.value,
op
)
})
.unwrap_or_default();
return Err(format!(
"output {} produced non-finite value {value} at index {idx}{internal}; llir={dump_path}",
spec.name
));
}
}
Ok(ChoiceRun {
outputs: values,
llir_summary,
})
}
fn assert_fuzz_outputs_close(
outputs: &[F32OutputCheck],
expected: &[Vec<f32>],
actual: &[Vec<f32>],
candidate_llir_summary: &str,
reference_llir_summary: Option<&str>,
reference_hash: u64,
candidate_hash: u64,
) {
@@ -508,8 +587,16 @@ fn assert_fuzz_outputs_close(
worst = i;
}
if abs > spec.atol + spec.rtol * b.abs() {
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
let _ = std::fs::write(dump_path, candidate_llir_summary);
if let Some(reference_llir_summary) = reference_llir_summary {
let _ = std::fs::write(
"/tmp/luminal_fuzz_bad_reference_llir.txt",
reference_llir_summary,
);
}
panic!(
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={}",
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={} candidate_llir={dump_path}",
spec.name,
spec.atol + spec.rtol * b.abs()
);
@@ -522,6 +609,22 @@ fn assert_fuzz_outputs_close(
}
}
fn summarize_llir(llir_graph: &LLIRGraph) -> String {
llir_graph
.node_indices()
.map(|idx| {
let inputs = llir_graph
.edges_directed(idx, Direction::Incoming)
.sorted_by_key(|edge| edge.id())
.map(|edge| edge.source().index().to_string())
.collect::<Vec<_>>()
.join(", ");
format!("{} <- [{}]: {:?}", idx.index(), inputs, &llir_graph[idx])
})
.collect::<Vec<_>>()
.join("\n")
}
/// Get the GPU compute capability as (major, minor).
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
let ctx = CudaContext::new(0).ok()?;
@@ -593,12 +696,12 @@ pub fn test_unary_cuda<T: TestDType>(
let a = cx.tensor(shape.clone());
let b = func(a).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = generator(n_elements, seed);
rt.set_data(a, input_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, b.id);
@@ -666,14 +769,14 @@ pub fn test_binary_cuda<T: TestDType>(
let b = cx.tensor(b_shape.clone());
let c = func(a, b).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let a_data = a_generator(a_elements, seed);
let b_data = b_generator(b_elements, seed.wrapping_add(1));
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, c.id);
@@ -733,7 +836,7 @@ pub fn test_mod(
let b = cx.tensor(b_shape.clone());
let c = func(a, b).output();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream.clone());
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
@@ -741,7 +844,7 @@ pub fn test_mod(
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(c);

View File

@@ -1,7 +1,7 @@
use hf_hub::api::sync::Api;
use luminal::{
dtype::DType,
graph::{BuildSearchSpaceOptions, DimBucket, Graph},
graph::{CompileOptions, DimBucket, Graph},
prelude::{F32Pow, GraphTensor, Runtime},
};
use luminal_metal::MetalRuntime;
@@ -449,12 +449,34 @@ fn main() -> Result<(), Box<dyn Error>> {
cx.set_dim('s', 1);
cx.set_dim('c', 1);
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let search_s = 16.min(max_prefill).max(2);
let search_c = 16.min(max_context).max(2);
let build_options = CompileOptions::default()
.max_memory_mib(SEARCH_MEMORY_MIB)
.dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
)
.dim_buckets(
'c',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_context).representative(search_c),
],
);
println!("Building E-Graph...");
let egraph_start = Instant::now();
cx.build_search_space_with_options::<MetalRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
);
cx.build_search_space::<MetalRuntime>(build_options);
println!(
" E-Graph build: {:.2} s",
egraph_start.elapsed().as_secs_f64()
@@ -474,28 +496,6 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("Compiling...");
let compile_start = Instant::now();
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let search_s = 16.min(max_prefill).max(2);
let search_c = 16.min(max_context).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim_buckets(
'c',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_context).representative(search_c),
],
);
cx.set_dim('s', search_s);
cx.set_dim('c', search_c);
runtime.set_data(input, vec![1; search_s]);
@@ -503,7 +503,8 @@ fn main() -> Result<(), Box<dyn Error>> {
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, SEARCH_GRAPHS);
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
runtime = cx.search(runtime, search_options);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -31,10 +31,42 @@ impl DynBackend for MetalDynBackend {
}
}
/// Reject dtypes the Metal kernel emitters don't support.
///
/// Metal codegen has no native 64-bit integer or 64-bit float paths.
/// Reaching the kernel emitter with one of these dtypes used to panic deep
/// in MSL generation with an unhelpful error; surfacing a clean message
/// at translate-time lets the user fall back to CPU or pick a narrower
/// dtype before any Metal compilation runs.
fn reject_unsupported_dtype(graph: &Graph) -> Result<(), String> {
for node_id in graph.graph.node_indices() {
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
match input.dtype {
DType::I64 | DType::F64 => {
return Err(format!(
"Metal backend does not support {:?} (input `{}`). \
Metal codegen has no native 64-bit kernels; either \
narrow the dtype (e.g. `.to(torch.int32)` / \
`.to(torch.float32)`) before the boundary or \
compile with the CPU / CUDA backend.",
input.dtype, input.label
));
}
_ => {}
}
}
}
Ok(())
}
pub fn metal_factory(
graph: &mut Graph,
args: BackendCompileArgs,
) -> Result<Box<dyn DynBackend>, String> {
reject_unsupported_dtype(graph)?;
compile_backend::<MetalRuntime>(
graph,
args,

View File

@@ -326,7 +326,7 @@ impl Runtime for MetalRuntime {
fn late_egglog_passes(
ops: &[std::sync::Arc<Box<dyn luminal::op::EgglogOp>>],
options: &luminal::graph::BuildSearchSpaceOptions,
options: &luminal::graph::CompileOptions,
dyn_map: &FxHashMap<char, usize>,
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
vec![crate::memory_analysis::metal_memory_analysis_pass(

View File

@@ -41,7 +41,11 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
let mut rng = StdRng::seed_from_u64(0);
cx.search_options(rt, SearchOptions::new(limit), &mut rng)
cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(limit),
&mut rng,
)
}
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
@@ -297,11 +301,11 @@ fn dynamic_dim_sum_reduce_runs() {
let input = cx.tensor(('a', 2));
let output = input.sum(0).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -315,13 +319,14 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
let input = cx.tensor(('s', 4));
let output = (input + input).output();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.set_dim('s', 1);
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(
CompileOptions::default().dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]),
);
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, vec![1.0f32; 4]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
cx.set_dim('s', 1);
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
@@ -346,10 +351,10 @@ fn metal_int_arithmetic_preserves_large_values() {
let large_index = (token * 1024) + 123;
let mod_output = (large_index % 65_537).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(token, &[16_385i32]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -368,11 +373,11 @@ proptest! {
let input = cx.tensor(len);
let output = (input + input).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -390,11 +395,11 @@ proptest! {
let input = cx.tensor(len);
let output = (input * input).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -412,11 +417,11 @@ proptest! {
let input = cx.tensor(len);
let output = input.exp2().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -433,9 +438,7 @@ fn metal_build_search_space_accepts_memory_budget() {
let b = cx.tensor(4);
(a * b).output();
cx.build_search_space_with_options::<MetalRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(1),
);
cx.build_search_space::<MetalRuntime>(CompileOptions::default().max_memory_mib(1));
}
/// Simple deterministic test for add
@@ -446,11 +449,11 @@ fn metal_simple_add() {
let b = cx.tensor(4);
let output = (a + b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -466,11 +469,11 @@ fn metal_simple_mul() {
let b = cx.tensor(4);
let output = (a * b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -485,10 +488,10 @@ fn metal_simple_exp2() {
let input = cx.tensor(4);
let output = input.exp2().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -502,10 +505,10 @@ fn metal_simple_log2() {
let input = cx.tensor(4);
let output = input.log2().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -519,7 +522,7 @@ fn metal_simple_sin() {
let input = cx.tensor(4);
let output = input.sin().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(
input,
@@ -530,7 +533,7 @@ fn metal_simple_sin() {
3.0 * std::f32::consts::FRAC_PI_2,
],
);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -544,10 +547,10 @@ fn metal_simple_sqrt() {
let input = cx.tensor(4);
let output = input.sqrt().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -561,10 +564,10 @@ fn metal_simple_recip() {
let input = cx.tensor(4);
let output = input.reciprocal().output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -579,11 +582,11 @@ fn metal_simple_mod() {
let b = cx.tensor(4);
let output = (a % b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -598,11 +601,11 @@ fn metal_simple_less_than() {
let b = cx.tensor(4);
let output = a.lt(b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -618,11 +621,11 @@ fn metal_simple_sum_reduce() {
// sum over axis 1
let output = input.sum(1).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -637,11 +640,11 @@ fn metal_simple_max_reduce() {
// max over axis 1
let output = input.max(1).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
rt = cx.search(rt, 5);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -655,10 +658,10 @@ fn metal_f16_cast_roundtrip() {
let input = cx.tensor(4);
let output = input.cast(DType::F16).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
rt = cx.search(rt, 3);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -675,11 +678,11 @@ fn metal_f16_intermediate_add_roundtrip() {
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
rt = cx.search(rt, 3);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -694,7 +697,7 @@ fn metal_specialized_matmul() {
let b = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -734,7 +737,7 @@ fn metal_regular_tiled_matmul_path() {
let b = cx.tensor((k, n));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -769,7 +772,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
let weight = cx.tensor((n, k));
let output = a.matmul(weight.t()).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -804,7 +807,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
let rhs = cx.tensor((k, n));
let output = lhs_storage.t().matmul(rhs).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -843,7 +846,7 @@ fn metal_mps_batched_matmul_row_row_layout() {
let b = cx.tensor((batch, k, n));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSBatchedMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -887,7 +890,7 @@ fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
let merged = attn.transpose(0, 1).merge_dims(1, 2);
let output = merged.matmul(weight.t()).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert!(
egraph_has_op(&cx, "GenericMatmul"),
"expected GenericMatmul rewrite option in e-graph"
@@ -946,7 +949,7 @@ fn metal_mps_batched_matmul_transposed_rhs_layout() {
let weight = cx.tensor((batch, n, k));
let output = a.matmul(weight.permute((0, 2, 1))).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSBatchedMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -987,7 +990,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
let weight = cx.tensor((n, k)).as_dtype(DType::F16);
let output = a.matmul(weight.t()).cast(DType::F32).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
@@ -1019,7 +1022,7 @@ fn metal_rms_norm() {
let weight = cx.tensor(TRANSFORMER_HIDDEN);
let output = rms_norm(input, weight, 1e-5).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1027,7 +1030,7 @@ fn metal_rms_norm() {
rt.set_data(input, &input_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1053,7 +1056,7 @@ fn metal_self_attention() {
let wo = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = self_attention(input, wq, wk, wv, wo).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1067,7 +1070,7 @@ fn metal_self_attention() {
rt.set_data(wk, &wk_data);
rt.set_data(wv, &wv_data);
rt.set_data(wo, &wo_data);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1112,7 +1115,7 @@ fn metal_self_attention_f16_weights() {
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1126,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
rt.set_data(wk, to_f16_vec(&wk_data));
rt.set_data(wv, to_f16_vec(&wv_data));
rt.set_data(wo, to_f16_vec(&wo_data));
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1158,7 +1161,7 @@ fn metal_swiglu_mlp() {
let w_down = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE));
let output = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1170,7 +1173,7 @@ fn metal_swiglu_mlp() {
rt.set_data(w_gate, &gate_data);
rt.set_data(w_up, &up_data);
rt.set_data(w_down, &down_data);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1210,7 +1213,7 @@ fn metal_mini_transformer_layer() {
let layer = MiniTransformerLayer::init(&mut cx);
let output = layer.forward(input).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1220,7 +1223,7 @@ fn metal_mini_transformer_layer() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1276,7 +1279,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
.cast(DType::F32);
let output = (x + mlp_out).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
@@ -1286,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1323,12 +1326,12 @@ fn test_scatter_basic() {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[10.0, 20.0, 30.0]);
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1345,12 +1348,12 @@ fn test_scatter_buffer_roundtrip() {
let cache_out = src.scatter(indexes, cache);
let read = cache_out.output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[0.0]);
rt.set_data(indexes, &[0.0]);
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
for (pos, value, expected) in [
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
@@ -1379,12 +1382,12 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
let tensors = [("weights", Dtype::F32, vec![3], bytes_of(&weight_values))];
let path = write_test_safetensors(&tensors);
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(weights, &[99.0, 99.0, 99.0]);
rt.set_data(bias, &[0.5, 1.0, -1.5]);
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1446,10 +1449,10 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
];
let path = write_test_safetensors(&tensors);
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1469,14 +1472,14 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
let indexes = cx.tensor((2, 2)).as_dtype(DType::Int);
let out = data.gather(indexes).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(
input,
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
);
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1491,12 +1494,12 @@ fn test_scatter_into_nonzero_dest() {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[2f32]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1518,12 +1521,12 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[7.0, 8.0]);
rt.set_data(indexes, &[1.0, 3.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1547,12 +1550,12 @@ fn test_scatter_no_copy_handles_2d_destination() {
let dest = cx.tensor((2, 3));
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[9.0, 8.0]);
rt.set_data(indexes, &[2.0, 4.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1574,12 +1577,12 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
let scatter = src.scatter(indexes, dest).output();
let dest_plus_one = (dest + 1.0).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[1.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1601,12 +1604,12 @@ fn test_scatter_all_positions() {
let dest = cx.tensor(4);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1621,11 +1624,11 @@ fn test_gather_preserves_data_dtype() {
let indexes = cx.tensor(1).as_dtype(DType::Int);
let out = data.gather(indexes).output();
cx.build_search_space::<MetalRuntime>();
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(data, &[1.25, 2.5]);
rt.set_data(indexes, &[1.0]);
rt = cx.search(rt, 1);
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);

View File

@@ -166,8 +166,11 @@ mod tests {
let indices = cx.tensor(3).as_dtype(DType::Int);
let result = gather_rows(data, indices, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
rt.set_data(
@@ -192,8 +195,11 @@ mod tests {
let dest = cx.tensor((4, 3));
let result = scatter_rows(src, indices, dest, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
rt.set_data(indices.id, vec![1, 3]);
@@ -218,8 +224,11 @@ mod tests {
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
@@ -271,8 +280,11 @@ mod tests {
let k_cache_new = k_cache_new.output();
let v_cache_new = v_cache_new.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
rt.set_data(q.id, vec![1., 0., 1., 0.]);
@@ -344,8 +356,11 @@ mod tests {
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
// K cached at slot 0: [1, 0]
@@ -416,8 +431,11 @@ mod tests {
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Cache has 1 token at slot 0
let mut k_cache_data = vec![0.; num_slots * kv_dim];

View File

@@ -183,8 +183,11 @@ mod tests {
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 2.0, 3.0];
// Router strongly favors expert 0
@@ -238,8 +241,11 @@ mod tests {
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 1.0];
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
@@ -292,8 +298,11 @@ mod tests {
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
@@ -349,8 +358,11 @@ mod tests {
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(in_dim);
let router_data = random_vec(in_dim * n_experts);
@@ -394,8 +406,11 @@ mod tests {
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(batch * in_dim);
let router_data = random_vec(in_dim * n_experts);

View File

@@ -8,7 +8,7 @@ echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py"
CUDA_TESTS="tests/"
# ── Phase 1: Native Backend ─────────────────────────────────

View File

@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -5,6 +5,7 @@ use luminal::{
visualization::ToDot,
};
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use std::collections::HashMap;
use crate::typed_data::TypedData;
@@ -73,22 +74,13 @@ fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usiz
Some((var, candidate))
}
/// Convert luminal DType to PT2 dtype integer code (for python interop)
/// Types without a direct Pytorch equivalent map to the closest safe representation
/// Convert luminal `DType` to a PT2 dtype code via `TorchDType`. Panics
/// for luminal-specific dtypes that have no PyTorch counterpart (`I4`,
/// `U4`, the F6 / F4 families, ...).
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
match dtype {
DType::U8 => 1,
DType::I8 => 2,
DType::I16 => 3,
DType::Int => 4, // i32
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
DType::F16 => 6,
DType::F32 | DType::TF32 => 7,
DType::F64 => 8,
DType::Bool => 12,
DType::Bf16 => 13,
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
}
crate::torch_dtype::TorchDType::try_from(dtype)
.map(|t| t.code())
.unwrap_or_else(|d| panic!("luminal_dtype_to_pt2_code: unsupported dtype {d:?}"))
}
/// Common intermediate result from translating a model graph.
@@ -512,6 +504,65 @@ impl CompiledGraph {
Ok(self.runtime.get_output_i32(*node_id))
}
/// Read an output as f16 (returned as raw little-endian bytes —
/// Python has no native f16, so the caller bit-casts via
/// `torch.frombuffer(..., dtype=torch.float16)`). Strict: the
/// producer node must already be `DType::F16`; no widening at
/// the read boundary.
fn get_output_f16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
let data = self.runtime.get_output_f16(*node_id);
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
Ok(PyBytes::new(py, bytes))
}
/// Read an output as bf16 (returned as raw little-endian bytes —
/// caller bit-casts via `torch.frombuffer(..., dtype=torch.
/// bfloat16)`). Strict: the producer node must already be
/// `DType::Bf16`; no widening at the read boundary.
fn get_output_bf16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
let data = self.runtime.get_output_bf16(*node_id);
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
Ok(PyBytes::new(py, bytes))
}
/// Read an output as i64. Strict: the producer node must already
/// be `DType::I64`; no widening at the read boundary.
fn get_output_i64(&self, name: &str) -> PyResult<Vec<i64>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
Ok(self.runtime.get_output_i64(*node_id))
}
/// Read an output as f64. Strict: the producer node must already
/// be `DType::F64`; no widening at the read boundary.
fn get_output_f64(&self, name: &str) -> PyResult<Vec<f64>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
Ok(self.runtime.get_output_f64(*node_id))
}
/// Get output tensor data by name as bool (copies to host).
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {

View File

@@ -0,0 +1,120 @@
//! Canonical-form helpers for dimension `Expression` arithmetic — used
//! by the translator to keep shape arithmetic syntactically consistent
//! across code paths.
//!
//! `Expression` equality is syntactic; `a * 8` and `8 * a` are distinct
//! objects despite being mathematically equal. When two translator code
//! paths build the same logical dim via differently-ordered
//! multiplications, downstream `assert_eq!(self.dims(), rhs.dims())`
//! checks in `GraphTensor::Add` / `Sub` / `Mul` / `Rem` panic. These
//! helpers solve that at the construction site: every shape product
//! goes through `product_of_dims`, which sorts the operand list before
//! folding, so two callers passing the operands in different orders
//! produce identical `Expression`s.
//!
//! Lives in `luminal_python` (rather than upstream `luminal::shape`) so
//! the change is contained to the translator. luminal-core callers of
//! `gather_elements` / `scatter_elements` / `scatter_nd` historically
//! pass concrete dims, so they don't need this; the translator-local
//! lowerings in `translator::movement_dynamic` do.
//!
//! The ordering matches what `pt2_expr.rs::normalize_mul_expr` was
//! using locally before being promoted here — see that file for the
//! original canonical-sort logic.
use luminal::prelude::Expression;
/// Sort key for the canonical commutative ordering. Sorts by RPN-term
/// count first so single-term operands (variables, literals) sort
/// before compound subexpressions; ties broken by debug repr so two
/// single-term operands have a stable alphabetic order.
///
/// O(n) string alloc per compare — only call on shape products, never
/// per-element in a kernel.
#[inline]
pub(crate) fn commutative_key(expr: &Expression) -> (usize, String) {
(expr.len(), format!("{expr:?}"))
}
/// Order `(a, b)` so the canonically-smaller expression is first.
#[inline]
pub(crate) fn sort_pair(a: Expression, b: Expression) -> (Expression, Expression) {
if commutative_key(&a) <= commutative_key(&b) {
(a, b)
} else {
(b, a)
}
}
/// Multiply two dim expressions with canonical operand ordering.
#[inline]
pub(crate) fn mul_dims(a: Expression, b: Expression) -> Expression {
let (a, b) = sort_pair(a, b);
a * b
}
/// Add two dim expressions with canonical operand ordering.
#[inline]
pub(crate) fn add_dims(a: Expression, b: Expression) -> Expression {
let (a, b) = sort_pair(a, b);
a + b
}
/// Product of a sequence of dim expressions. Operands are sorted
/// canonically before folding so callers passing the same logical
/// dim set in different orders produce identical `Expression`s.
/// Empty sequence → `Expression::from(1usize)`.
pub(crate) fn product_of_dims<I>(dims: I) -> Expression
where
I: IntoIterator<Item = Expression>,
{
let mut v: Vec<Expression> = dims.into_iter().collect();
v.sort_by_key(commutative_key);
v.into_iter()
.fold(Expression::from(1usize), |acc, d| acc * d)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mul_dims_canonicalises_commutative_order() {
let a = Expression::from('a');
let n = Expression::from(8i64);
assert_eq!(mul_dims(a, n), mul_dims(n, a));
}
#[test]
fn product_of_dims_independent_of_input_order() {
let a = Expression::from('a');
let b = Expression::from('b');
let n = Expression::from(8i64);
let p1 = product_of_dims([a, n, b]);
let p2 = product_of_dims([n, b, a]);
let p3 = product_of_dims([b, a, n]);
assert_eq!(p1, p2);
assert_eq!(p1, p3);
}
#[test]
fn empty_product_is_one() {
let empty: Vec<Expression> = vec![];
assert_eq!(product_of_dims(empty), Expression::from(1usize));
}
#[test]
fn mixed_numeric_types_canonicalise_together() {
// `pt2_util` builds with `Expression::from(usize)` while tests /
// direct callers reach for `i64`. The two literal paths must
// produce identical reprs or `product_of_dims` will sort them
// into different positions and we lose the canonical-form
// guarantee across call sites.
assert_eq!(Expression::from(8usize), Expression::from(8i64));
let a = Expression::from('a');
assert_eq!(
product_of_dims([Expression::from(8usize), a]),
product_of_dims([Expression::from(8i64), a]),
);
}
}

View File

@@ -1,4 +1,6 @@
mod compiled_graph;
mod dim_arith;
pub mod torch_dtype;
pub mod typed_data;
// PT2 modules
@@ -13,17 +15,32 @@ use compiled_graph::CompiledGraph;
use pt2_compiled_model::process_pt2;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use std::collections::HashMap;
use torch_dtype::TorchDType;
#[pymodule]
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
m.add_class::<CompiledGraph>()?;
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
m.add_function(wrap_pyfunction!(_torch_dtype_codes, m)?)?;
#[cfg(feature = "cuda")]
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
Ok(())
}
/// `{variant_name: pt2_code}` for every `TorchDType` variant. The Python
/// parity test (`tests/test_torch_dtype_parity.py`) consumes this and
/// asserts every entry matches `torch._export.serde.schema.ScalarType.<name>
/// .value` — drift fails CI rather than silently miscompiling at runtime.
#[pyfunction]
fn _torch_dtype_codes() -> HashMap<&'static str, u32> {
TorchDType::ALL
.iter()
.map(|v| (v.name(), v.code()))
.collect()
}
// ---------------------------------------------------------------------------
// Factory capsule helpers
// ---------------------------------------------------------------------------

View File

@@ -7,10 +7,10 @@ use std::collections::HashMap;
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
use crate::pt2_expr::parse_sympy_expr;
use crate::pt2_parser;
use crate::pt2_schema;
use crate::translator;
use crate::typed_data::TypedData;
use crate::{pt2_parser, pt2_util};
/// Pre-loaded weight/constant data paired with tensor sizes.
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
@@ -374,52 +374,10 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
}
}
/// Convert raw bytes to TypedData using PT2 dtype numbering.
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
/// Converts i64/f64/i16 to the closest luminal-native representation.
/// Convert raw bytes to `TypedData` using PT2 dtype numbering. Thin
/// wrapper around `TypedData::from_pytorch_bytes` — the dtype dispatch
/// (including the narrow-int panic and unknown-code rejection) lives
/// there, so this site stays a one-liner that just clones the slice.
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
match dtype {
// Types that map directly — preserve raw bytes
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
// i64 → i32 (truncate, matching luminal's Int type)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
TypedData::from_i32_vec(i32s)
}
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
TypedData::from_f32_vec(f32s)
}
// i16 → i32 (widen to luminal's Int)
3 => {
let i32s: Vec<i32> = bytes
.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
TypedData::from_i32_vec(i32s)
}
_ => {
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
}
}
TypedData::from_pytorch_bytes(bytes.to_vec(), dtype)
}

View File

@@ -251,26 +251,12 @@ fn normalize_expr(expr: Expression) -> Expression {
}
}
fn commutative_key(expr: Expression) -> (usize, String) {
(expr.len(), format!("{expr:?}"))
}
fn sort_commutative(lhs: Expression, rhs: Expression) -> (Expression, Expression) {
if commutative_key(lhs) <= commutative_key(rhs) {
(lhs, rhs)
} else {
(rhs, lhs)
}
}
fn normalize_add_expr(lhs: Expression, rhs: Expression) -> Expression {
let (lhs, rhs) = sort_commutative(lhs, rhs);
normalize_expr(lhs + rhs)
normalize_expr(crate::dim_arith::add_dims(lhs, rhs))
}
fn normalize_mul_expr(lhs: Expression, rhs: Expression) -> Expression {
let (lhs, rhs) = sort_commutative(lhs, rhs);
normalize_expr(lhs * rhs)
normalize_expr(crate::dim_arith::mul_dims(lhs, rhs))
}
fn checked_add_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {

View File

@@ -114,29 +114,17 @@ pub fn resolve_neg1_dim(target: &[i64], current_dims: &[Expression]) -> Vec<Expr
}
if let Some(idx) = neg1_idx {
let mut total = Expression::from(1usize);
for d in current_dims {
total *= *d;
}
if let (Some(total_val), Some(_)) = (
{
let mut t = 1i64;
let mut all_concrete = true;
for d in current_dims {
if let Some(v) = d.to_usize() {
t *= v as i64;
} else {
all_concrete = false;
}
}
if all_concrete { Some(t) } else { None }
},
Some(known_product),
) {
result[idx] = Expression::from((total_val / known_product) as usize);
} else {
result[idx] = total / Expression::from(known_product as usize);
}
result[idx] = match current_dims
.iter()
.map(|d| d.to_usize())
.collect::<Option<Vec<_>>>()
{
Some(vs) => Expression::from(vs.iter().product::<usize>() / known_product as usize),
None => {
crate::dim_arith::product_of_dims(current_dims.iter().copied())
/ Expression::from(known_product as usize)
}
};
}
result
@@ -185,11 +173,12 @@ pub fn resolve_neg1_dim_exprs(
if input_symbolic.is_empty() {
result[idx] = Expression::from((input_concrete / target_concrete) as usize);
} else {
let mut expr = Expression::from((input_concrete / target_concrete) as usize);
for s in &input_symbolic {
expr *= *s;
}
result[idx] = expr;
let mut operands: Vec<Expression> = Vec::with_capacity(input_symbolic.len() + 1);
operands.push(Expression::from(
(input_concrete / target_concrete) as usize,
));
operands.extend(input_symbolic.iter().copied());
result[idx] = crate::dim_arith::product_of_dims(operands);
}
result
@@ -198,16 +187,29 @@ pub fn resolve_neg1_dim_exprs(
}
}
/// Map torch dtype integer (PT2 format) to luminal DType.
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
/// Map a PT2 dtype code to luminal `DType`. Panics for variants the IR
/// doesn't model as first-class types (narrow ints `Byte` / `Char` /
/// `Short`, the complex family, the float8 family) and for unknown
/// codes — better to fail loudly at the translator boundary than to
/// silently widen and lie about the user's dtype.
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
match dtype {
6 => DType::F16,
7 => DType::F32,
8 => DType::F32, // float64 → F32 (no F64 in luminal)
13 => DType::Bf16,
12 => DType::Bool,
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
_ => DType::F32,
let t = crate::torch_dtype::TorchDType::from_code(dtype)
.unwrap_or_else(|c| panic!("torch_dtype_int_to_luminal: unknown PT2 dtype code {c}"));
match t {
crate::torch_dtype::TorchDType::Byte
| crate::torch_dtype::TorchDType::Char
| crate::torch_dtype::TorchDType::Short => panic!(
"torch_dtype_int_to_luminal: PT2 dtype {} (code {}) isn't a first-class \
IR type yet — cast to torch.int32 at the call site, or wait for the \
narrower-int IR follow-up.",
t.name(),
t.code(),
),
other => DType::try_from(other).unwrap_or_else(|t| {
panic!(
"torch_dtype_int_to_luminal: {} isn't a first-class luminal IR type",
t.name()
)
}),
}
}

View File

@@ -0,0 +1,235 @@
//! Typed mirror of PyTorch's PT2 export-schema `ScalarType` enum.
//!
//! The PT2 export pipeline wire-serializes tensor dtypes as `u32` codes drawn
//! from `torch._export.serde.schema.ScalarType` (an `IntEnum` on the Python
//! side). Three sites in this crate used to carry duplicate raw-`u32` match
//! arms with the canonical numbering hand-rolled in each — silent miscompile
//! risk when PyTorch renumbers or adds a code. This module collapses those
//! sites onto one typed enum and pins the numbering with a parity test that
//! asserts every Rust variant matches `torch._export.serde.schema.ScalarType`
//! at CI time (see `crates/luminal_python/tests/test_torch_dtype_parity.py`).
//!
//! Note: PyTorch's C++ `c10::ScalarType` uses a different numbering than the
//! PT2 schema (PT2 reserves 0 for `Unknown`); we bind to the **PT2 schema**,
//! not the c10 header, because that is what flows over our wire.
use luminal::prelude::DType;
/// PT2 export-schema dtype code. Discriminants match
/// `torch._export.serde.schema.ScalarType` variant values exactly; drift is
/// caught by `tests/test_torch_dtype_parity.py`.
#[repr(u32)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum TorchDType {
Unknown = 0,
Byte = 1,
Char = 2,
Short = 3,
Int = 4,
Long = 5,
Half = 6,
Float = 7,
Double = 8,
ComplexHalf = 9,
ComplexFloat = 10,
ComplexDouble = 11,
Bool = 12,
BFloat16 = 13,
Uint16 = 28,
Float8E4m3Fn = 29,
Float8E5m2 = 30,
Float8E4m3Fnuz = 31,
Float8E5m2Fnuz = 32,
}
impl TorchDType {
/// All variants, in declaration order. Used by the pyo3-exported parity
/// table and by tests; add new variants here when PyTorch adds them.
pub const ALL: &'static [TorchDType] = &[
TorchDType::Unknown,
TorchDType::Byte,
TorchDType::Char,
TorchDType::Short,
TorchDType::Int,
TorchDType::Long,
TorchDType::Half,
TorchDType::Float,
TorchDType::Double,
TorchDType::ComplexHalf,
TorchDType::ComplexFloat,
TorchDType::ComplexDouble,
TorchDType::Bool,
TorchDType::BFloat16,
TorchDType::Uint16,
TorchDType::Float8E4m3Fn,
TorchDType::Float8E5m2,
TorchDType::Float8E4m3Fnuz,
TorchDType::Float8E5m2Fnuz,
];
/// Canonical wire code (matches `ScalarType.<name>.value` in Python).
#[inline]
pub fn code(self) -> u32 {
self as u32
}
/// PyTorch schema variant name (e.g. `"LONG"`, `"BFLOAT16"`). Used by the
/// parity test to align Rust variants with `ScalarType.<name>`.
pub fn name(self) -> &'static str {
match self {
TorchDType::Unknown => "UNKNOWN",
TorchDType::Byte => "BYTE",
TorchDType::Char => "CHAR",
TorchDType::Short => "SHORT",
TorchDType::Int => "INT",
TorchDType::Long => "LONG",
TorchDType::Half => "HALF",
TorchDType::Float => "FLOAT",
TorchDType::Double => "DOUBLE",
TorchDType::ComplexHalf => "COMPLEXHALF",
TorchDType::ComplexFloat => "COMPLEXFLOAT",
TorchDType::ComplexDouble => "COMPLEXDOUBLE",
TorchDType::Bool => "BOOL",
TorchDType::BFloat16 => "BFLOAT16",
TorchDType::Uint16 => "UINT16",
TorchDType::Float8E4m3Fn => "FLOAT8E4M3FN",
TorchDType::Float8E5m2 => "FLOAT8E5M2",
TorchDType::Float8E4m3Fnuz => "FLOAT8E4M3FNUZ",
TorchDType::Float8E5m2Fnuz => "FLOAT8E5M2FNUZ",
}
}
/// Parse from a wire code. `Err(code)` if the code isn't a known PyTorch
/// variant — the caller decides whether to panic with context or fall
/// through to a non-PT2 path.
pub fn from_code(code: u32) -> Result<Self, u32> {
for v in Self::ALL {
if v.code() == code {
return Ok(*v);
}
}
Err(code)
}
}
/// PyTorch dtype → luminal `DType`. `Err(self)` for variants luminal's IR
/// doesn't model as first-class types — the narrow ints (`Byte` / `Char` /
/// `Short`), the complex family, and the float8 NUZ variants. `DType::U8`,
/// `DType::I8`, `DType::I16` exist on the luminal side but the IR has no
/// kernels / codegen for them, so we refuse the conversion here rather
/// than silently producing a buffer the kernels can't actually run.
/// Boundary code panics with the variant name on `Err`; cf.
/// `typed_data::from_pytorch_bytes`, `pt2_util::torch_dtype_int_to_luminal`.
impl TryFrom<TorchDType> for DType {
type Error = TorchDType;
fn try_from(t: TorchDType) -> Result<Self, Self::Error> {
Ok(match t {
TorchDType::Int => DType::Int,
TorchDType::Long => DType::I64,
TorchDType::Half => DType::F16,
TorchDType::Float => DType::F32,
TorchDType::Double => DType::F64,
TorchDType::Bool => DType::Bool,
TorchDType::BFloat16 => DType::Bf16,
TorchDType::Float8E4m3Fn => DType::F8E4M3,
TorchDType::Float8E5m2 => DType::F8E5M2,
TorchDType::Byte
| TorchDType::Char
| TorchDType::Short
| TorchDType::Uint16
| TorchDType::Unknown
| TorchDType::ComplexHalf
| TorchDType::ComplexFloat
| TorchDType::ComplexDouble
| TorchDType::Float8E4m3Fnuz
| TorchDType::Float8E5m2Fnuz => return Err(t),
})
}
}
/// luminal `DType` → PyTorch dtype. `Err(dtype)` for luminal-specific
/// variants without a first-class PyTorch counterpart — the narrow ints
/// (`U8` / `I8` / `I16` / `U16`), the sub-byte / exotic widths (`I4`,
/// `U4`, `F6E2M3`, ...), and `TF32`.
///
/// `TF32` is a compute-mode hint inside luminal, not a storage dtype on
/// the PyTorch side (PyTorch has no `torch.tf32`); silently mapping it to
/// `Float` would hand PyTorch an f32 buffer that the caller had been
/// tracking as TF32 inside luminal. Refuse instead — a real cast to
/// `DType::F32` upstream is the explicit way to bridge.
impl TryFrom<DType> for TorchDType {
type Error = DType;
fn try_from(d: DType) -> Result<Self, Self::Error> {
Ok(match d {
DType::F32 => TorchDType::Float,
DType::F64 => TorchDType::Double,
DType::F16 => TorchDType::Half,
DType::Bf16 => TorchDType::BFloat16,
DType::Int => TorchDType::Int,
DType::I64 => TorchDType::Long,
DType::Bool => TorchDType::Bool,
DType::F8E4M3 => TorchDType::Float8E4m3Fn,
DType::F8E5M2 => TorchDType::Float8E5m2,
_ => return Err(d),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_codes() {
for v in TorchDType::ALL {
assert_eq!(TorchDType::from_code(v.code()).unwrap(), *v);
}
}
#[test]
fn supported_dtypes_roundtrip() {
// Only the variants luminal's IR models as first-class can
// roundtrip cleanly. Narrow ints (`U8` / `I8` / `I16` / `U16`)
// are intentionally excluded — see the `TryFrom` impls.
for d in [
DType::F32,
DType::F64,
DType::F16,
DType::Bf16,
DType::Int,
DType::I64,
DType::Bool,
] {
let t = TorchDType::try_from(d).expect("known DType");
let back = DType::try_from(t).expect("known TorchDType");
assert_eq!(d, back, "roundtrip mismatch for {d:?}");
}
}
#[test]
fn narrow_ints_refuse_conversion() {
// Forward (PyTorch → luminal) and reverse (luminal → PyTorch)
// both refuse the narrow-int variants; downstream sites translate
// the `Err` into a typed panic with the variant name.
for t in [TorchDType::Byte, TorchDType::Char, TorchDType::Short] {
assert!(DType::try_from(t).is_err(), "expected Err for {t:?}");
}
for d in [
DType::U8,
DType::I8,
DType::I16,
DType::U16,
// TF32 is a luminal-internal compute-mode hint, not a PyTorch
// storage dtype — refuse to silently alias it as `Float`.
DType::TF32,
] {
assert!(TorchDType::try_from(d).is_err(), "expected Err for {d:?}");
}
}
#[test]
fn unknown_code_errors() {
assert!(TorchDType::from_code(99).is_err());
assert!(TorchDType::from_code(14).is_err()); // gap in PT2 numbering
}
}

View File

@@ -175,7 +175,11 @@ impl<'a> Translator<'a> {
"torch.ops.aten.pow.Tensor_Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let exp = self.get_float_arg(node, 1)?;
a.pow(exp as f32)
if (exp - 2.0).abs() < f64::EPSILON {
a * a
} else {
a.pow(exp as f32)
}
}
"torch.ops.aten.pow.Tensor_Tensor" => {
let a = self.get_input_tensor(node, 0)?;

View File

@@ -7,6 +7,7 @@ mod binary;
mod conv;
mod dispatch;
mod movement;
mod movement_dynamic;
mod reduction;
mod tensor;
mod unary;

View File

@@ -306,7 +306,11 @@ impl<'a> Translator<'a> {
let mut target: Vec<Expression> = src_dims.to_vec();
target[first_non_none_dim] = idx_dim_size;
expanded.shape.expand(target);
return Ok(source.gather_elements(expanded, first_non_none_dim));
return Ok(super::movement_dynamic::pt2_gather_elements(
source,
expanded,
first_non_none_dim,
));
}
} else {
bail!(
@@ -426,7 +430,7 @@ impl<'a> Translator<'a> {
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
let result = a.gather_elements(normalized, dim);
let result = super::movement_dynamic::pt2_gather_elements(a, normalized, dim);
Ok(if promoted_rank0 {
result.squeeze(0)
} else {
@@ -440,7 +444,12 @@ impl<'a> Translator<'a> {
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
let src = self.get_input_tensor(node, 3)?;
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
Ok(super::movement_dynamic::pt2_scatter_elements(
a,
indices.cast(DType::Int),
src,
dim,
))
}
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -463,7 +472,12 @@ impl<'a> Translator<'a> {
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
}
.expand_rhs(indices.shape);
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
Ok(super::movement_dynamic::pt2_scatter_elements(
a,
indices.cast(DType::Int),
value,
dim,
))
}
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -508,7 +522,7 @@ impl<'a> Translator<'a> {
let indices = idx_tensor.cast(DType::Int);
let new_last = indices.shape.len();
let indices = indices.expand_dim(new_last, Expression::from(1usize));
Ok(a.scatter_nd(indices, values))
Ok(super::movement_dynamic::pt2_scatter_nd(a, indices, values))
} else {
bail!("index_put with multiple index tensors not yet supported");
}

View File

@@ -0,0 +1,231 @@
//! Symbolic-dim-safe `gather_elements` / `scatter_elements` / `scatter_nd`
//! lowerings for the PT2 translator.
//!
//! The luminal-core versions in `luminal::frontend::movement` require
//! concrete shape dims — they call `d.to_usize().expect(...)` on every
//! input dim and panic at translate-time when `torch.compile` hands us a
//! batch dim, sequence-length dim, or any other dynamic dim. PT2's whole
//! point is dynamic shapes, so we re-implement the same three ops here
//! using `Expression`-typed shape arithmetic and only call luminal-core
//! primitives that already accept `Expression`s (`Graph::constant`,
//! `Graph::iota`, `flatten_strides`, `ShapeTracker::new(Vec<Expression>)`,
//! `expand_dim`, `expand_rhs`, `flatten`, `slice_along`, `squeeze`,
//! `cast`, `scatter`, `gather`).
//!
//! Every shape product flows through `crate::dim_arith::product_of_dims`
//! so the `Expression`s we build are canonical: two callers that produce
//! the same logical dim via differently-ordered multiplications end up
//! with byte-identical `Expression`s. Without this, downstream dim-equality
//! asserts in luminal-core's `Add` / `Sub` (see `src/frontend/binary.rs`)
//! panic on `a*8` ≠ `8*a` after these helpers feed into broadcast paths.
use luminal::prelude::*;
use crate::dim_arith::product_of_dims;
/// Row-major strides as `Expression`s. `stride[i] = prod(dims[i+1..])`.
fn row_major_strides(dims: &[Expression]) -> Vec<Expression> {
let rank = dims.len();
(0..rank)
.map(|i| product_of_dims(dims[i + 1..].iter().copied()))
.collect()
}
/// Build the additive non-axis contribution to a flat index over a
/// rank-`rank` output of shape `out_shape`. The axis dim contributes
/// 0; every other dim `d` contributes `iota_d * strides[d]`. Materialised
/// via one `Graph::iota` call with `flatten_strides(out_shape, axis_exprs)`
/// — same pattern luminal core uses, just with `Expression` throughout.
fn non_axis_flat(
graph: &mut Graph,
out_shape: &[Expression],
strides: &[Expression],
axis: usize,
) -> GraphTensor {
let rank = out_shape.len();
let axis_exprs: Vec<Expression> = (0..rank)
.map(|d| {
if d == axis {
Expression::from(0)
} else {
Expression::from('z') * strides[d]
}
})
.collect();
graph.iota(flatten_strides(out_shape, &axis_exprs), out_shape.to_vec())
}
/// Wrap negative axis indices into `[0, axis_dim)`. Equivalent to
/// `if idx < 0 { idx + axis_dim } else { idx }` in tensor form.
fn normalize_negative_index(indices: GraphTensor, axis_dim: Expression) -> GraphTensor {
let idx_f32 = indices.cast(DType::F32);
let zero = idx_f32
.graph()
.constant_float(0.0)
.expand_rhs(idx_f32.shape);
let adj = idx_f32
.graph()
.constant(axis_dim)
.cast(DType::F32)
.expand_rhs(idx_f32.shape);
let is_neg = idx_f32.lt(zero).cast(DType::F32);
(idx_f32 + (is_neg * adj)).cast(DType::Int)
}
/// Translator-local `gather_elements` that accepts symbolic shape dims.
/// Mirrors `GraphTensor::gather_elements` semantics but uses
/// `Expression`-typed shape arithmetic and only calls symbol-safe
/// luminal-core primitives.
///
/// `output[i0,..,ik] = self[i0,..,i_{axis-1}, indices[i0,..,ik], i_{axis+1},..,ik]`
pub fn pt2_gather_elements(data: GraphTensor, indexes: GraphTensor, axis: usize) -> GraphTensor {
let dims = data.dims();
let out_shape: Vec<Expression> = indexes.dims();
let strides = row_major_strides(&dims);
let idx_normalized = normalize_negative_index(indexes, dims[axis]);
let non_axis_flat = non_axis_flat(data.graph(), &out_shape, &strides, axis);
let stride_tensor = data
.graph()
.constant(strides[axis])
.expand_rhs(idx_normalized.shape);
let flat_idx = non_axis_flat + idx_normalized * stride_tensor;
data.gather(flat_idx)
}
/// Translator-local `scatter_elements` that accepts symbolic shape dims.
/// Same semantics as `GraphTensor::scatter_elements`.
pub fn pt2_scatter_elements(
data: GraphTensor,
indices: GraphTensor,
updates: GraphTensor,
axis: usize,
) -> GraphTensor {
let data_dims = data.dims();
let idx_shape: Vec<Expression> = indices.dims();
let strides = row_major_strides(&data_dims);
let idx_normalized = normalize_negative_index(indices, data_dims[axis]);
let non_axis_flat = non_axis_flat(data.graph(), &idx_shape, &strides, axis);
let stride_tensor = data
.graph()
.constant(strides[axis])
.expand_rhs(idx_normalized.shape);
let flat_dest = non_axis_flat + idx_normalized * stride_tensor;
let flat_dest_1d = flat_dest.flatten();
let flat_updates = updates.flatten();
let flat_data = data.flatten();
let output_flat = flat_updates.scatter(flat_dest_1d, flat_data);
// View-only reshape back to data shape; the buffer is already laid
// out row-major from the scatter, so swapping the tracker is safe.
let mut result = output_flat;
result.shape = ShapeTracker::new(data_dims);
result
}
/// Translator-local `scatter_nd` that accepts symbolic shape dims.
/// Mirrors `GraphTensor::scatter_nd` semantics.
pub fn pt2_scatter_nd(
data: GraphTensor,
indices: GraphTensor,
updates: GraphTensor,
) -> GraphTensor {
let indices = indices.cast(DType::Int);
let data_dims = data.dims();
let data_rank = data_dims.len();
let idx_dims = indices.dims();
let idx_rank = idx_dims.len();
// The last dim of indices is the index width K — it must be
// concrete at translate-time because it controls how many
// contribution terms we build statically. HuggingFace's MoE
// accumulator (the path that brought us here via `index_put`)
// always passes a literal; non-HF callers with a SymInt K would
// need a different lowering.
let k = idx_dims[idx_rank - 1]
.to_usize()
.expect("scatter_nd: indices innermost dim (K) must be concrete");
assert!(k <= data_rank, "scatter_nd: K must be <= data rank");
// Batch shape = indices shape without last dim.
let batch_shape: Vec<Expression> = idx_dims[..idx_rank - 1].to_vec();
let batch_numel = product_of_dims(batch_shape.iter().copied());
// Trailing shape = data_shape[K..]
let trailing_shape: Vec<Expression> = data_dims[k..].to_vec();
let trailing_numel = product_of_dims(trailing_shape.iter().copied());
let data_strides = row_major_strides(&data_dims);
// Flatten batch dims of indices to [batch_numel, K] via view reshape.
let mut indices_flat = indices;
if idx_rank > 2 {
indices_flat.shape = ShapeTracker::new(vec![batch_numel, Expression::from(k)]);
}
let mut flat_base: Option<GraphTensor> = None;
for (k_dim, stride) in data_strides.iter().copied().enumerate().take(k) {
let idx_k = indices_flat.slice_along(k_dim..k_dim + 1, indices_flat.dims().len() - 1);
let idx_k = idx_k.squeeze(idx_k.dims().len() - 1);
let stride_tensor = data.graph().constant(stride).expand_rhs(idx_k.shape);
let contribution = idx_k * stride_tensor;
flat_base = Some(match flat_base {
Some(fb) => fb + contribution,
None => contribution,
});
}
let flat_base = flat_base.unwrap();
// Trailing-numel concreteness drives whether we need the expand-and-fold
// path. If trailing_shape is empty OR its numel collapses to 1, the flat
// base is already the full destination index.
let trailing_is_unit = trailing_shape.is_empty() || trailing_numel.to_usize() == Some(1);
let mut full_flat_dest = if trailing_is_unit {
flat_base
} else {
let mut base_expanded = flat_base.expand_dim(1, trailing_numel);
let trailing_rank = trailing_shape.len();
for (ti, d) in (k..data_rank).enumerate() {
let ar = data.graph().arange(data_dims[d]);
let mut ar_shaped = ar;
for _ in ti + 1..trailing_rank {
let n = ar_shaped.dims().len();
ar_shaped = ar_shaped.expand_dim(n, 1);
}
for _ in 0..ti {
ar_shaped = ar_shaped.expand_dim(0, 1);
}
ar_shaped.shape.expand(trailing_shape.clone());
let mut ar_flat = ar_shaped;
ar_flat.shape = ShapeTracker::new(vec![trailing_numel]);
ar_flat = ar_flat.expand_dim(0, batch_numel);
let stride_tensor = data
.graph()
.constant(data_strides[d])
.expand_rhs(ar_flat.shape);
base_expanded += ar_flat * stride_tensor;
}
base_expanded
};
full_flat_dest = full_flat_dest.flatten();
let flat_updates = updates.flatten();
let flat_data = data.flatten();
let output_flat = flat_updates.scatter(full_flat_dest, flat_data);
let mut result = output_flat;
result.shape = ShapeTracker::new(data_dims);
result
}

View File

@@ -119,10 +119,8 @@ impl<'a> Translator<'a> {
/// buffer would be sized for the un-sliced argsort tensor while the
/// shape tracker reports a smaller rank.
///
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
/// metadata records int64 and the Python wrapper widens at the
/// boundary, so the PyTorch contract is preserved end-to-end
/// (LUM-486).
/// The result is cast to `DType::I64` to match PyTorch's int64
/// argmax / argmin indices.
pub(crate) fn translate_argextremum(
&mut self,
node: &Node,
@@ -149,7 +147,7 @@ impl<'a> Translator<'a> {
None | Some(0) | Some(-1) => {
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
// `keepdim=True` does not add a dimension when the input is 0-d.
return Ok(self.graph.constant(0i64).cast(DType::Int));
return Ok(self.graph.constant(0i64).cast(DType::I64));
}
Some(dim) => {
return Err(anyhow::anyhow!(
@@ -188,6 +186,6 @@ impl<'a> Translator<'a> {
} else {
picked
};
Ok(result * 1)
Ok((result * 1).cast(DType::I64))
}
}

View File

@@ -413,15 +413,18 @@ impl<'a> Translator<'a> {
// Build top-k outputs from a full stable argsort. Slice the indices
// before gathering values so the gather shape matches the requested
// top-k output rather than the full sort width.
// top-k output rather than the full sort width. Cast to I64 so the
// emitted indices match PyTorch's `torch.topk` semantics (indices
// are int64); `gather_elements` accepts any int dtype on its index
// operand, so a single I64 tensor serves both consumers.
let full_argsort = a.stable_argsort(dim, true);
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
let topk_indices = (full_argsort.slice_along(..k, dim) * 1.0).cast(DType::I64);
// Only build the outputs that are consumed.
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
let values = a.gather_elements(topk_indices, dim);
let values = super::movement_dynamic::pt2_gather_elements(a, topk_indices, dim);
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
@@ -465,11 +468,12 @@ impl<'a> Translator<'a> {
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
let values = a.gather_elements(full_argsort, dim);
let values = super::movement_dynamic::pt2_gather_elements(a, full_argsort, dim);
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
let indices = full_argsort * 1.0;
// `torch.sort` returns int64 indices; cast at the PT2 boundary.
let indices = (full_argsort * 1.0).cast(DType::I64);
self.tensors.insert(idx_name, indices);
}

View File

@@ -35,7 +35,12 @@ impl<'a> Translator<'a> {
false
};
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
Ok(a.stable_argsort(dim, descending))
// PyTorch's `torch.argsort` returns int64 unconditionally;
// luminal's frontend `stable_argsort` returns i32 (storage-
// efficient default for native Rust callers). Cast at the
// PT2↔luminal boundary so the strict output-read path sees
// an I64 buffer.
Ok(a.stable_argsort(dim, descending).cast(DType::I64))
}
pub(crate) fn translate_unary_op(

View File

@@ -4,7 +4,6 @@
//! through the PT2 path without forcing everything to f32.
use luminal::hlir::NativeData;
use luminal::prelude::tracing::warn;
use luminal::prelude::*;
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
@@ -149,62 +148,40 @@ impl TypedData {
}
}
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
/// in luminal's native format. Handles widening/narrowing conversions for types where
/// PyTorch's byte layout differs from luminal's:
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype
/// code) to `TypedData`. Supported dtypes preserve their raw bytes —
/// no width changes at the FFI boundary. Narrow integer widths
/// (`Byte` / `Char` / `Short`) panic: luminal's `NativeData` has no
/// narrower-integer variants yet, so the only way they could pass
/// through is via implicit widening to `i32`, which the no-implicit-
/// cast directive forbids. Cast at the call site
/// (`x.to(torch.int32)`) or wait for the narrower-int IR follow-up.
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
match dtype_code {
// Types that map directly — preserve raw bytes
7 => Self::from_raw(bytes, DType::F32),
6 => Self::from_raw(bytes, DType::F16),
13 => Self::from_raw(bytes, DType::Bf16),
4 => Self::from_raw(bytes, DType::Int), // i32
12 => Self::from_raw(bytes, DType::Bool),
// i64 → i32 (truncate)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
Self::from_i32_vec(i32s)
}
// f64 → f32 (downcast)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
Self::from_f32_vec(f32s)
}
// i16 → i32 (widen)
3 => {
let i32s: Vec<i32> = bytes
.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
Self::from_i32_vec(i32s)
}
// u8 → i32 (widen)
1 => {
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
Self::from_i32_vec(i32s)
}
// i8 → i32 (widen, signed)
2 => {
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
Self::from_i32_vec(i32s)
}
// Unknown: best-effort pass-through as f32
_ => {
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
Self::from_raw(bytes, DType::F32)
}
let t = crate::torch_dtype::TorchDType::from_code(dtype_code)
.unwrap_or_else(|c| panic!("from_pytorch_bytes: unknown PT2 dtype code {c}"));
match t {
crate::torch_dtype::TorchDType::Float => Self::from_raw(bytes, DType::F32),
crate::torch_dtype::TorchDType::Half => Self::from_raw(bytes, DType::F16),
crate::torch_dtype::TorchDType::BFloat16 => Self::from_raw(bytes, DType::Bf16),
crate::torch_dtype::TorchDType::Int => Self::from_raw(bytes, DType::Int),
crate::torch_dtype::TorchDType::Bool => Self::from_raw(bytes, DType::Bool),
crate::torch_dtype::TorchDType::Long => Self::from_raw(bytes, DType::I64),
crate::torch_dtype::TorchDType::Double => Self::from_raw(bytes, DType::F64),
crate::torch_dtype::TorchDType::Byte
| crate::torch_dtype::TorchDType::Char
| crate::torch_dtype::TorchDType::Short => panic!(
"from_pytorch_bytes: PT2 dtype {} (code {}) isn't a first-class \
IR type yet — cast to torch.int32 at the call site, or wait \
for the narrower-int IR follow-up.",
t.name(),
t.code(),
),
other => panic!(
"from_pytorch_bytes: PT2 dtype {} (code {}) isn't a first-class \
IR type — no luminal mapping.",
other.name(),
other.code(),
),
}
}

View File

@@ -8,6 +8,20 @@ from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
class DTypeBoundaryError(TypeError):
"""Raised when the caller passes an input whose dtype does not match the
compiled graph's declared input dtype.
The previous behaviour cast silently at every call, which (a) hid real
precision bugs (e.g. f64 → f32 truncation on values outside the f32
range) and (b) burnt CPU/GPU on a per-call allocation+copy that the
user couldn't see in their profile. The contract is now strict:
`model(x)` requires `x.dtype == model.input_dtypes[i]` for every
positional input. Convert at the call site with
`x.to(model.input_dtypes[i])` if you need a different dtype.
"""
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
@@ -35,14 +49,18 @@ class CompiledModel:
self._supports_device_ptrs = getattr(
graph_result, "supports_device_ptrs", False
)
# Expected input dtypes from graph (used to convert user inputs)
# Expected input dtypes from graph. Every declared input MUST
# have a dtype code — refuse to silently default to float32 if
# the Rust side returned a shorter list than `input_names`.
input_dtype_codes = graph_result.input_dtypes
self._input_dtypes = [
code_to_torch_dtype(input_dtype_codes[i])
if i < len(input_dtype_codes)
else torch.float32
for i in range(len(self._input_names))
]
if len(input_dtype_codes) != len(self._input_names):
raise RuntimeError(
f"CompiledGraph returned {len(input_dtype_codes)} input dtype "
f"codes for {len(self._input_names)} declared inputs "
f"({self._input_names!r}) — every declared input needs a "
f"matching dtype."
)
self._input_dtypes = [code_to_torch_dtype(c) for c in input_dtype_codes]
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -95,13 +113,22 @@ class CompiledModel:
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if tensor.dtype != expected_dtype:
raise DTypeBoundaryError(
f"Luminal compiled input '{name}' expects "
f"{expected_dtype} but got {tensor.dtype}. "
"Convert at the call site with "
f"`x.to({expected_dtype})` — the boundary used to silently "
"cast (and warn) on every call, which masked precision "
"bugs and burnt cycles on per-call allocation+copy."
)
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
t = tensor.detach().contiguous()
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().to(expected_dtype)
t = tensor.detach().cpu().contiguous()
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
@@ -112,100 +139,120 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
# Every declared output MUST have a dtype code; refuse to default
# to float32 the way we used to if the Rust side returned fewer
# codes than declared outputs.
output_dtype_codes = self._graph.output_dtypes
if len(output_dtype_codes) != len(self._output_names):
raise RuntimeError(
f"CompiledGraph returned {len(output_dtype_codes)} output "
f"dtype codes for {len(self._output_names)} declared outputs "
f"({self._output_names!r}) — every declared output needs a "
f"matching dtype."
)
output_torch_dtypes = [code_to_torch_dtype(c) for c in output_dtype_codes]
# CUDA zero-copy path: pre-allocate output tensors and register their device
# pointers so the final kernel writes directly into PyTorch's buffer.
# Per-dtype dispatch table mapping `torch_dtype` → the typed
# `_graph` getter for that dtype. Every supported dtype has an
# explicit native-width getter; anything not listed raises
# `NotImplementedError` from `_read_typed_output`. There is no
# open-ended fallback — a missing entry means we don't know how
# to read that dtype yet, and we'd rather fail loudly than
# silently reinterpret bytes.
#
# `float16` / `bfloat16` getters return `uint16` bit patterns
# (Python has no native `f16` / `bf16`); the helper below
# bit-casts them back to the declared dtype via
# `torch.frombuffer`. That's a reinterpret, not a numeric
# cast — no precision change.
#
# Narrow ints (`int8` / `int16` / `uint8`) are intentionally
# absent — luminal's IR refuses them at the FFI boundary (cf.
# `pt2_util::torch_dtype_int_to_luminal`,
# `typed_data::from_pytorch_bytes`), so a graph can never
# declare a narrow-int output that reaches this dispatch.
_zero_copy_native_floats = (torch.float32, torch.float16, torch.bfloat16)
_output_readers = {
torch.float32: ("get_output", torch.float32),
torch.float64: ("get_output_f64", torch.float64),
torch.float16: ("get_output_f16", torch.float16),
torch.bfloat16: ("get_output_bf16", torch.bfloat16),
torch.int64: ("get_output_i64", torch.int64),
torch.int32: ("get_output_i32", torch.int32),
torch.bool: ("get_output_bool", torch.bool),
}
def _read_typed_output(name: str, shape, out_dtype) -> torch.Tensor:
"""Pull one output back from the runtime at the right dtype.
Strict: any `out_dtype` not in `_output_readers` raises
`NotImplementedError`. The previous code's open-ended
fallback read the buffer as f32 and `.to(out_dtype)`'d
back, which silently aliased dtypes we don't really
support; refusing surfaces the gap.
For `float16` / `bfloat16` the typed getter returns
`uint16` bit patterns (Python has no native half-precision
float type); we bit-cast via `torch.tensor(..., uint16)`
and `.view(half)` so the conversion is a reinterpret of the
bytes, not a numeric cast.
"""
entry = _output_readers.get(out_dtype)
if entry is None:
raise NotImplementedError(
f"Output '{name}' declared dtype {out_dtype} isn't "
f"supported by the luminal read boundary. Add a typed "
f"getter for this dtype (see `_output_readers`) or cast "
f"the output to a supported dtype upstream."
)
getter_name, read_dtype = entry
data = getattr(self._graph, getter_name)(name)
if out_dtype in (torch.float16, torch.bfloat16):
# Getter returned an immutable `bytes` from Rust; wrap in
# `bytearray` to make the storage writable (suppresses
# the "non-writable buffer" warning), then bit-cast via
# `frombuffer` — no numeric conversion.
tensor = torch.frombuffer(bytearray(data), dtype=out_dtype).reshape(
tuple(shape)
)
else:
tensor = torch.tensor(data, dtype=read_dtype).reshape(tuple(shape))
return tensor.to(input_device)
# Pre-allocation is GPU-only: the CUDA kernel needs the
# output's device pointer registered *before* `_graph.run()`
# so the final kernel writes directly into PyTorch's buffer.
# Only the float dtypes luminal natively writes
# (`_zero_copy_native_floats`) take the zero-copy path; other
# dtypes (int*, bool, f64) read back via `_read_typed_output`
# after `run()` and so don't need a pre-allocated tensor at
# this layer. CPU never zero-copies — there's no separate
# device buffer to register against.
_use_zero_copy = self._supports_device_ptrs
output_tensors = []
if _use_zero_copy:
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
out_dtype = output_torch_dtypes[i]
out = torch.empty(shape, dtype=out_dtype, device=input_device)
if out_dtype.is_floating_point:
if out_dtype in _zero_copy_native_floats:
self._graph.set_output_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
output_tensors.append(out)
# Run the graph
self._graph.run()
# Integer dtypes for which we read the buffer as i32 and then cast.
# Includes int64 because luminal collapses all integer types to its
# 32-bit `Int` internally — we restore the original precision here.
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
# Collect outputs
if _use_zero_copy:
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = output_torch_dtypes[i]
if _use_zero_copy and out_dtype in _zero_copy_native_floats:
out = output_tensors[i]
if out_dtype.is_floating_point:
if not self._graph.output_is_zero_copy(name):
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
if not self._graph.output_is_zero_copy(name):
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = (
torch.tensor(data, dtype=torch.bool)
.reshape(tuple(shape))
.to(input_device)
)
else:
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
outputs.append(out)
else:
# Native path: retrieve as f32, then convert to target dtype if needed.
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
if out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(out_dtype)
)
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
else:
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(out_dtype)
)
out = out.to(input_device)
outputs.append(out)
else:
out = _read_typed_output(name, shape, out_dtype)
outputs.append(out)
return tuple(outputs)

View File

@@ -1,28 +1,61 @@
"""Shared dtype utility functions for the luminal Python Bridge"""
"""Shared dtype utility functions for the luminal Python bridge.
The PT2 dtype-code numbering is sourced from
``torch._export.serde.schema.ScalarType`` at import time — PyTorch is the
canonical source of truth on both sides of the FFI boundary. The Rust side
mirrors the same enum in ``luminal_python/rust/src/torch_dtype.rs`` and is
held in agreement by ``tests/test_torch_dtype_parity.py``.
``torch._export.serde.schema`` is a quasi-private API (leading underscore),
but it is the module PT2 export actually wire-serializes against; binding
to it here is the right boundary. If PyTorch reorganizes the module path,
the import below will fail loudly at module load.
"""
import torch
from torch._export.serde.schema import ScalarType
# Map each `torch.dtype` we care about to the PT2 code PyTorch itself
# would emit for it. Looking up `ScalarType.<NAME>.value` keeps the
# numbering in lockstep with PyTorch — if PyTorch renumbers, we pick
# up the new code automatically (and the Rust parity test catches the
# drift from the other side).
_TORCH_DTYPE_TO_CODE = {
torch.uint8: 1,
torch.int8: 2,
torch.int16: 3,
torch.int32: 4,
torch.int64: 5,
torch.float16: 6,
torch.float32: 7,
torch.float64: 8,
torch.bool: 12,
torch.bfloat16: 13,
torch.uint8: ScalarType.BYTE.value,
torch.int8: ScalarType.CHAR.value,
torch.int16: ScalarType.SHORT.value,
torch.int32: ScalarType.INT.value,
torch.int64: ScalarType.LONG.value,
torch.float16: ScalarType.HALF.value,
torch.float32: ScalarType.FLOAT.value,
torch.float64: ScalarType.DOUBLE.value,
torch.bool: ScalarType.BOOL.value,
torch.bfloat16: ScalarType.BFLOAT16.value,
}
_CODE_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_CODE.items()}
def torch_dtype_code(dtype):
"""Map torch.dtype to PT2 dtype integer code."""
return _TORCH_DTYPE_TO_CODE.get(dtype, 7) # default to f32
"""Map torch.dtype to PT2 dtype integer code. Raises `KeyError`
on an unsupported dtype rather than silently aliasing to FLOAT."""
try:
return _TORCH_DTYPE_TO_CODE[dtype]
except KeyError:
raise KeyError(
f"torch_dtype_code: {dtype} isn't a supported PT2 dtype "
f"(supported: {sorted(_TORCH_DTYPE_TO_CODE.keys(), key=str)})"
) from None
def code_to_torch_dtype(code):
"""Map PT2 dtype integer code to torch.dtype."""
return _CODE_TO_TORCH_DTYPE.get(code, torch.float32)
"""Map PT2 dtype integer code to torch.dtype. Raises `KeyError`
on an unknown code rather than silently defaulting to float32."""
try:
return _CODE_TO_TORCH_DTYPE[code]
except KeyError:
raise KeyError(
f"code_to_torch_dtype: PT2 dtype code {code} isn't mapped "
f"to a torch.dtype (known codes: "
f"{sorted(_CODE_TO_TORCH_DTYPE.keys())})"
) from None

View File

@@ -0,0 +1,250 @@
from dataclasses import dataclass
import warnings
from typing import Callable
import pytest
import torch
from luminal import luminal_backend
class BoundaryNoopModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.bool:
return x | torch.zeros((), dtype=torch.bool, device=x.device)
return x + torch.zeros((), dtype=x.dtype, device=x.device)
@dataclass(frozen=True)
class DTypeCase:
name: str
dtype: torch.dtype
values: Callable[[], torch.Tensor]
xfail_reason: str | None = None
DTYPE_CASES = [
DTypeCase(
"bool",
torch.bool,
lambda: torch.tensor([True, False, True], dtype=torch.bool),
),
DTypeCase(
"uint8",
torch.uint8,
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
),
DTypeCase(
"int8",
torch.int8,
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
),
DTypeCase(
"int16",
torch.int16,
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
),
DTypeCase(
"int32",
torch.int32,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int32,
),
),
DTypeCase(
"int64_i32_range",
torch.int64,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int64,
),
),
DTypeCase(
"float16",
torch.float16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
),
DTypeCase(
"bfloat16",
torch.bfloat16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
),
DTypeCase(
"float32",
torch.float32,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
),
DTypeCase(
"float64_f32_exact",
torch.float64,
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
),
DTypeCase(
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
),
DTypeCase(
"float64_precision_sensitive",
torch.float64,
lambda: torch.tensor(
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
),
]
def _cuda_skip_reason() -> str | None:
if not torch.cuda.is_available():
return "CUDA is not available"
try:
from luminal.luminal import _cuda_lite_factory_capsule
_cuda_lite_factory_capsule()
except (ImportError, AttributeError, RuntimeError) as exc:
return f"luminal_python was not built with CUDA support: {exc}"
return None
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
def boundary_device(request) -> torch.device:
device_name = request.param
if device_name == "cuda":
skip_reason = _cuda_skip_reason()
if skip_reason is not None:
pytest.skip(skip_reason)
return torch.device(device_name)
# Dtypes that round-trip the BoundaryNoopModel without an explicit
# `x.to(model.input_dtypes[0])` cast at the call site. Anything not in this
# set is a narrow integer (uint8 / int8 / int16) that luminal collapses to
# `DType::Int` internally — the hard-reject contract makes the boundary
# refuse the mismatched dtype, and the test for those lives in
# `test_input_dtype_mismatch_rejects` instead.
_FIRST_CLASS_NOOP_DTYPES = {
"bool",
"int32",
"int64_i32_range",
"int64_outside_i32_range",
"float16",
"bfloat16",
"float32",
"float64_f32_exact",
"float64_precision_sensitive",
}
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
if case.xfail_reason is not None
else (),
id=case.name,
)
for case in DTYPE_CASES
if case.name in _FIRST_CLASS_NOOP_DTYPES
],
)
def test_boundary_noop_preserves_dtype_and_values(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
expected = model(x)
actual = compiled(x)
assert isinstance(actual, torch.Tensor)
assert actual.dtype == expected.dtype
assert torch.equal(actual.cpu(), expected.cpu())
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
# Narrow integer widths (uint8 / int8 / int16) aren't first-class in
# luminal's IR — the translator refuses them outright. int64 /
# float64 are first-class and round-trip without rejection.
if case.name in {"uint8", "int8", "int16"}
],
)
def test_input_dtype_mismatch_rejects(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
"""Hard-reject contract: a graph whose declared input dtype is one of
the narrow ints (uint8 / int8 / int16) fails at compile time with a
clear panic from `torch_dtype_int_to_luminal`. Previously the
translator silently widened narrow ints to `Int` (i32), which left
the user's actual dtype invisible past the FFI boundary; today the
failure points at the missing IR support directly.
"""
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
# `pyo3_runtime.PanicException` inherits from `BaseException` (not
# `Exception`), so `pytest.raises(Exception, ...)` would miss it.
# Match on the panic message text — stable across torch versions.
with pytest.raises(BaseException, match="isn't a first-class IR type yet"):
compiled(x)
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"bool",
"int32",
"float16",
"bfloat16",
"float32",
# int64 / float64 are first-class in the IR — passing a tensor
# of either dtype matches the graph's input dtype directly, no
# conversion needed.
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_matching_dtype_does_not_raise(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
"""Round-trip contract: a user input whose dtype matches the graph's
declared input dtype runs without raising, with no warnings emitted at
the boundary."""
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always")
compiled(x)
boundary_warnings = [
record
for record in records
if "boundary" in str(record.message).lower()
or "convert" in str(record.message).lower()
]
assert boundary_warnings == [], (
f"unexpected boundary-related warning(s): {boundary_warnings}"
)

View File

@@ -0,0 +1,109 @@
"""Dynamic-shape regression coverage for the movement ops Qwen3-MoE /
Gemma4-MoE exercise via `torch.compile`.
Three failure modes surfaced while debugging the Qwen3-30B-A3B path:
1. `gather_elements: index dim must be concrete` — `gather_elements`
/ `scatter_elements` collected index dims as `Vec<usize>` via
`.to_usize().expect(...)`. First forward worked; the second forward
at a different seq_len made Dynamo emit a SymInt dim and tripped
the assertion.
2. `Dims must match to add tensors. left: [(a*8), 2048] right: [(8*a), 2048]`
— different translator paths produced semantically-equal but
syntactically-different `Expression` dims.
3. `scatter_nd: data dim must be concrete` — same family as (1),
reached via `translate_index_put` (HF's MoE accumulator).
"""
from __future__ import annotations
import torch
from luminal.main import luminal_backend
def _compile(model):
return torch.compile(model, backend=luminal_backend)
def test_gather_elements_dynamic_index_shape(device: torch.device) -> None:
"""`torch.gather` with a dynamic batch dim on the index tensor."""
class GatherModel(torch.nn.Module):
def forward(self, table: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
expanded = table.unsqueeze(0).expand(indices.shape[0], -1, -1)
idx = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 32)
return torch.gather(expanded, 1, idx).squeeze(1)
model = GatherModel().to(device)
compiled = _compile(model)
table = torch.randn(8, 32, device=device)
for batch in [4, 7, 11, 4]:
idx = torch.randint(0, 8, (batch,), device=device, dtype=torch.int64)
assert torch.allclose(compiled(table, idx), model(table, idx), atol=1e-4)
def test_scatter_elements_dynamic_index_shape(device: torch.device) -> None:
"""`torch.scatter` with a dynamic batch dim on the index tensor."""
class ScatterModel(torch.nn.Module):
def forward(self, values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
dest = torch.zeros(
values.shape[0], 16, device=values.device, dtype=values.dtype
)
return dest.scatter(1, indices, values)
model = ScatterModel().to(device)
compiled = _compile(model)
for batch in [4, 7, 11, 4]:
# Distinct indices per row → no-overlap scatter for allclose.
idx = torch.stack(
[torch.randperm(16, device=device)[:4] for _ in range(batch)]
).to(torch.int64)
vals = torch.randn(batch, 4, device=device)
assert torch.allclose(compiled(vals, idx), model(vals, idx), atol=1e-4)
def test_scatter_nd_dynamic_data_shape(device: torch.device) -> None:
"""`tensor[idx] = value` → `translate_index_put` → `scatter_nd`."""
class ScatterNDModel(torch.nn.Module):
def forward(
self, base: torch.Tensor, idx: torch.Tensor, vals: torch.Tensor
) -> torch.Tensor:
out = base.clone()
out[idx] = vals
return out
model = ScatterNDModel().to(device)
compiled = _compile(model)
for batch in [4, 7, 11, 4]:
base = torch.randn(16, 4, device=device)
idx = torch.randperm(16, device=device)[:batch].to(torch.int64)
vals = torch.randn(batch, 4, device=device)
assert torch.allclose(
compiled(base, idx, vals), model(base, idx, vals), atol=1e-4
)
def test_where_dynamic_shape_no_dim_mismatch_panic(device: torch.device) -> None:
"""`torch.where` over inputs whose shape derives from a SymInt:
two translator paths can produce `a*8` vs `8*a` for the same dim,
which trips the dim-equality assert in luminal-core's `Sub` /
`Add` without canonical ordering in `dim_arith`.
"""
class WhereModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.where(x > 0, x, y)
model = WhereModel().to(device)
compiled = _compile(model)
for batch in [4, 7, 11, 4]:
x = torch.randn(batch, 16, device=device)
y = torch.randn(batch, 16, device=device)
assert torch.allclose(compiled(x, y), model(x, y), atol=1e-4)

View File

@@ -230,7 +230,6 @@ def test_hf_llama_decode_loop_static(device: torch.device):
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
"""Decode loop on real Llama3.2-1B with pretrained weights.
@@ -286,7 +285,6 @@ def _gpu_mem(label):
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
@@ -338,7 +336,6 @@ def test_hf_llama3_full(device: torch.device):
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_large_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
@@ -365,7 +362,7 @@ def test_hf_llama3_large_full(device: torch.device):
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
assert torch.allclose(out.logits, ref.logits, atol=1e-4), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
@@ -420,7 +417,6 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
@pytest.mark.slow
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama38b_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
@@ -447,7 +443,7 @@ def test_hf_llama38b_full(device: torch.device):
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
assert torch.allclose(out.logits, ref.logits, atol=1e-4), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)

View File

@@ -0,0 +1,47 @@
"""Pin luminal's Rust `TorchDType` enum to PyTorch's PT2 schema.
The PT2 export pipeline wire-serializes dtypes as `u32` codes drawn from
`torch._export.serde.schema.ScalarType`. luminal mirrors that enum in
`crates/luminal_python/rust/src/torch_dtype.rs` and depends on the
discriminants matching exactly. If PyTorch renumbers, adds, or removes a
variant, this test fails loudly at CI time — better than a silent
miscompile at runtime.
"""
from torch._export.serde.schema import ScalarType
# `_torch_dtype_codes` is the pyo3-exported map `{variant_name: pt2_code}`.
from luminal.luminal import _torch_dtype_codes
def test_rust_variants_match_pytorch():
"""Every Rust variant must agree with PyTorch's code for the same name."""
rust = _torch_dtype_codes()
pt = {v.name: v.value for v in ScalarType}
mismatches = []
for name, code in rust.items():
if name not in pt:
mismatches.append(f"{name}: luminal={code}, pytorch=<missing variant>")
elif pt[name] != code:
mismatches.append(f"{name}: luminal={code}, pytorch={pt[name]}")
assert not mismatches, (
"torch_dtype.rs and PyTorch's ScalarType have drifted:\n "
+ "\n ".join(mismatches)
)
def test_no_pytorch_variants_missing_from_rust():
"""Surface new PyTorch variants so we know to extend the Rust enum.
Failure here doesn't necessarily indicate a bug — it just means
PyTorch added a dtype (e.g. a new float8 variant) and luminal should
decide whether to mirror it. Update `TorchDType::ALL` in
`torch_dtype.rs` plus the `TryFrom` impls to resolve.
"""
rust = _torch_dtype_codes()
missing = [v.name for v in ScalarType if v.name not in rust]
assert not missing, (
"PyTorch ScalarType variants not mirrored in luminal::TorchDType: "
f"{missing}. Extend TorchDType::ALL in torch_dtype.rs and decide "
"whether each maps to a luminal DType variant."
)

View File

@@ -37,7 +37,7 @@ use std::fs::File;
use std::io::BufWriter;
use std::time::Instant;
use luminal::graph::BuildSearchSpaceOptions;
use luminal::graph::CompileOptions;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use rand::{Rng, SeedableRng, rngs::StdRng};
@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
.unwrap_or(default)
}
fn search_options() -> CompileOptions {
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
}
fn env_f32(name: &str, default: f32) -> f32 {
std::env::var(name)
.ok()
@@ -159,11 +163,9 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
s.parse::<usize>()
.map_err(|_| std::env::VarError::NotPresent)
}) {
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_gib(g),
);
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
} else {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
}
let ctx = CudaContext::new(0).unwrap();
@@ -189,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
println!("Compiling text encoder...");
let t0 = Instant::now();
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
runtime = cx.search(runtime, search_options());
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
println!("Encoding prompt...");
@@ -301,11 +303,9 @@ fn run_full_pipeline(
s.parse::<usize>()
.map_err(|_| std::env::VarError::NotPresent)
}) {
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_gib(g),
);
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
} else {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
}
let ctx = CudaContext::new(0).unwrap();
@@ -349,10 +349,9 @@ fn run_full_pipeline(
{
use rand::SeedableRng;
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
let opts = luminal::graph::SearchOptions::new(env_usize("SEARCH_ITERS", 5));
runtime = cx.search_options(runtime, opts, &mut rng);
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
} else {
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
runtime = cx.search(runtime, search_options());
}
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
@@ -409,11 +408,9 @@ fn run_full_pipeline(
s.parse::<usize>()
.map_err(|_| std::env::VarError::NotPresent)
}) {
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_gib(g),
);
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
} else {
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
}
let ctx = CudaContext::new(0).unwrap();
@@ -421,7 +418,7 @@ fn run_full_pipeline(
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
runtime.set_data(latent_in, vae_input);
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
runtime = cx.search(runtime, search_options());
runtime.execute(&cx.dyn_map);
let img = runtime.get_f32(out);
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'

View File

@@ -720,6 +720,10 @@ mod tests {
out
}
fn one_search() -> CompileOptions {
CompileOptions::default().search_graph_limit(1)
}
#[test]
fn conv2d_bias_matches_reference() {
let mut cx = Graph::default();
@@ -746,8 +750,8 @@ mod tests {
},
);
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -765,8 +769,8 @@ mod tests {
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 - 11.0).collect();
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.execute(&cx.dyn_map);
@@ -787,10 +791,10 @@ mod tests {
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 - 11.0).collect();
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(ctx.default_stream());
rt.set_data(input_t, input);
rt = cx.search(rt, 1);
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);
@@ -820,8 +824,8 @@ mod tests {
},
);
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -859,12 +863,12 @@ mod tests {
},
);
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(ctx.default_stream());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
rt = cx.search(rt, 1);
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);

View File

@@ -53,9 +53,20 @@ fn main() {
k_out.output();
v_out.output();
}
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
let build_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(build_options);
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
@@ -69,22 +80,12 @@ fn main() {
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim('s', search_s);
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, search_graphs);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -36,14 +36,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -68,114 +70,11 @@ pub struct Gemma {
impl Gemma {
pub fn init(cx: &mut Graph) -> Self {
let mut w = vec![];
for l in 0..LAYERS {
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(GemmaLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
input_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
cx,
),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
});
}
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
lm_head,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -185,11 +84,7 @@ impl Gemma {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -226,6 +121,114 @@ struct GemmaLayer {
rope_scaling_factor: f32,
}
impl GemmaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
input_layernorm: layer_norm(cx, l, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = gemma_rotary_embeddings(
qk_norm(q, self.q_norm, N_HEADS),
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
qk_norm(k, self.k_norm, N_KV_HEADS),
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = x + self.post_attention_layernorm.forward(attn_proj);
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
(
x + self.post_feedforward_layernorm.forward(mlp_out),
k_cache_out,
v_cache_out,
)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
/// This produces far fewer e-graph nodes than the tanh-based expansion.
#[allow(clippy::excessive_precision)]
@@ -363,59 +366,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl GemmaLayer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm + RoPE
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
let q_rope = gemma_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
k_normed,
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
// O projection + post-attention norm + residual
let attn_proj = attn_out.matmul(self.o_proj.t());
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
let x = x + attn_normed;
// Pre-feedforward norm + MLP + post-feedforward norm + residual
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
(x + mlp_normed, k_cache_out, v_cache_out)
}
}

View File

@@ -49,9 +49,20 @@ fn main() {
k_out.output();
v_out.output();
}
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
let build_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(build_options);
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
@@ -65,27 +76,15 @@ fn main() {
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim('s', search_s);
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(
runtime,
SearchOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(search_graphs)
.profile_timeout(Duration::from_secs(2));
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);

View File

@@ -83,20 +83,16 @@ impl KVCache {
let mut v_caches = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let k = cx
.named_tensor(
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
let v = cx
.named_tensor(
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
v_caches.push(persist(
cx,
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
}
Self {
k_caches,
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
impl Gemma4MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let gate = cx
.named_tensor(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let up = cx
.named_tensor(
format!("model.layers.{layer}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{layer}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
(spec.q_dim, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist();
let v_proj = spec.has_v_proj.then(|| {
cx.named_tensor(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist()
});
let o_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
(HIDDEN, spec.q_dim),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_norm.weight"),
spec.head_dim,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_norm.weight"),
spec.head_dim,
)
.persist();
let layer_scalar = cx
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
.persist();
let router_scale = cx
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
.persist();
let router_proj = cx
.named_tensor(
format!("model.layers.{layer}.router.proj.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let per_expert_scale = cx
.named_tensor(
format!("model.layers.{layer}.router.per_expert_scale"),
NUM_EXPERTS,
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.gate_up_proj"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist()
.as_dtype(DType::Bf16);
let down_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.down_proj"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist()
.as_dtype(DType::Bf16);
layers.push(Gemma4Layer {
spec,
gate,
up,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
layer_scalar,
input_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm_1: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
cx,
),
post_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
cx,
),
pre_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
cx,
),
moe: Gemma4SparseMoE {
router_scale,
router_proj,
per_expert_scale,
gate_up_weights,
down_weights,
},
});
}
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
Self {
embedding,
lm_head,
layers,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS)
.map(|layer| Gemma4Layer::init(cx, layer))
.collect(),
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -287,11 +127,7 @@ impl Gemma4MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (layer_idx, layer) in self.layers.iter().enumerate() {
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
down_weights: GraphTensor,
}
impl Gemma4Layer {
fn init(cx: &mut Graph, layer: usize) -> Self {
let spec = layer_spec(layer);
Self {
spec,
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
v_proj: spec
.has_v_proj
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
moe: Gemma4SparseMoE {
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
gate_up_weights: layer_tensor(
cx,
layer,
"experts.gate_up_proj",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
layer,
"experts.down_proj",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
}
@@ -505,81 +499,6 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl Gemma4Layer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
}

View File

@@ -17,7 +17,7 @@ const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
const MAX_SEQ_LEN: usize = 4096;
const GEN_TOKENS: usize = 500;
const SEARCH_GRAPHS: usize = 500;
const SEARCH_TRIALS: usize = 1;
const SEARCH_TRIALS: usize = 10;
const SEARCH_KEEP_BEST: usize = 4;
const SEARCH_MEMORY_MIB: usize = 2048;
const SEARCH_SEED: u64 = 0;
@@ -290,12 +290,21 @@ fn main() {
cx.set_dim('s', 1);
cx.set_dim('c', 1);
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
let search_s = 16.min(max_prefill).max(2);
let build_options = CompileOptions::default()
.max_memory_mib(SEARCH_MEMORY_MIB)
.dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
println!("Building E-Graph...");
let egraph_start = std::time::Instant::now();
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
);
cx.build_search_space::<CudaRuntime>(build_options);
println!(
" E-Graph build: {:.2} s",
egraph_start.elapsed().as_secs_f64()
@@ -318,15 +327,6 @@ fn main() {
println!("Compiling...");
let compile_start = std::time::Instant::now();
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim('s', search_s);
cx.set_dim('c', search_s);
runtime.set_data(input, vec![1; search_s]);
@@ -338,13 +338,11 @@ fn main() {
println!(" Search trials: {SEARCH_TRIALS}");
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(
runtime,
SearchOptions::new(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -111,125 +111,18 @@ impl Llama {
config: LlamaConfig,
fp8_linears: bool,
) -> Self {
let mut layers = Vec::with_capacity(config.layers);
for l in 0..config.layers {
layers.push(LlamaLayer {
config,
up: linear_weight(
cx,
format!("model.layers.{l}.mlp.up_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
up_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.up_proj"),
fp8_linears,
),
gate: linear_weight(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
gate_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
fp8_linears,
),
down: linear_weight(
cx,
format!("model.layers.{l}.mlp.down_proj"),
(config.hidden, config.intermediate),
fp8_linears,
),
down_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.down_proj"),
fp8_linears,
),
q_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
q_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
fp8_linears,
),
k_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
k_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
fp8_linears,
),
v_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
v_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
fp8_linears,
),
o_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
o_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
fp8_linears,
),
attn_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
config,
embedding: cx
.named_tensor(
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
)
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
.persist(),
lm_norm: LayerNorm::new(
config.hidden,
Some("model.norm.weight"),
None,
false,
1e-5,
embedding: persist(
cx,
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
),
layers: (0..config.layers)
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
.collect(),
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
}
}
@@ -243,12 +136,7 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let hidden = self.config.hidden;
let mut x = self.embedding.gather(
(input * hidden).expand_dim(1, hidden)
+ input.graph().arange(hidden).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, input, self.config.hidden);
let mut cache_outputs = Vec::with_capacity(self.config.layers);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
weight: GraphTensor,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
Self {
config,
up: layer_linear_weight(
cx,
l,
"mlp.up_proj",
(config.intermediate, config.hidden),
fp8,
),
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
gate: layer_linear_weight(
cx,
l,
"mlp.gate_proj",
(config.intermediate, config.hidden),
fp8,
),
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
down: layer_linear_weight(
cx,
l,
"mlp.down_proj",
(config.hidden, config.intermediate),
fp8,
),
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
q_proj: layer_linear_weight(
cx,
l,
"self_attn.q_proj",
(config.hidden, config.hidden),
fp8,
),
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
k_proj: layer_linear_weight(
cx,
l,
"self_attn.k_proj",
(config.kv_dim(), config.hidden),
fp8,
),
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
v_proj: layer_linear_weight(
cx,
l,
"self_attn.v_proj",
(config.kv_dim(), config.hidden),
fp8,
),
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
o_proj: layer_linear_weight(
cx,
l,
"self_attn.o_proj",
(config.hidden, config.hidden),
fp8,
),
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
attn_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.input_layernorm.weight"),
),
mlp_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn linear_weight(
cx: &mut Graph,
prefix: impl ToString,
@@ -325,6 +377,16 @@ fn linear_weight(
}
}
fn layer_linear_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
fp8: bool,
) -> GraphTensor {
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
}
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
if !fp8 {
return None;
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
})
}
fn layer_linear_scales(
cx: &mut Graph,
layer: usize,
suffix: &str,
fp8: bool,
) -> Option<Fp8LinearScales> {
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
}
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * hidden).expand_dim(1, hidden)
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
)
}
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
scale.expand_rhs(like.dims())
}
@@ -443,87 +526,3 @@ fn attention(
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}

View File

@@ -204,9 +204,20 @@ fn main() {
k_out.output();
v_out.output();
}
// Bucket s=1 (decode) vs s>1 (prefill/mixed). Each bucket gets its own
// optimized compilation — decode can select warp-parallel kernels while
// prefill can select tiled matmul / cuBLAS.
let max_prefill = (tokens_a.len().max(tokens_b.len()) + 16).next_power_of_two();
let build_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(16),
],
);
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(build_options);
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
@@ -220,18 +231,6 @@ fn main() {
}
println!("Compiling...");
// Bucket s=1 (decode) vs s>1 (prefill/mixed). Each bucket gets its own
// optimized compilation — decode can select warp-parallel kernels while
// prefill can select tiled matmul / cuBLAS.
let max_prefill = (tokens_a.len().max(tokens_b.len()) + 16).next_power_of_two();
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(16),
],
);
// Dummy data sized for the largest representative (s=16, c=16)
let search_s = 16;
let search_c = 16;
@@ -242,7 +241,8 @@ fn main() {
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, search_graphs);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Re-initialize KV cache after search (search consumes buffers)
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();

View File

@@ -25,8 +25,8 @@ pub struct PagedKVCache {
impl PagedKVCache {
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
@@ -44,78 +44,11 @@ pub struct Llama {
impl Llama {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
layers.push(LlamaLayer {
up: cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
gate: cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
down: cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist(),
q_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
k_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
v_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
o_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
embedding: cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -141,12 +74,8 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &PagedKVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let mut x = self.embedding.gather(
(input * HIDDEN).expand_dim(1, HIDDEN)
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = vec![];
let mut x = token_embedding(self.embedding, input);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
@@ -177,6 +106,99 @@ struct LlamaLayer {
mlp_rms: LayerNorm,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
1e-5,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
@@ -264,44 +286,3 @@ fn paged_attention(
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// Apply RoPE before scattering into cache
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -147,26 +147,11 @@ where
k_out.output();
v_out.output();
}
println!("Building E-Graph...");
cx.build_search_space::<R>();
println!("Loading weights...");
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * config.max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..config.layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(config.max_seq_len);
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
let mut compile_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
@@ -183,12 +168,28 @@ where
DimBucket::new(1, max_decode_p).representative(decode_p_representative),
]
};
cx.set_dim_buckets('p', &p_buckets);
compile_options = compile_options.dim_buckets('p', &p_buckets);
println!("Building E-Graph...");
cx.build_search_space::<R>(compile_options);
println!("Loading weights...");
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * config.max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..config.layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', search_s);
cx.set_dim('p', 0);
runtime.set_i32_data(input.id, vec![1; search_s]);
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, config.search_graphs);
let search_options = CompileOptions::default().search_graph_limit(config.search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..config.layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);

View File

@@ -34,14 +34,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(layers);
let mut v_caches = Vec::with_capacity(layers);
for l in 0..layers {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -63,105 +65,10 @@ impl Qwen {
layers <= LAYERS,
"requested {layers} layers, but model has {LAYERS}"
);
let mut w = vec![];
for l in 0..layers {
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(QwenLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..layers).map(|l| QwenLayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -172,11 +79,7 @@ impl Qwen {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(self.layers.len());
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -209,6 +112,90 @@ struct QwenLayer {
mlp_rms: LayerNorm,
}
impl QwenLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = qwen_rotary_embeddings(qk_norm(q, self.q_norm, N_HEADS), pos_ids, N_HEADS);
let k_rope =
qwen_rotary_embeddings(qk_norm(k, self.k_norm, N_KV_HEADS), pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// Per-head RMS normalization for QK-norm.
/// Input: [seq, dim] where dim = n_heads * HEAD_DIM
/// split_dims to [seq, n_heads, HEAD_DIM], RMS norm over last axis, multiply by weight, merge back.
@@ -331,36 +318,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl QwenLayer {
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm: per-head RMS normalization
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
// RoPE
let q_rope = qwen_rotary_embeddings(q_normed, pos_ids, N_HEADS);
let k_rope = qwen_rotary_embeddings(k_normed, pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -50,9 +50,20 @@ fn main() {
k_out.output();
v_out.output();
}
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
let build_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(build_options);
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
@@ -66,23 +77,13 @@ fn main() {
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim('s', search_s);
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(runtime, SearchOptions::new(search_graphs), &mut rng);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -29,17 +29,19 @@ pub struct KVCache {
impl KVCache {
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -58,111 +60,11 @@ pub struct Qwen3MoE {
impl Qwen3MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
let router = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_up_weights"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist();
let down_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_weights"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist();
layers.push(Qwen3MoELayer {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
moe: QwenMoE {
router,
gate_up_weights: gate_up_weights.as_dtype(DType::Bf16),
down_weights: down_weights.as_dtype(DType::Bf16),
},
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers,
lm_norm,
lm_head,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| Qwen3MoELayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
}
}
@@ -172,11 +74,7 @@ impl Qwen3MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -214,6 +112,39 @@ struct QwenMoE {
}
impl Qwen3MoELayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
moe: QwenMoE {
router: layer_weight(cx, l, "mlp.gate", (NUM_EXPERTS, HIDDEN)),
gate_up_weights: layer_tensor(
cx,
l,
"mlp.gate_up_weights",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
l,
"mlp.down_weights",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
mut x: GraphTensor,
@@ -247,6 +178,51 @@ impl Qwen3MoELayer {
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
impl QwenMoE {
fn forward(&self, x: GraphTensor) -> GraphTensor {
let n = x.dims().len(); // 2 for [s, H]

View File

@@ -11,8 +11,8 @@ fn main() {
display_graph(&cx);
// Compile
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

View File

@@ -63,9 +63,18 @@ fn main() {
k_out.output();
v_out.output();
}
let prompt: Vec<u32> = vec![TOKEN_SOT, TOKEN_NO_TIMESTAMPS];
let max_prefill = prompt.len().max(2);
let build_options = CompileOptions::default().dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(max_prefill),
],
);
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(build_options);
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
@@ -82,22 +91,13 @@ fn main() {
// Set the mel spectrogram once.
runtime.set_data(mel_tensor, mel_data.clone());
let prompt: Vec<u32> = vec![TOKEN_SOT, TOKEN_NO_TIMESTAMPS];
println!("Compiling...");
let max_prefill = prompt.len().max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(max_prefill),
],
);
cx.set_dim('s', max_prefill);
cx.set_dim('p', 0);
runtime.set_data(input, vec![1i32; max_prefill]);
runtime.set_data(pos_ids, (0..max_prefill as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, search_graphs);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Reset the KV caches and re-set the mel after search (which executes test runs).
for i in 0..N_TEXT_LAYER {

View File

@@ -33,6 +33,14 @@ fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
x.matmul(w.t())
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
fn conv1d_bias(
@@ -90,27 +98,13 @@ struct AttentionWeights {
impl AttentionWeights {
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
Self {
q_proj: cx
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
.persist(),
q_bias: cx
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
.persist(),
k_proj: cx
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
.persist(),
v_proj: cx
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
.persist(),
v_bias: cx
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
.persist(),
out_proj: cx
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
.persist(),
out_bias: cx
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
.persist(),
q_proj: persist(cx, format!("{prefix}.q_proj.weight"), (dim, dim)),
q_bias: persist(cx, format!("{prefix}.q_proj.bias"), dim),
k_proj: persist(cx, format!("{prefix}.k_proj.weight"), (dim, dim)),
v_proj: persist(cx, format!("{prefix}.v_proj.weight"), (dim, dim)),
v_bias: persist(cx, format!("{prefix}.v_proj.bias"), dim),
out_proj: persist(cx, format!("{prefix}.out_proj.weight"), (dim, dim)),
out_bias: persist(cx, format!("{prefix}.out_proj.bias"), dim),
}
}
}
@@ -125,6 +119,14 @@ fn merge_heads(x: GraphTensor) -> GraphTensor {
x.transpose(0, 1).merge_dims(1, 2)
}
fn embedding_lookup(embedding: GraphTensor, ids: GraphTensor) -> GraphTensor {
let seq = ids.dims1();
embedding.gather(
(ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
)
}
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
let q = linear_with_bias(x, w.q_proj, w.q_bias);
@@ -239,18 +241,10 @@ impl EncoderLayer {
N_AUDIO_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_AUDIO_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
}
}
@@ -295,18 +289,10 @@ impl DecoderLayer {
N_TEXT_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_TEXT_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
}
}
@@ -346,14 +332,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
for l in 0..N_TEXT_LAYER {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -376,27 +364,23 @@ pub struct WhisperEncoder {
impl WhisperEncoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
conv1_w: cx
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
.persist(),
conv1_b: cx
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
.persist(),
conv2_w: cx
.named_tensor(
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
)
.persist(),
conv2_b: cx
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
.persist(),
positional_embedding: cx
.named_tensor(
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
)
.persist(),
conv1_w: persist(
cx,
"model.encoder.conv1.weight",
(N_AUDIO_STATE, N_MELS * 3),
),
conv1_b: persist(cx, "model.encoder.conv1.bias", N_AUDIO_STATE),
conv2_w: persist(
cx,
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
),
conv2_b: persist(cx, "model.encoder.conv2.bias", N_AUDIO_STATE),
positional_embedding: persist(
cx,
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
),
layers: (0..N_AUDIO_LAYER)
.map(|i| EncoderLayer::new(i, cx))
.collect(),
@@ -427,15 +411,16 @@ pub struct WhisperDecoder {
impl WhisperDecoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
embed_tokens: cx
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
.persist(),
embed_positions: cx
.named_tensor(
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
)
.persist(),
embed_tokens: persist(
cx,
"model.decoder.embed_tokens.weight",
(N_VOCAB, N_TEXT_STATE),
),
embed_positions: persist(
cx,
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
),
layers: (0..N_TEXT_LAYER)
.map(|i| DecoderLayer::new(i, cx))
.collect(),
@@ -450,18 +435,8 @@ impl WhisperDecoder {
xa: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
// Token embedding gather
let mut x = self.embed_tokens.gather(
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
// Positional embedding gather (using pos_ids)
let pos_emb = self.embed_positions.gather(
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
x += pos_emb;
let mut x = embedding_lookup(self.embed_tokens, token_ids);
x += embedding_lookup(self.embed_positions, pos_ids);
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
for (i, layer) in self.layers.iter().enumerate() {

View File

@@ -318,7 +318,7 @@ fn main() {
println!("Building E-Graph...");
let t0 = Instant::now();
cx.build_search_space::<CudaRuntime>();
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
println!(" built E-Graph in {:?}", t0.elapsed());
println!("Loading weights...");
@@ -335,7 +335,8 @@ fn main() {
println!("Compiling (search_graphs={search_graphs})...");
let t0 = Instant::now();
runtime = cx.search(runtime, search_graphs);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
println!(" search took {:?}", t0.elapsed());
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)

18
src/bin/examples-perf.rs Normal file
View File

@@ -0,0 +1,18 @@
use std::{
env,
path::Path,
process::{Command, ExitCode},
};
fn main() -> ExitCode {
let repo_root = env!("CARGO_MANIFEST_DIR");
let script = Path::new(repo_root).join("ci/examples_perf.py");
let status = Command::new("python3")
.arg(script)
.args(env::args_os().skip(1))
.current_dir(repo_root)
.status()
.expect("failed to run python3 ci/examples_perf.py");
ExitCode::from(status.code().unwrap_or(1) as u8)
}

View File

@@ -1,8 +1,8 @@
use std::fmt::Display;
use std::fmt::{Debug, Display};
/// Supported dtypes
/// This is undergoing development. Our goal is to be as explicit as possible about dtype behavior.
#[derive(Clone, Copy, Debug, PartialEq, Default)]
#[derive(Clone, Copy, PartialEq, Default)]
pub enum DType {
/// 32-bit float (8e23m)
#[default]
@@ -20,6 +20,14 @@ pub enum DType {
/// 32-bit signed integer
Int,
/// 64-bit signed integer.
///
/// Debug-formats as `"Int64"` (not `"I64"`) because the egglog optimizer
/// uses `{:?}` to serialize `DType` into rule strings and has a built-in
/// primitive sort named `I64` for integer literals in shape expressions;
/// emitting `"I64"` would shadow that primitive and panic the egraph
/// loader with `UnboundFunction("I64", ...)`.
I64,
/// 4-bit signed integer
I4,
/// 4-bit unsigned integer
@@ -54,6 +62,37 @@ pub enum DType {
F4E2M1,
}
impl Debug for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Mostly identical to the derived Debug, except `I64 -> "Int64"` to
// avoid clashing with egglog's primitive `I64` sort (see the variant
// docstring above).
let name = match self {
DType::F32 => "F32",
DType::F64 => "F64",
DType::F16 => "F16",
DType::Bf16 => "Bf16",
DType::TF32 => "TF32",
DType::Int => "Int",
DType::I64 => "Int64",
DType::I4 => "I4",
DType::U4 => "U4",
DType::I8 => "I8",
DType::U8 => "U8",
DType::I16 => "I16",
DType::U16 => "U16",
DType::Bool => "Bool",
DType::F8UE8M0 => "F8UE8M0",
DType::F8E4M3 => "F8E4M3",
DType::F8E5M2 => "F8E5M2",
DType::F6E2M3 => "F6E2M3",
DType::F6E3M2 => "F6E3M2",
DType::F4E2M1 => "F4E2M1",
};
write!(f, "{}", name)
}
}
impl Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
@@ -68,7 +107,7 @@ impl DType {
/// Use `ShapeTracker::required_total_bytes()` to compute byte sizes for a tensor.
pub fn bits(&self) -> usize {
match self {
DType::F64 => 64,
DType::F64 | DType::I64 => 64,
DType::F32 | DType::Int => 32,
DType::TF32 => 19,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 16,

View File

@@ -13,7 +13,7 @@ use petgraph::stable_graph::NodeIndex;
use rustc_hash::FxHashMap;
use crate::dtype::DType;
use crate::graph::Graph;
use crate::graph::{CompileOptions, Graph};
use crate::hlir::{NativeData, NativeRuntime, Output};
use crate::op::Runtime;
@@ -38,9 +38,21 @@ pub trait DynBackend {
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType);
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>);
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32>;
fn get_output_f16(&self, _node: NodeIndex) -> Vec<half::f16> {
panic!("get_output_f16 not supported by '{}'", self.name());
}
fn get_output_bf16(&self, _node: NodeIndex) -> Vec<half::bf16> {
panic!("get_output_bf16 not supported by '{}'", self.name());
}
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
panic!("get_output_i32 not supported by '{}'", self.name());
}
fn get_output_i64(&self, _node: NodeIndex) -> Vec<i64> {
panic!("get_output_i64 not supported by '{}'", self.name());
}
fn get_output_f64(&self, _node: NodeIndex) -> Vec<f64> {
panic!("get_output_f64 not supported by '{}'", self.name());
}
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
panic!("get_output_bool not supported by '{}'", self.name());
}
@@ -133,7 +145,7 @@ pub fn compile_backend<Rt: Runtime + 'static>(
// survives cross-binary type identity mismatches with external plugins).
let label_map = build_label_map(graph);
graph.build_search_space::<Rt>();
graph.build_search_space::<Rt>(CompileOptions::default());
let mut rt = init()?;
@@ -162,7 +174,10 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, args.search_iters);
let mut rt = graph.search(
rt,
CompileOptions::default().search_graph_limit(args.search_iters),
);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);
@@ -215,6 +230,7 @@ pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
DType::I64 => unsafe { as_bytes(vec![1i64; n_elements]) },
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
@@ -232,13 +248,11 @@ pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
}
match dtype {
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
DType::F64 => {
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
}
DType::F64 => NativeData::F64(unsafe { from_bytes(bytes) }),
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
DType::I64 => NativeData::I64(unsafe { from_bytes(bytes) }),
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
@@ -278,18 +292,80 @@ impl DynBackend for NativeDynBackend {
}
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
let data = self.output_buffer(node);
data.to_f32_vec()
match self.output_buffer(node) {
NativeData::F32(v) => v.clone(),
other => panic!(
"get_output_f32: buffer dtype is {:?}, expected F32. \
Add a `Cast(DType::F32)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_f16(&self, node: NodeIndex) -> Vec<half::f16> {
match self.output_buffer(node) {
NativeData::F16(v) => v.clone(),
other => panic!(
"get_output_f16: buffer dtype is {:?}, expected F16. \
Add a `Cast(DType::F16)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_bf16(&self, node: NodeIndex) -> Vec<half::bf16> {
match self.output_buffer(node) {
NativeData::Bf16(v) => v.clone(),
other => panic!(
"get_output_bf16: buffer dtype is {:?}, expected Bf16. \
Add a `Cast(DType::Bf16)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
let data = self.output_buffer(node);
data.to_i32_vec()
match self.output_buffer(node) {
NativeData::Int(v) => v.clone(),
other => panic!(
"get_output_i32: buffer dtype is {:?}, expected Int (i32). \
Add a `Cast(DType::Int)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
match self.output_buffer(node) {
NativeData::I64(v) => v.clone(),
other => panic!(
"get_output_i64: buffer dtype is {:?}, expected I64. \
Add a `Cast(DType::I64)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
match self.output_buffer(node) {
NativeData::F64(v) => v.clone(),
other => panic!(
"get_output_f64: buffer dtype is {:?}, expected F64. \
Add a `Cast(DType::F64)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
let data = self.output_buffer(node);
data.to_bool_vec()
match self.output_buffer(node) {
NativeData::Bool(v) => v.clone(),
other => panic!(
"get_output_bool: buffer dtype is {:?}, expected Bool. \
Add a `Cast(DType::Bool)` before the Output.",
std::mem::discriminant(other)
),
}
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {

View File

@@ -2,6 +2,7 @@ use std::sync::LazyLock;
use super::api::*;
use crate::shape::{self, ToShape};
use rustc_hash::FxHashSet;
// ---- Sort classes (pub const) ----
@@ -55,6 +56,12 @@ pub fn peq(a: Term, b: Term) -> Term {
pub fn pneq(a: Term, b: Term) -> Term {
neq(a, b)
}
pub fn interval_lower(e: Term) -> Term {
app(&func("lower", &["expr"]), vec![e])
}
pub fn interval_upper(e: Term) -> Term {
app(&func("upper", &["expr"]), vec![e])
}
// ---- Egglog function applications ----
@@ -228,9 +235,13 @@ pub struct BaseSorts {
// DType variants
pub f32_dt: SortDef,
pub f64_dt: SortDef,
pub f16_dt: SortDef,
pub bf16_dt: SortDef,
pub int_dt: SortDef,
/// Egglog sort for `DType::I64`. Named `"Int64"` (not `"I64"`) to avoid
/// shadowing egglog's built-in `I64` primitive sort.
pub int64_dt: SortDef,
pub bool_dt: SortDef,
pub f4e2m1_dt: SortDef,
pub f8e4m3_dt: SortDef,
@@ -319,9 +330,11 @@ impl BaseSorts {
row_major: sort(ELIST, "RowMajor", &[("list", ELIST)]),
f32_dt: sort(DTYPE, "F32", &[]),
f64_dt: sort(DTYPE, "F64", &[]),
f16_dt: sort(DTYPE, "F16", &[]),
bf16_dt: sort(DTYPE, "Bf16", &[]),
int_dt: sort(DTYPE, "Int", &[]),
int64_dt: sort(DTYPE, "Int64", &[]),
bool_dt: sort(DTYPE, "Bool", &[]),
f4e2m1_dt: sort(DTYPE, "F4E2M1", &[]),
f8e4m3_dt: sort(DTYPE, "F8E4M3", &[]),
@@ -385,9 +398,11 @@ impl BaseSorts {
&self.remove_nth_from_end,
&self.row_major,
&self.f32_dt,
&self.f64_dt,
&self.f16_dt,
&self.bf16_dt,
&self.int_dt,
&self.int64_dt,
&self.bool_dt,
&self.f4e2m1_dt,
&self.f8e4m3_dt,
@@ -412,6 +427,38 @@ pub fn dtype(e: Term) -> Term {
app(&func("dtype", &["inp"]), vec![e])
}
pub fn interval_facts_egglog(
intervals: &shape::DynDimIntervals,
vars: impl IntoIterator<Item = char>,
) -> String {
let mut all_vars = FxHashSet::default();
all_vars.extend(intervals.keys().copied());
all_vars.extend(vars);
let mut all_vars = all_vars.into_iter().collect::<Vec<_>>();
all_vars.sort_unstable();
let mut out = String::new();
for var in all_vars {
let interval = intervals
.get(&var)
.copied()
.unwrap_or_else(shape::DimInterval::unbounded);
let var_expr = mvar(str(&var.to_string()));
out.push_str(&format!(
"(set {} {})\n",
term_to_egglog(&interval_lower(var_expr.clone())),
interval.min
));
out.push_str(&format!(
"(set {} {})\n",
term_to_egglog(&interval_upper(var_expr)),
interval.max
));
}
out
}
// ---- Normalized Op helpers ----
/// Build an `(Op kind inputs)` IR term.
@@ -460,11 +507,19 @@ pub fn new_op_call(kind_sort: &SortDef, input_names: &[&str]) -> (Args, Term) {
(args, op)
}
pub fn base_expression_egglog() -> String {
base_expression_egglog_impl(false)
}
pub fn base_expression_egglog_with_intervals() -> String {
base_expression_egglog_impl(true)
}
/// Generate the egglog program equivalent to `base.egg`.
///
/// This builds the Expression, EList, and DType datatypes along with all
/// algebraic rewrites, replacement rules, and list helper functions.
pub fn base_expression_egglog() -> String {
fn base_expression_egglog_impl(use_interval_analysis: bool) -> String {
let s = BaseSorts::new();
// Build the program
@@ -475,12 +530,29 @@ pub fn base_expression_egglog() -> String {
// Rulesets
p.add_ruleset("expr");
if use_interval_analysis {
p.add_ruleset("interval_expr");
}
p.add_ruleset("dtype_prop");
p.add_ruleset("cleanup");
p.add_ruleset("post_cleanup");
// Register all sorts
s.register(&mut p);
if use_interval_analysis {
p.add_function(FunctionDef {
name: "lower".to_string(),
args: vec![EXPRESSION.name.to_string()],
ret: I64.name.to_string(),
merge: Some("(max old new)".to_string()),
});
p.add_function(FunctionDef {
name: "upper".to_string(),
args: vec![EXPRESSION.name.to_string()],
ret: I64.name.to_string(),
merge: Some("(min old new)".to_string()),
});
}
// ---- Algebraic rewrites ----
// Commutativity
@@ -725,6 +797,182 @@ pub fn base_expression_egglog() -> String {
.ruleset("expr"),
);
if use_interval_analysis {
// ---- Interval analysis and interval-guarded simplifications ----
p.add_rule(
Rule::new()
.fact(peq(v("?e"), num(v("?n"))))
.set(interval_lower(v("?e")), v("?n"))
.set(interval_upper(v("?e")), v("?n"))
.ruleset("interval_expr")
.name("interval-num-exact"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), add(v("?a"), v("?b"))),
peq(v("?lo_a"), interval_lower(v("?a"))),
peq(v("?lo_b"), interval_lower(v("?b"))),
peq(v("?sum"), padd(v("?lo_a"), v("?lo_b"))),
])
.set(interval_lower(v("?e")), v("?sum"))
.when(vec![
pgte(v("?lo_a"), i64(0)),
pgte(v("?lo_b"), i64(0)),
pgte(psub(i64(i64::MAX), v("?lo_b")), v("?lo_a")),
])
.ruleset("interval_expr")
.name("interval-add-lower-nonnegative"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), add(v("?a"), v("?b"))),
peq(v("?hi_a"), interval_upper(v("?a"))),
peq(v("?hi_b"), interval_upper(v("?b"))),
peq(v("?sum"), padd(v("?hi_a"), v("?hi_b"))),
])
.set(interval_upper(v("?e")), v("?sum"))
.when(vec![
plt(v("?hi_a"), i64(i64::MAX)),
plt(v("?hi_b"), i64(i64::MAX)),
pgte(psub(i64(i64::MAX), v("?hi_b")), v("?hi_a")),
])
.ruleset("interval_expr")
.name("interval-add-upper-finite"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), min(v("?a"), v("?b"))),
peq(v("?lo_a"), interval_lower(v("?a"))),
peq(v("?lo_b"), interval_lower(v("?b"))),
])
.set(interval_lower(v("?e")), pmin(v("?lo_a"), v("?lo_b")))
.ruleset("interval_expr")
.name("interval-min-lower"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), min(v("?a"), v("?b"))),
peq(v("?hi_a"), interval_upper(v("?a"))),
peq(v("?hi_b"), interval_upper(v("?b"))),
])
.set(interval_upper(v("?e")), pmin(v("?hi_a"), v("?hi_b")))
.ruleset("interval_expr")
.name("interval-min-upper"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), max(v("?a"), v("?b"))),
peq(v("?lo_a"), interval_lower(v("?a"))),
peq(v("?lo_b"), interval_lower(v("?b"))),
])
.set(interval_lower(v("?e")), pmax(v("?lo_a"), v("?lo_b")))
.ruleset("interval_expr")
.name("interval-max-lower"),
);
p.add_rule(
Rule::new()
.facts(vec![
peq(v("?e"), max(v("?a"), v("?b"))),
peq(v("?hi_a"), interval_upper(v("?a"))),
peq(v("?hi_b"), interval_upper(v("?b"))),
])
.set(interval_upper(v("?e")), pmax(v("?hi_a"), v("?hi_b")))
.ruleset("interval_expr")
.name("interval-max-upper"),
);
p.add_rule(
rewrite("interval-lt-true", lt(v("?x"), num(v("?n"))), num(i64(1)))
.when(vec![
peq(v("?hi"), interval_upper(v("?x"))),
plt(v("?hi"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite("interval-lt-false", lt(v("?x"), num(v("?n"))), num(i64(0)))
.when(vec![
peq(v("?lo"), interval_lower(v("?x"))),
pgte(v("?lo"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite("interval-gte-true", gte(v("?x"), num(v("?n"))), num(i64(1)))
.when(vec![
peq(v("?lo"), interval_lower(v("?x"))),
pgte(v("?lo"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite(
"interval-gte-false",
gte(v("?x"), num(v("?n"))),
num(i64(0)),
)
.when(vec![
peq(v("?hi"), interval_upper(v("?x"))),
plt(v("?hi"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite(
"interval-min-right-identity",
min(v("?x"), num(v("?n"))),
v("?x"),
)
.when(vec![
peq(v("?hi"), interval_upper(v("?x"))),
pgte(v("?n"), v("?hi")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite(
"interval-max-right-identity",
max(v("?x"), num(v("?n"))),
v("?x"),
)
.when(vec![
peq(v("?lo"), interval_lower(v("?x"))),
pgte(v("?lo"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite("interval-mod-small", modd(v("?x"), num(v("?n"))), v("?x"))
.when(vec![
pgte(v("?n"), i64(1)),
peq(v("?lo"), interval_lower(v("?x"))),
peq(v("?hi"), interval_upper(v("?x"))),
pgte(v("?lo"), i64(0)),
plt(v("?hi"), v("?n")),
])
.ruleset("interval_expr"),
);
p.add_rule(
rewrite(
"interval-div-small",
div(v("?x"), num(v("?n"))),
num(i64(0)),
)
.when(vec![
pgte(v("?n"), i64(1)),
peq(v("?lo"), interval_lower(v("?x"))),
peq(v("?hi"), interval_upper(v("?x"))),
pgte(v("?lo"), i64(0)),
plt(v("?hi"), v("?n")),
])
.ruleset("interval_expr"),
);
}
// `div-div`, restricted to nested constant divisors only. The original
// unconstrained form `(a/b)/c → a/(b*c)` produces a new `div` whose
// denominator matches the same rule again as soon as `a` is itself a

View File

@@ -235,18 +235,26 @@ fn egglog_ruleset_declarations() -> String {
.join("\n")
}
fn egglog_main_cycle_phases(cycle: usize) -> Vec<EgglogSchedulePhase> {
fn expr_schedule(use_interval_analysis: bool) -> &'static str {
if use_interval_analysis {
"(saturate (seq expr interval_expr))"
} else {
"(saturate expr)"
}
}
fn egglog_main_cycle_phases(cycle: usize, use_interval_analysis: bool) -> Vec<EgglogSchedulePhase> {
vec![EgglogSchedulePhase {
name: format!("cycle {cycle:03} main"),
schedule: egglog_main_schedule(),
schedule: egglog_main_schedule(use_interval_analysis),
}]
}
fn egglog_final_phases() -> Vec<EgglogSchedulePhase> {
fn egglog_final_phases(use_interval_analysis: bool) -> Vec<EgglogSchedulePhase> {
vec![
EgglogSchedulePhase {
name: "final expr".to_string(),
schedule: "(saturate expr)".to_string(),
schedule: expr_schedule(use_interval_analysis).to_string(),
},
EgglogSchedulePhase {
name: "cleanup".to_string(),
@@ -263,14 +271,16 @@ fn egglog_final_phases() -> Vec<EgglogSchedulePhase> {
]
}
fn egglog_main_schedule() -> String {
fn egglog_main_schedule(use_interval_analysis: bool) -> String {
let expr = expr_schedule(use_interval_analysis);
// Producer rules create raw alternatives that downstream fusion consumes.
// Fusion grow/merge only consumes Kernel*/FusionEnd alternatives, so keeping
// producer discovery saturated before fusion reaches the same fixed point
// while avoiding repeated expensive pair-discovery scans during growth.
"(saturate (seq
format!(
"(saturate (seq
(saturate (seq
(saturate expr)
{expr}
(saturate dtype_prop)
(run matmul_flatten)
(run kernel_lower)
@@ -282,19 +292,19 @@ fn egglog_main_schedule() -> String {
(run fusion_pair)
))
(saturate (seq
(saturate expr)
{expr}
(saturate dtype_prop)
(run fusion_grow)
(run fusion_merge)
))
))"
.to_string()
)
}
fn egglog_schedule_program() -> String {
let mut schedules = vec![format!("(run-schedule {})", egglog_main_schedule())];
let mut schedules = vec![format!("(run-schedule {})", egglog_main_schedule(false))];
schedules.extend(
egglog_final_phases()
egglog_final_phases(false)
.into_iter()
.map(|phase| format!("(run-schedule {})", phase.schedule)),
);
@@ -302,9 +312,22 @@ fn egglog_schedule_program() -> String {
}
fn egglog_setup_with(program: &str, parts: &OpTextParts) -> String {
egglog_setup_with_options(program, parts, false)
}
fn egglog_setup_with_options(
program: &str,
parts: &OpTextParts,
use_interval_analysis: bool,
) -> String {
let base_program = if use_interval_analysis {
base::base_expression_egglog_with_intervals()
} else {
base::base_expression_egglog()
};
[
egglog_ruleset_declarations(),
base::base_expression_egglog(),
base_program,
parts.op_defs.clone(),
parts.cleanups.clone(),
base::base_cleanup_egglog(),
@@ -1130,6 +1153,19 @@ pub fn run_egglog_with_report_and_late_passes(
run_egglog_with_report_parts(program, root, &op_parts)
}
#[tracing::instrument(skip_all)]
pub fn run_egglog_with_report_late_passes_and_interval_analysis(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
late_passes: &[LateEgglogPass],
use_interval_analysis: bool,
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
let op_parts = OpTextParts::new_with_late_passes(ops, cleanup, late_passes);
run_egglog_with_report_parts_impl(program, root, &op_parts, use_interval_analysis)
}
/// Same as [`run_egglog_with_report`], but takes pre-computed [`OpTextParts`].
/// Useful when a caller runs many egglog invocations with the same op set
/// and wants to factor the op-derived text work out of a parallel loop.
@@ -1139,6 +1175,15 @@ pub fn run_egglog_with_report_parts(
program: &str,
root: &str,
op_parts: &OpTextParts,
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
run_egglog_with_report_parts_impl(program, root, op_parts, false)
}
fn run_egglog_with_report_parts_impl(
program: &str,
root: &str,
op_parts: &OpTextParts,
use_interval_analysis: bool,
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
#[cfg(debug_assertions)]
{
@@ -1158,7 +1203,7 @@ pub fn run_egglog_with_report_parts(
let full_start = std::time::Instant::now();
let setup_text_start = std::time::Instant::now();
let setup_code = egglog_setup_with(program, op_parts);
let setup_code = egglog_setup_with_options(program, op_parts, use_interval_analysis);
let setup_text_elapsed = setup_text_start.elapsed();
let setup_lines = setup_code.lines().count();
let mut egraph = egglog::EGraph::default();
@@ -1197,7 +1242,7 @@ pub fn run_egglog_with_report_parts(
let mut reached_fixed_point = false;
for cycle in 1..=MAIN_SCHEDULE_MAX_CYCLES {
let mut cycle_updated = false;
for phase in egglog_main_cycle_phases(cycle) {
for phase in egglog_main_cycle_phases(cycle, use_interval_analysis) {
cycle_updated |= run_schedule_phase(&mut egraph, &mut phases, &phase)?;
}
if egraph.num_tuples() > MAIN_SCHEDULE_MAX_TUPLES {
@@ -1217,7 +1262,7 @@ pub fn run_egglog_with_report_parts(
"egglog saturation did not reach a fixed point within {MAIN_SCHEDULE_MAX_CYCLES} cycles"
)));
}
for phase in egglog_final_phases() {
for phase in egglog_final_phases(use_interval_analysis) {
run_schedule_phase(&mut egraph, &mut phases, &phase)?;
}
for phase in &op_parts.late_phases {
@@ -1432,6 +1477,26 @@ pub fn run_egglog_with_late_passes(
.map(|(egraph, _)| egraph)
}
#[tracing::instrument(skip_all)]
pub fn run_egglog_with_late_passes_and_interval_analysis(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
late_passes: &[LateEgglogPass],
use_interval_analysis: bool,
) -> Result<SerializedEGraph, egglog::Error> {
run_egglog_with_report_late_passes_and_interval_analysis(
program,
root,
ops,
cleanup,
late_passes,
use_interval_analysis,
)
.map(|(egraph, _)| egraph)
}
/// Same as [`run_egglog`] but takes pre-computed [`OpTextParts`], so the
/// whole function is `Send`. Used by the parallel grouped-egraphs build.
#[tracing::instrument(skip_all)]
@@ -1480,9 +1545,13 @@ pub fn extract_expr_list<'a>(
pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DType {
match egraph.enodes[node].0.as_str() {
"F32" => DType::F32,
"F64" => DType::F64,
"F16" => DType::F16,
"Bf16" => DType::Bf16,
"Int" => DType::Int,
// `"Int64"` rather than `"I64"` to avoid colliding with egglog's
// built-in I64 primitive (see `DType::I64` docstring).
"Int64" => DType::I64,
"Bool" => DType::Bool,
"F4E2M1" => DType::F4E2M1,
"F6E2M3" => DType::F6E2M3,

Some files were not shown because too many files have changed in this diff Show More