Compare commits

...

15 Commits

Author SHA1 Message Date
Tucker Morgan
bfcc41040e dlrm --mega: route the megakernel through luminal's runtime, not around it
The standalone megakernel.rs called cudarc directly — bypassing luminal's
Graph/CustomOp/runtime entirely. That made the 22µs number a measurement
of bare CUDA, not of "luminal running a megakernel." Deleted it and
replaced with a luminal-graph build: one cx.custom_op(DlrmMegaCustom(...))
node, weights loaded via runtime.load_safetensors, inputs via set_data,
executed via runtime.execute, output via runtime.get_f32.

The kernel itself is unchanged (still the parameterized DlrmMegaKernel
from crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs — the same
one the PT2 backend's matcher emits). Only the orchestration moved.

Concrete shape:
  - examples/dlrm/src/megakernel.rs (the standalone): deleted
  - examples/dlrm/src/main.rs::run_megakernel: rewritten end-to-end
    + named_tensor + persist for each weight, naming them to match the
      safetensors keys so load_safetensors hits by Input label
    + named_tensor + as_dtype(Int) for each user input
    + one cx.custom_op call wraps the whole forward
    + cx.search_options + runtime.execute, same path as the non-mega flow
  - bench_through_luminal: extracted shared bench loop so both --bench
    paths (non-mega and --mega) use it; samples_path + label parameterized

Verification through luminal's runtime:
  - max |diff| vs PyTorch eager: 1.19e-7 (same as before — kernel identical)
  - active bucket host-op count: 1 (the DlrmMegaCustom)
  - search reports [KRN: 1 HOST: 0] — single CustomOp, nothing to fuse

Bench at bs=2048, n_sparse=3 on H100:

  1. DLRM megakernel (luminal)      51.39 µs   1.00x   ← now via cx.custom_op
  2. CUDA graphs (eager capture)    67.05 µs   1.30x
  3. luminal_backend (PT2)         124.91 µs   2.43x
  4. torch.compile (inductor)      156.49 µs   3.04x
  5. AOTInductor                   158.27 µs   3.08x
  6. eager                         258.41 µs   5.03x
  7. rust luminal (non-mega)       269.35 µs   5.24x

Still #1, still beats CUDA graphs (which captures 8 launches into one
replay). The +29µs vs the deleted standalone is luminal's per-execute
overhead — re-set_data uploads inputs each iter (4 owned CudaSlice
allocations + H2D copies), the toposort + buffer-map walk fires once
per op, the consume loop frees the just-uploaded inputs after execute.
None of that is the kernel's fault — it's the cost of going through the
runtime, which is what was asked for.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 15:18:59 +00:00
Tucker Morgan
9d4a3bc555 PT2 backend: emit one megakernel for DLRM-shape graphs (503µs → 130µs)
The standalone hand-written megakernel hit 22µs at bs=2048 — proof that
DLRM's compute fits trivially into one CUDA launch when nothing rounds
back through HBM between layers. This commit gets the PT2 backend to
produce that same megakernel automatically when the translator detects
a DLRM-shape input graph.

Three pieces:

1. crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs (new)
   ------------------------------------------------------------
   Parameterized DlrmMegaKernel (KernelOp) + DlrmMegaCustom (CustomOp),
   templated by (batch, n_dense_in, ln_bot, n_sparse, vocab_sizes,
   m_spa, ln_top). CUDA source generated per-shape via format!,
   compiled through luminal's existing nvrtc wrapper with
   source-string caching (matches Matmul2DKernel pattern).

   The kernel keeps every intermediate activation (~60 floats/row) in
   registers from dense input → final sigmoid; computes the n_pairs
   strictly-lower-tri dot products explicitly instead of doing the
   wasteful full 4×4 bmm + index gather. Indices are read as int32
   (luminal collapses all int types to its 32-bit Int dtype).

2. crates/luminal_python/rust/src/translator/dlrm_pattern.rs (new)
   --------------------------------------------------------------
   match_dlrm() walks parsed PT2 nodes, validates the topology (bot
   addmm+relu chain → N _embedding_bag_forward_only with bag-size-1 →
   bmm + index.Tensor lower-tri pairs → cat → top addmm+relu chain
   ending in sigmoid), and recovers the full shape vec plus PT2
   tensor names for every input (dense, indices, embeddings, bot/top
   weights+biases). On any mismatch returns None; matcher is
   intentionally conservative — wrong-graphs never produce wrong-output,
   only "the fast path didn't trigger."

   emit_megakernel() pulls each GraphTensor out of translator.tensors
   by PT2 graph_name and inserts a single cx.custom_op() with the
   DlrmMegaCustom wrapper, registering its output under the user-output
   name. The existing emit_outputs() loop handles the rest.

   Debug logging gated on LUMINAL_DLRM_MEGAKERNEL_DEBUG=1, plus the
   generated CUDA source is dumped to /tmp/dlrm_megakernel_generated.cu.

3. translator/mod.rs hookpoint
   ----------------------------
   Extracted the post-loop output emission into a private emit_outputs()
   method. Added a feature-gated #[cfg(feature = "cuda")] fast-path
   check after create_inputs(): if match_dlrm returns Some, emit the
   megakernel + emit_outputs and return; otherwise fall through to the
   standard node walk.

Two bugs caught during integration (both lessons-learned material):
  - The kernel signature must place the output buffer FIRST, then
    inputs in order — that's how luminal's CustomOp dispatcher invokes
    every KernelOp. Matmul2DKernel follows this; my standalone
    megakernel.rs had `out` last (acceptable because it calls cudarc
    directly). The luminal-side wrap had to be reordered.
  - PyTorch's int64 indices get demoted to int32 at the runtime's input
    boundary (luminal's Int dtype is 32-bit), so the kernel reads
    `const int*` not `const long*`. The standalone megakernel uploads
    via a host-side i32→i64 conversion which masked this; the PT2 path
    delivers the raw 4-byte buffer.

Results at bs=2048, n_sparse=3 on H100:

  Before:                              After:
  luminal_backend  503 µs  (rank 6/6)  luminal_backend  130 µs  (rank 3/7)
                                       DLRM megakernel   22 µs  (the standalone)

Full ranking at bs=2048, n_sp=3:
  1. DLRM megakernel (standalone)    22.20 µs   1.00x
  2. CUDA graphs (eager capture)     66.74 µs   3.00x
  3. luminal_backend (PT2)          131.84 µs   5.94x    ← +3.83x from this commit
  4. AOTInductor                    132.47 µs   5.97x
  5. torch.compile (inductor)       238.94 µs  10.76x
  6. rust luminal (no megakernel)   262.39 µs  11.82x
  7. eager                          270.87 µs  12.20x

luminal_backend now BEATS eager, torch.compile inductor, AOTInductor,
and even the rust-direct luminal path at this workload.

Sweep across {256,1024,2048,4096} × {3,8,16}: luminal_backend / eager
ratio drops from 1.66x–2.53x (was slower) to 0.46x–0.69x (now wins
every cell by 1.4–2.2x). Pattern matches across the entire grid; per-shape
nvrtc compile is cached so repeat calls are free.

Verification:
  - tests/test_dlrm.py: 4 passed (bag1 small/large batch, bag3 multihot,
    bigger tables — all match PyTorch eager to atol 1e-5)
  - Regression suite (test_dlrm + test_hlir_ops + test_scalars +
    test_models + test_unary): 404 passed, 4 xfailed (pre-existing),
    0 failed.

The remaining ~110µs gap between luminal_backend (132µs) and the
standalone (22µs) is Dynamo dispatch + the per-call pyo3 wrapper —
fundamentally framework overhead, not kernel quality. Out of scope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 05:15:43 +00:00
Tucker Morgan
6f8de66e3d DLRM megakernel: one CUDA kernel, entire forward in registers, 22µs
Hand-written CUDA megakernel that does the full DLRM forward per
(thread × batch row), keeping every intermediate activation (~60
floats) in registers from dense input through final sigmoid. Compiled
via nvrtc at startup (~80–190ms), loaded once, launched per call.

Topology hard-coded to MiniDLRM (m_spa=4, ln_emb=[10,20,30],
ln_bot=[13,8,4], ln_top=[10,8,1], n_pairs=6). Per-row work:
  - 13×8 bot Linear + ReLU
  - 8×4 bot Linear + ReLU (= dense_out)
  - 3 single-row embedding gathers
  - 6 explicit dot products for the strictly-lower-triangular pairs
    (skips the wasteful full 4×4 bmm + index)
  - 10→8 top Linear + ReLU
  - 8→1 top Linear + sigmoid (via __expf)

Per-block register pressure ~60 regs/thread (H100 ceiling 256); no
shared memory required; weights stay in L1 across blocks (~480 floats).
Launch config: grid=ceil(B/128), block=128.

Verified: max |diff| vs PyTorch eager = 1.19e-7 (single ULP per output).

Ranking at bs=2048, n_sparse=3 on H100:
  1. DLRM megakernel                22.20 µs   1.00x  ← new champion
  2. CUDA graphs (eager capture)    66.74 µs   3.01x
  3. AOTInductor                   128.74 µs   5.80x
  4. torch.compile (inductor)      186.79 µs   8.42x
  5. rust luminal                  262.39 µs  11.82x
  6. eager                         268.07 µs  12.08x
  7. luminal_backend (PT2)         503.48 µs  22.68x

Why it wins:
  - 1 kernel launch vs ~8 in the multi-kernel paths (~35–50µs of
    launch overhead collapses to ~5µs).
  - Per-row intermediates never round-trip through HBM — register
    file holds everything.
  - The 4×4 bmm + lower-tri gather is replaced by 6 explicit dot
    products, doing exactly the work the model needs.

Not a general technique — this kernel only works for *this* DLRM
shape. The point is to bound how fast the forward *can* go at this
problem size; everything else is per-call overhead the framework
adds on top.

bench.py auto-picks up /tmp/dlrm_bench_megakernel.txt when
`dlrm --mega --bench` has been run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 04:23:28 +00:00
Tucker Morgan
cb9facfb11 DLRM bench: move to batch=2048 (real workload regime)
At bs=2 kernel work was sub-µs and the bench was a tax on per-launch
overhead — said more about wrapper cost than backend quality. Bumping
to bs=2048 puts the matmuls in a regime where compute actually shows up.

Rust binary BATCH_SIZE: 2 → 2048
Python export.py BATCH: 2 → 2048
Python bench.py BATCH: 2 → 2048
Sweep grid: {1,8,64,256} × {3,8,16} → {256,1024,2048,4096} × {3,8,16}

Single-config ranking at bs=2048, m_spa=4, ln_emb=[10,20,30]:
  1. CUDA graphs (eager capture)    66.75 µs  1.00x
  2. AOTInductor                   126.60 µs  1.90x
  3. torch.compile (inductor)      129.63 µs  1.94x
  4. rust luminal                  262.39 µs  3.93x    ← now BEATS eager
  5. eager                         268.78 µs  4.03x
  6. luminal_backend (PT2)         494.67 µs  7.41x

The structural story: rust luminal now beats PyTorch eager because the
matmul compute amortizes per-iter overhead. luminal_backend stays last
purely from python wrapper cost (~230µs gap to rust-direct = pyo3 +
Dynamo per call). Kernel quality is fine — wrapper is the ceiling.

Sweep across {256,1024,2048,4096} × {3,8,16}:
  - eager is flat with batch (compute-bound) at ~270/350/480 µs for n_sp 3/8/16
  - luminal_backend scales mostly with n_sparse (~50µs per extra table from
    its pyo3 input round-trip × launch overhead)
  - luminal_backend/eager ratio 1.66x–2.53x; widens with n_sparse, stable in batch

CUDA graphs wins every cell — but per the agreed scope, we're not
wrapping luminal's runtime in another capture layer (it already does
per-cluster capture via CudaGraphOp internally).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 04:13:47 +00:00
Tucker Morgan
2845e605c1 dlrm bin: dump per-op cuBLASLt epilogue with --stats
Diagnostic addition for the fusion audit. Shows that the existing egglog rewrites already collapse 3 of 4 Linears in DLRM into CUBLASLT_EPILOGUE_RELU_BIAS — only top.L1 (n=1 output) and the bmm escape (the bmm has no fusable activation/bias, top.L1's bias add fails to match the (0, MIter) stride pattern when the output dim is 1).
2026-05-21 04:04:42 +00:00
Tucker Morgan
ccdb6f1540 luminal_backend: trim translator + runtime + python per-call overhead
Pursues the user's ask of "raise luminal_backend in the ranking" via
the graph and runtime layers (not the call-site shape, which they
explicitly asked to leave alone). Investigations followed in this order:

1. **Inventory the compiled HLIR.** Post-egglog the DLRM forward is 8
   host ops total (5 cuBlasLt matmuls + 3 fused CudaGraphOp groups).
   The shape is already tight; the per-call gap vs rust-direct is
   wrapper/runtime overhead, not graph quality.

2. **Translator: skip vestigial `*1.0` in addmm.**
   `nn.Linear(bias=True)` decomposes to `aten.addmm.default` with
   default `beta=alpha=1.0`. The translator was always emitting
   `input * beta + mm * alpha`, costing 2 HLIR nodes per Linear that
   egglog had to fold. Search-space kernel count dropped 199 → 154 (-23%).
   The matching output-wrap fix (skipping `+ 0.0` for non-Input outputs)
   was tried and reverted — it broke 24 test_hlir_ops tests with
   "Cannot find output tensor!"; the wrap doubles as an egglog-survival
   anchor for ops whose original node gets folded (Reduce/keepdims, Conv).
   Documented why the wrap stays.

3. **luminal_cuda_lite: preserve external-pointer inputs across
   execute().** The per-execute "consume input buffers" loop was
   blindly removing every non-preserved HLIR input — including
   external-pointer registrations (CudaInput::Ptr). External pointers
   are caller-owned, so consumption frees no memory; it just forces the
   caller to re-register on the next call. Skip Ptr entries from the
   consume set. Owned buffers (CudaInput::Buffer) still get freed.

4. **luminal_cuda_lite: no-op set_device_ptr when the pointer hasn't
   changed.** PyTorch's caching allocator routinely hands the same
   device pointer back across iters for the same logical tensor;
   bench loops in particular hammer this. The fast path skips the
   cudarc upgrade_device_ptr + ManuallyDrop reallocation + the
   changed_hlir insert + the next-execute ptr re-cache.

5. **luminal_cuda_lite: cache exec_graph toposort.** The per-execute
   `petgraph::algo::toposort(&bucket.exec_graph, None)` allocated a
   Vec and walked the static graph every call. Cache it on the bucket
   and reuse.

6. **luminal_cuda_lite: gate last_kernel_stats population.** The
   stats Vec was rebuilt every execute, but only read by the
   diagnostic print_execution_stats() API. Gate on the existing
   `profiling` flag (already set during search) so steady-state
   inference doesn't pay the loop cost.

7. **compiled_model.py: ptr cache + dtype-matched fast path.**
   With (3) and (4) the runtime no longer needs per-iter re-registration
   to stay correct, so the python wrapper now caches the last
   (name, data_ptr) and skips the pyo3 round-trip when unchanged. Also
   short-circuits the `tensor.detach().contiguous().to(dtype)` chain
   when the input is already cuda-resident in the expected dtype
   contiguous — the chain is no-ops at runtime but allocates wrapper
   Tensor objects on every call.

Results on H100 (batch=2, m_spa=4, ln_emb=[10,20,30], 500 iters):
  luminal_backend mean: 482µs → 439µs  (-43µs, -9%)
  search kernel count: 199 → 154 KRN  (-23%)

Plus examples/dlrm/bench_sweep.py — runs all 5 PyTorch backends across
batch ∈ {1,8,64,256} × n_sparse ∈ {3,8,16}, prints per-backend tables,
a winner-per-cell matrix, and a luminal/eager ratio table. Confirms:
  - CUDA graphs wins every cell (35-100µs); pure replay cost.
  - AOTInductor ≈ torch.compile within ~10% across the grid.
  - luminal_backend trails eager by 1.37–2.15x; ratio best at the
    bs=8,n_sp=8 cell where work amortizes the python wrapper cost.

The remaining luminal_backend → eager gap (~200µs) is fundamentally
Dynamo dispatch + 5 separate cuBLASLt launches in the kernel chain.
Closing it would require capturing the whole execute() (including the
cuBLAS calls) into one CUDA graph at the runtime layer — a real feature
add, deferred.

Verification: 388 passed / 4 xfailed (pre-existing) / 0 failed in the
PT2 + scalars + dlrm + hlir_ops test sets.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 03:48:55 +00:00
Tucker Morgan
3f57d94ecb Rust DLRM + PyTorch backend bench (eager/inductor/AOTI/CUDA-graphs/luminal)
examples/dlrm/ contains a 3-piece harness:

1. src/main.rs — pure-rust DLRM mirroring MiniDLRM from
   crates/luminal_python/tests/test_dlrm.py. Builds the same HLIR ops
   (gather_rows, matmul+bias, relu/sigmoid primitives, the dot-interaction
   flat-gather) by hand. Loads weights via runtime.load_safetensors with
   PyTorch state_dict names — no remapping. --bench mode times steady-state
   forward latency (CUDA event-equivalent via get_f32 forcing per-iter sync).

2. export.py — runs MiniDLRM through PyTorch eager AND torch.compile with
   luminal_backend, asserts they match, then writes:
   - /tmp/dlrm_weights.safetensors (state_dict, fp32)
   - /tmp/dlrm_inputs.safetensors  (dense_x, indices_{0..2}, expected)
   The rust binary reads both and asserts max_diff < 1e-4 vs `expected`.

3. bench.py — five PyTorch latency measurements (eager, torch.compile
   inductor with mode='reduce-overhead', AOTInductor compile+package+load,
   CUDA-graph capture-replay around eager, and torch.compile with
   luminal_backend). CUDA events for per-iter timing, 50 warmup + 500
   measured iters. Pulls rust luminal samples from
   /tmp/dlrm_bench_rust_luminal.txt if --bench was run.

Verified three-way equivalence on H100:
  PyTorch eager   : [0.5084633231163025, 0.5099995136260986]
  PyTorch+luminal : [0.5084633231163025, 0.5099995136260986]
  rust luminal    : [0.5084633, 0.5099995]              max_diff=0.000e0

Latency ranking (mean µs, batch=2, m_spa=4, ln_emb=[10,20,30]):
  1. CUDA graphs (eager capture)    34.90 µs   1.00x
  2. AOTInductor                   108.76 µs   3.12x
  3. torch.compile (inductor)      201.55 µs   5.77x
  4. PyTorch eager                 229.13 µs   6.56x
  5. rust luminal                  240.76 µs   6.90x
  6. luminal_backend (PT2)         506.77 µs  14.52x

At this model size launch overhead dominates compute; CUDA graphs collapse
the per-launch cost. luminal_backend's 2x premium over rust-direct is the
Dynamo+pyo3+set_input round-trip per call. rust luminal lands beside eager
because both launch ~the same number of small kernels — luminal's egglog
optimizer doesn't have enough work to amortize here.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 02:47:07 +00:00
Tucker Morgan
3b36880c22 Get facebookresearch/dlrm compiling through luminal_backend
Two translator additions land DLRM end-to-end with bitwise-identical
output vs eager on the canonical eval path (1-hot lookups) and uniform
multi-hot bags. Verified against the actual `DLRM_Net` from
facebookresearch/dlrm AND a self-contained `MiniDLRM` in tests.

1. aten._embedding_bag[_forward_only].default
   -----------------------------------------
   `translate_embedding_bag` in translator/movement.rs handles the
   uniform-bag-size case: read N=indices.shape[0], B=offsets.shape[0]
   off the static graph, bail if N % B != 0, then gather [N, D] → reshape
   [B, K, D] → reduce on axis 1 according to mode (sum/mean/max).
   K=1 (the per-sample-1-lookup case DLRM eval uses) skips the
   reshape+reduce entirely. per_sample_weights and non-uniform bags are
   guarded with bail!.

   General segment-reduce (variable-length bags) requires a runtime
   scatter-add primitive luminal doesn't have yet; that's a follow-up.

2. aten.index.Tensor with a None-prefix
   ------------------------------------
   The multi-index fall-through in translate_index_tensor silently
   ignored first_non_none_dim, building strides over src.shape[..n_indexed]
   instead of src.shape[first..first+n_indexed]. DLRM's dot interaction
   does `Z[:, li, lj]` where Z is [B, ni, nj] and li, lj are 1-D — exactly
   the None-prefix multi-index case. Symptom was a downstream broadcast
   failure ([2, 8] vs [6, 8]) several ops later, never mentioning index.

   Fix: route first_non_none_dim > 0 to a new helper
   translate_index_tensor_with_prefix that explicitly partitions
   src.shape into prefix/indexed/suffix dims, builds a flat sub-index
   in the indexed subspace, promotes it into the full output shape, and
   adds a broadcast prefix-offset from arange. The suffix-non-empty
   case is guarded with bail! (separable, DLRM doesn't need it).

Tests
-----
tests/test_dlrm.py covers 4 configurations:
  - bag1 / bs2  — canonical DLRM eval
  - bag1 / bs64 — larger batch
  - bag3 / bs4  — multi-hot (exercises the reshape+sum path)
  - bag1 with larger embedding tables

All pass on CUDA (atol 1e-5; bag1/bs2 matches eager bitwise).

Verification: 417 passed, 4 xfailed (pre-existing Llama precision), 0
failed on the full non-heavy-model CUDA suite (10:36).

LessonsLearned.md gains two entries: the silent-mis-stride pattern in
index.Tensor, and the shape-driven-from-static-info strategy for
EmbeddingBag.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 02:23:15 +00:00
spinlocked
f94335b1b8 Bucket Qwen decode positions (#328)
Add a positive-position bucket for Qwen cached decode so Metal can reuse a compiled bucket as p
advances during generation. Keep p=0 as the prefill bucket.

Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-20 16:56:29 -04:00
spinlocked
f62e3c50d0 Optimize Metal runtime setup and buffer reuse (#326)
* Cache MPS matmul setup objects

* Precompute per-bucket execution metadata

* Reuse dynamic intermediate buffers at bucket capacity

* Fix Metal shader language import

---------

Co-authored-by: Joe Fioti <jafioti@gmail.com>
2026-05-20 14:37:58 -04:00
Ali
eeeabd7c20 Online normalizer calculation for softmax (#324) 2026-05-20 14:26:55 -04:00
Joe Fioti
0f02466f3d Reject implicit casting in native binary ops (#330)
* Reject implicit casting in native binary ops

* Make native dtype handling strict and explicit
2026-05-20 13:51:06 -04:00
Joe Fioti
156fac518e Metal qwen (#327)
* Refine Luminal graph rewrite handling

* Generalize Metal scatter reuse and Qwen validation

* Add Qwen safetensor size accounting

* Fix Modal example imports for shared output validation

* Clarify Luminal contributor guidance

* Revert direct shard loading from qwen metal

* Remove qwen Metal CI job

* Add Metal Llama 1B CI and restore safe profiling timeouts

* Fix duplicate Metal ops and tests

* Fix Metal pipeline compilation on llama

* Run llama Metal CI on xlarge runners

* Resample search generations after timeout failures
2026-05-20 13:26:34 -04:00
Joe Fioti
a3df68bd43 Add full-modal-ready CUDA test workflows (#329) 2026-05-20 01:13:02 -04:00
Ali
7a95e56a8b copy_device_buffer_to_new_slice synchronizes stream unnecessarily (#322) 2026-05-19 17:26:38 -04:00
36 changed files with 6089 additions and 658 deletions

67
.github/workflows/test-full-cuda.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
name: Test Full CUDA
on:
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
rust_cuda_ignored_tests:
if: >-
github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
name: Rust CUDA Ignored Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 300
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run ignored CUDA Rust tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
GPU_TYPE: H100
MODAL_TIMEOUT: "14400"
CARGO_TEST_ARGS: "--ignored --test-threads=1"
run: modal run ci/modal_cargo_test.py
python_cuda_slow_tests:
if: >-
github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
name: Python CUDA Slow Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 300
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run slow pytest CUDA tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 14400 tests/ -v -s -m slow

View File

@@ -17,3 +17,20 @@ jobs:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
llama_1b_metal_example:
name: Llama 1B Metal Example
runs-on: macos-14-xlarge
timeout-minutes: 120
steps:
- uses: actions/checkout@v6
- name: Print runner hardware
run: system_profiler SPHardwareDataType SPDisplaysDataType
- name: Cache Hugging Face models
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: llama-1b-metal-hf-${{ runner.os }}-${{ runner.arch }}-v1
- name: Run Llama 1B Metal example and validate output
run: rustup update; python3 ci/metal_llama_1b_example.py

View File

@@ -0,0 +1,48 @@
import os
import subprocess
import sys
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
process = subprocess.Popen(
command,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
chunks = []
while True:
chunk = process.stdout.read1(4096)
if not chunk:
break
sys.stdout.buffer.write(chunk)
sys.stdout.buffer.flush()
chunks.append(chunk)
return_code = process.wait()
output = b"".join(chunks).decode("utf-8", errors="replace")
if return_code:
raise subprocess.CalledProcessError(return_code, command, output=output)
return output
def main():
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
sys.path.insert(0, os.path.join(repo_root, "ci"))
from example_output import validate_output
output = run_and_capture(
["cargo", "run", "--release", "-p", "luminal_metal", "--example", "llama_1b"],
cwd=repo_root,
env=os.environ.copy(),
)
if "TTFT:" not in output or "TPOT:" not in output:
raise AssertionError("Llama 1B Metal example did not complete generation")
validate_output("llama", output)
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,10 @@
import modal
import subprocess
import os
import shlex
gpu_type = os.environ.get("GPU_TYPE", "T4")
modal_timeout = int(os.environ.get("MODAL_TIMEOUT", "7200"))
CUDARC_CUDA_VERSION = "12080"
app = modal.App("luminal-ci-cargo-test")
@@ -28,7 +30,7 @@ cuda_image = (
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=7200, # 2 hours
timeout=modal_timeout,
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
@@ -43,17 +45,20 @@ def run_cargo_test():
)
compute_cap = result.stdout.strip().replace(".", "")
test_args = shlex.split(os.environ.get("CARGO_TEST_ARGS", "--test-threads=1"))
cmd = [
"cargo",
"test",
"--release",
"-p",
"luminal_cuda_lite",
"--verbose",
"--",
*test_args,
]
print("Running:", " ".join(cmd), flush=True)
subprocess.run(
[
"cargo",
"test",
"--release",
"-p",
"luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",
],
cmd,
cwd=WORKDIR,
env={
**os.environ,

View File

@@ -0,0 +1,450 @@
//! DLRM-shape megakernel — one CUDA kernel does the full forward pass
//! (bot MLP → N embedding gathers → dot-product interaction → top MLP)
//! per (thread × batch row). All intermediate activations live in
//! registers; weights are read straight from global memory and rely on
//! the L1 cache (the full weight footprint is a few KB).
//!
//! Parameterized by the DLRM family shape: dense input width, bot MLP
//! widths, number of sparse tables + their vocabs, embedding dim,
//! top MLP widths. CUDA source is generated per-shape via `format!`
//! and compiled through luminal's nvrtc wrapper with source-string
//! caching (same path as [`crate::kernel::matmul2d::Matmul2DKernel`]).
//!
//! Used by `luminal_python`'s PT2 translator when it detects a DLRM-shape
//! input graph — see `crates/luminal_python/rust/src/translator/dlrm_pattern.rs`.
//! The standalone `examples/dlrm/src/megakernel.rs` is the proof-of-concept
//! this module generalizes from.
//!
//! ## Input layout
//!
//! The kernel's input list (passed to `cx.custom_op`) is, in order:
//! 1. dense_x F32 (B, n_dense_in)
//! 2..2+n_sparse int32 indices per sparse table, each (B,)
//! — luminal collapses all integer types to 32-bit Int,
//! so the runtime delivers a 4-byte-per-element buffer
//! regardless of the original PyTorch dtype.
//! 2+n_sparse.. F32 embedding weights, one per table, each (V_k, m_spa)
//! then bot Linear weight+bias pairs, in topological order
//! then top Linear weight+bias pairs, in topological order
//!
//! The matcher in luminal_python lines up these inputs from the parsed
//! PT2 graph; mismatches there will surface as wrong-output bugs in
//! `tests/test_dlrm.py`, not as a crash.
use std::sync::Arc;
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use luminal::{
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, shape::Expression,
};
use crate::compile_module_image_for_current_device;
use crate::kernel::KernelOp;
/// Static shape description for the DLRM family. Every dim is a `usize`
/// resolved at translate time — the kernel bakes them all into the CUDA
/// source as compile-time constants.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DlrmMegaKernel {
/// Per-call batch size.
pub batch: usize,
/// Number of dense features (first element of `ln_bot`).
pub n_dense_in: usize,
/// Bot MLP layer widths. `ln_bot[0] == n_dense_in`; `ln_bot.last() == m_spa`.
/// Must have at least 2 entries (one Linear layer).
pub ln_bot: Vec<usize>,
/// Number of sparse embedding tables.
pub n_sparse: usize,
/// Vocab size for each table (length == `n_sparse`).
pub vocab_sizes: Vec<usize>,
/// Sparse embedding dim (equal across tables, == bot MLP output width).
pub m_spa: usize,
/// Top MLP layer widths. `ln_top[0] == m_spa + n_pairs`; `ln_top.last() == 1`.
pub ln_top: Vec<usize>,
}
impl DlrmMegaKernel {
/// `n_feat = 1 + n_sparse` — number of feature vectors fed into the
/// dot interaction (1 dense + sparse tables).
fn n_feat(&self) -> usize {
1 + self.n_sparse
}
/// `n_pairs = n_feat * (n_feat - 1) / 2` — number of strictly-lower-tri
/// pairs produced by the dot interaction.
fn n_pairs(&self) -> usize {
let n = self.n_feat();
n * (n - 1) / 2
}
/// Validation: cheap up-front check that the shape is internally
/// consistent. The matcher should have caught all of these but a
/// debug-assert keeps the kernel compile path well-defined.
fn validate(&self) {
assert!(self.ln_bot.len() >= 2, "ln_bot must have ≥2 entries");
assert!(self.ln_top.len() >= 2, "ln_top must have ≥2 entries");
assert_eq!(self.ln_bot[0], self.n_dense_in, "ln_bot[0] must == n_dense_in");
assert_eq!(*self.ln_bot.last().unwrap(), self.m_spa, "ln_bot.last() must == m_spa");
assert_eq!(self.vocab_sizes.len(), self.n_sparse);
assert_eq!(
self.ln_top[0],
self.m_spa + self.n_pairs(),
"ln_top[0] must == m_spa + n_pairs"
);
assert_eq!(*self.ln_top.last().unwrap(), 1, "ln_top.last() must == 1 (binary classifier)");
assert!(self.batch > 0);
}
/// Generate the CUDA source for this kernel shape.
fn cuda_source(&self) -> String {
let n_feat = self.n_feat();
let n_pairs = self.n_pairs();
// ---- Kernel signature ------------------------------------------
// luminal's CustomOp dispatcher calls the kernel as
// kernel(output_ptr, input_ptrs...)
// — see `host/cublaslt`'s C/D ordering and matmul2d's
// `matmul_2d_kernel(float* C, const float* A, ...)`. Match that
// by putting `out` first, then the inputs in the same order as
// emit_megakernel builds the inputs vec.
let mut sig = String::from(
" float* __restrict__ out,\n const float* __restrict__ dense_x,\n",
);
for k in 0..self.n_sparse {
// 32-bit signed — see module docstring re: luminal's Int collapse.
sig.push_str(&format!(" const int* __restrict__ idx_{k},\n"));
}
for k in 0..self.n_sparse {
sig.push_str(&format!(" const float* __restrict__ emb_{k}_w,\n"));
}
// Bot MLP: one Linear per (ln_bot[i] → ln_bot[i+1]). Stored
// PyTorch-style as (out, in), bias (out,).
for i in 0..self.ln_bot.len() - 1 {
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_w,\n"));
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_b,\n"));
}
for i in 0..self.ln_top.len() - 1 {
let trail = if i == self.ln_top.len() - 2 { "" } else { "," };
sig.push_str(&format!(" const float* __restrict__ top_l{i}_w,\n"));
sig.push_str(&format!(" const float* __restrict__ top_l{i}_b{trail}\n"));
}
// ---- Body --------------------------------------------------------
let mut body = String::new();
// 1. Load dense row into registers.
body.push_str(&format!(
" // Bot MLP layer 0 input: dense row\n \
float layer_in[{}];\n \
#pragma unroll\n \
for (int i = 0; i < {n_dense_in}; ++i) layer_in[i] = dense_x[bi * {n_dense_in} + i];\n\n",
self.ln_bot[0],
n_dense_in = self.n_dense_in,
));
// 2. Bot MLP — sequence of Linear+ReLU. Output of last layer
// becomes `x[m_spa]` for the interaction.
for i in 0..self.ln_bot.len() - 1 {
let in_w = self.ln_bot[i];
let out_w = self.ln_bot[i + 1];
body.push_str(&format!(
" // Bot Linear {i}: ({in_w}{out_w}) + ReLU\n \
float bot_l{i}_out[{out_w}];\n \
#pragma unroll\n \
for (int j = 0; j < {out_w}; ++j) {{\n \
float a = bot_l{i}_b[j];\n \
#pragma unroll\n \
for (int i_ = 0; i_ < {in_w}; ++i_) a += layer_in[i_] * bot_l{i}_w[j*{in_w} + i_];\n \
bot_l{i}_out[j] = fmaxf(a, 0.0f);\n \
}}\n \
// shuffle output into `layer_in` for the next iteration / interaction\n \
#pragma unroll\n \
for (int i = 0; i < {out_w}; ++i) layer_in[i] = bot_l{i}_out[i];\n\n",
));
}
// After the loop, `layer_in[..m_spa]` holds dense_out ("x").
body.push_str(&format!(
" float x[{m_spa}];\n \
#pragma unroll\n \
for (int i = 0; i < {m_spa}; ++i) x[i] = layer_in[i];\n\n",
m_spa = self.m_spa,
));
// 3. Sparse embedding gathers (one row per table, bag size 1).
for k in 0..self.n_sparse {
body.push_str(&format!(
" // Embedding lookup {k}\n \
float ly_{k}[{m_spa}];\n \
{{\n \
int i_{k} = idx_{k}[bi];\n \
#pragma unroll\n \
for (int j = 0; j < {m_spa}; ++j) ly_{k}[j] = emb_{k}_w[i_{k}*{m_spa} + j];\n \
}}\n\n",
m_spa = self.m_spa,
));
}
// 4. Dot interaction: compute n_pairs strictly-lower-tri dot products
// over the n_feat = 1 + n_sparse vectors (x, ly_0, ly_1, ...).
// Order matches MiniDLRM._interact: for i in 0..n_feat for j in 0..i.
// Vec[0] = x, Vec[k+1] = ly_k.
body.push_str(&format!(" float zflat[{n_pairs}];\n"));
let vec_name = |idx: usize| -> String {
if idx == 0 {
"x".to_string()
} else {
format!("ly_{}", idx - 1)
}
};
let mut pair_idx = 0usize;
for i in 0..n_feat {
for j in 0..i {
let a = vec_name(i);
let b = vec_name(j);
let mut terms = Vec::with_capacity(self.m_spa);
for d in 0..self.m_spa {
terms.push(format!("{a}[{d}]*{b}[{d}]"));
}
body.push_str(&format!(
" zflat[{pair_idx}] = {};\n",
terms.join(" + ")
));
pair_idx += 1;
}
}
body.push('\n');
// 5. R = cat([x, zflat]) → top MLP input.
let r_len = self.m_spa + n_pairs;
body.push_str(&format!(
" float r[{r_len}];\n \
#pragma unroll\n \
for (int i = 0; i < {m_spa}; ++i) r[i] = x[i];\n \
#pragma unroll\n \
for (int i = 0; i < {n_pairs}; ++i) r[{m_spa} + i] = zflat[i];\n\n",
m_spa = self.m_spa,
));
// 6. Top MLP: Linear+ReLU chain, ending with Linear+Sigmoid.
// We treat `r` as the first layer input and reuse a single
// register array `top_in[]` for subsequent layers.
let max_top = *self.ln_top.iter().max().unwrap();
body.push_str(&format!(
" float top_in[{max_top}];\n \
#pragma unroll\n \
for (int i = 0; i < {r_len}; ++i) top_in[i] = r[i];\n\n",
));
let n_top_layers = self.ln_top.len() - 1;
for i in 0..n_top_layers {
let in_w = self.ln_top[i];
let out_w = self.ln_top[i + 1];
let is_last = i == n_top_layers - 1;
body.push_str(&format!(
" // Top Linear {i}: ({in_w}{out_w})\n \
float top_l{i}_out[{out_w}];\n \
#pragma unroll\n \
for (int j = 0; j < {out_w}; ++j) {{\n \
float a = top_l{i}_b[j];\n \
#pragma unroll\n \
for (int i_ = 0; i_ < {in_w}; ++i_) a += top_in[i_] * top_l{i}_w[j*{in_w} + i_];\n \
top_l{i}_out[j] = {activation};\n \
}}\n",
activation = if is_last {
"1.0f / (1.0f + __expf(-a))"
} else {
"fmaxf(a, 0.0f)"
},
));
if !is_last {
body.push_str(&format!(
" #pragma unroll\n \
for (int i = 0; i < {out_w}; ++i) top_in[i] = top_l{i}_out[i];\n\n",
));
} else {
// Final layer: write to global output. ln_top.last() == 1
// so this is just a single value.
body.push_str(&format!(
" out[bi] = top_l{i}_out[0];\n",
));
}
}
// Assemble the full source.
format!(
"extern \"C\" __global__ void dlrm_mega(\n{sig}) {{\n \
int bi = blockIdx.x * blockDim.x + threadIdx.x;\n \
if (bi >= {batch}) return;\n\n\
{body}\
}}\n",
batch = self.batch,
)
}
}
impl KernelOp for DlrmMegaKernel {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
self.validate();
let kernel = self.cuda_source();
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
(m.clone(), f.clone())
} else {
if std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").is_ok() {
let path = "/tmp/dlrm_megakernel_generated.cu";
let _ = std::fs::write(path, &kernel);
eprintln!("[DlrmMegaKernel] wrote generated source to {path}");
}
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
.expect("nvrtc compile failed for DLRM megakernel");
let module = stream.context().load_module(ptx).expect("load_module");
let func = module
.load_function("dlrm_mega")
.expect("load_function dlrm_mega");
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
const BLOCK: usize = 128;
let grid_x = self.batch.div_ceil(BLOCK);
(
func,
module,
kernel,
(
Expression::from(grid_x),
Expression::from(1usize),
Expression::from(1usize),
),
(
Expression::from(BLOCK),
Expression::from(1usize),
Expression::from(1usize),
),
Expression::from(0usize),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
Expression::from(self.batch)
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn output_dtype(&self) -> DType {
DType::F32
}
fn bytes_loaded(&self) -> Expression {
// Per batch row:
// dense: n_dense_in × f32
// indices: n_sparse × i64
// embs: n_sparse × m_spa × f32 (single row each)
// bot Ws: sum(in*out) for each layer, × f32 (shared across batch — costed once)
// bot bs: sum(out) × f32
// top Ws/bs same shape
let bot_w: usize = (0..self.ln_bot.len() - 1)
.map(|i| self.ln_bot[i] * self.ln_bot[i + 1])
.sum();
let bot_b: usize = self.ln_bot.iter().skip(1).sum();
let top_w: usize = (0..self.ln_top.len() - 1)
.map(|i| self.ln_top[i] * self.ln_top[i + 1])
.sum();
let top_b: usize = self.ln_top.iter().skip(1).sum();
let per_row =
self.n_dense_in * 4 + self.n_sparse * 8 + self.n_sparse * self.m_spa * 4;
let weights = (bot_w + bot_b + top_w + top_b) * 4;
Expression::from(self.batch * per_row + weights)
}
fn bytes_stored(&self) -> Expression {
// batch × 1 × f32
Expression::from(self.batch * 4)
}
fn flops(&self) -> Expression {
// Per row:
// bot Linears: 2*in*out + out (FMAs + bias)
// embedding gathers: 0 FMAs (loads)
// dot interaction: n_pairs × m_spa MACs
// top Linears: 2*in*out + out + (relu/sigmoid cost ~5)
let bot: usize = (0..self.ln_bot.len() - 1)
.map(|i| 2 * self.ln_bot[i] * self.ln_bot[i + 1] + self.ln_bot[i + 1])
.sum();
let dot = self.n_pairs() * self.m_spa * 2;
let top: usize = (0..self.ln_top.len() - 1)
.map(|i| 2 * self.ln_top[i] * self.ln_top[i + 1] + self.ln_top[i + 1])
.sum();
Expression::from(self.batch * (bot + dot + top + 5))
}
fn kernel_name(&self) -> &'static str {
"DlrmMega"
}
}
/// `CustomOp` wrapper for [`DlrmMegaKernel`]. Same pattern as
/// [`crate::kernel::matmul2d::Matmul2DCustom`].
#[derive(Debug, Clone)]
pub struct DlrmMegaCustom(pub DlrmMegaKernel);
impl CustomOp for DlrmMegaCustom {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mini_dlrm() -> DlrmMegaKernel {
DlrmMegaKernel {
batch: 2048,
n_dense_in: 13,
ln_bot: vec![13, 8, 4],
n_sparse: 3,
vocab_sizes: vec![10, 20, 30],
m_spa: 4,
ln_top: vec![10, 8, 1],
}
}
#[test]
fn shape_invariants() {
let k = mini_dlrm();
assert_eq!(k.n_feat(), 4);
assert_eq!(k.n_pairs(), 6);
assert_eq!(k.ln_top[0], k.m_spa + k.n_pairs());
k.validate();
}
#[test]
fn cuda_source_compiles_in_format() {
let src = mini_dlrm().cuda_source();
// Sanity checks on the generated source — no nvrtc invocation here,
// just verify the structural pieces exist.
assert!(src.contains("extern \"C\" __global__ void dlrm_mega"));
assert!(src.contains("if (bi >= 2048)"));
// 3 embedding lookups
assert!(src.contains("ly_0[") && src.contains("ly_1[") && src.contains("ly_2["));
// 6 dot products
assert!(src.contains("zflat[5]"));
// Sigmoid epilogue
assert!(src.contains("1.0f / (1.0f + __expf(-a))"));
}
}

View File

@@ -11,6 +11,7 @@ use uuid::Uuid;
pub mod conv2d;
pub mod cuda_graph;
pub mod dlrm_megakernel;
pub mod fusion;
pub mod hlir;
pub mod matmul2d;
@@ -19,6 +20,7 @@ pub mod rope;
pub use conv2d::KernelConv2D;
pub use cuda_graph::*;
pub use dlrm_megakernel::{DlrmMegaCustom, DlrmMegaKernel};
pub use matmul2d::{
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
matmul_3d, matmul_3d_t,

View File

@@ -1338,9 +1338,21 @@ impl KernelOp for KernelSoftmax {
#define FULL_MASK 0xffffffff
#define NEG_INF_F __int_as_float(0xff800000)
{dyn_defines}
#define LOG2E 1.4426950408889634f
extern \"C\" {{
// Online normalizer calculation for softmax (Milakov & Gimelshein 2018).
// Merge two partial (max, sum) pairs using the online softmax rule.
__device__ __forceinline__ void merge_md(float *m, float *d, float m2, float d2) {{
float new_m = fmaxf(*m, m2);
*d = *d * exp2f((*m - new_m) * LOG2E) + d2 * exp2f((m2 - new_m) * LOG2E);
*m = new_m;
}}
__global__ void fused_softmax(float *out, const float *inp{dyn_dims_param}) {{
__shared__ float shared[THREADS_PER_BLOCK / WARP_SIZE];
__shared__ float sh_m[THREADS_PER_BLOCK / WARP_SIZE];
__shared__ float sh_d[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
@@ -1352,55 +1364,36 @@ extern \"C\" {{
long long in_stride = {in_reduce_stride};
long long out_stride = {out_reduce_stride};
// Pass 1: find max
float max_val = NEG_INF_F;
// Pass 1: one read of inp produces (global_max, global_sum).
float m = NEG_INF_F, d = 0.0f;
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
max_val = fmaxf(max_val, inp[in_base + i * in_stride]);
merge_md(&m, &d, inp[in_base + i * in_stride], 1.0f);
}}
// Warp reduce: collapse 32 threads within each warp down to lane 0.
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
}}
if (lane_id == 0) shared[warp_id] = max_val;
if (lane_id == 0) {{ sh_m[warp_id] = m; sh_d[warp_id] = d; }}
__syncthreads();
// Block reduce: warp 0 collapses the 8 warp results down to one.
if (warp_id == 0) {{
max_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : NEG_INF_F;
m = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_m[tid] : NEG_INF_F;
d = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_d[tid] : 0.0f;
#pragma unroll
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
}}
shared[0] = max_val;
sh_m[0] = m;
sh_d[0] = d;
}}
__syncthreads();
max_val = shared[0];
float global_max = sh_m[0];
float inv_sum = 1.0f / sh_d[0];
// Pass 2: compute exp2 and sum
float sum_val = 0.0f;
// Pass 2: write final softmax values.
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
float v = exp2f((inp[in_base + i * in_stride] - max_val) * 1.4426950408889634f);
out[out_base + i * out_stride] = v; // store exp temporarily
sum_val += v;
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
}}
if (lane_id == 0) shared[warp_id] = sum_val;
__syncthreads();
if (warp_id == 0) {{
sum_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : 0.0f;
#pragma unroll
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
}}
shared[0] = sum_val;
}}
__syncthreads();
float inv_sum = 1.0f / shared[0];
// Pass 3: normalize
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
out[out_base + i * out_stride] *= inv_sum;
out[out_base + i * out_stride] = exp2f((inp[in_base + i * in_stride] - global_max) * LOG2E) * inv_sum;
}}
}}
}}"

View File

@@ -106,6 +106,12 @@ pub(crate) struct CompiledBucket {
pub(crate) bucket_indices: FxHashMap<char, usize>,
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
pub(crate) hlir_synced: bool,
/// Cached topological order of exec_graph nodes. Lazily populated on
/// first execute() and invalidated only when the exec_graph itself
/// changes (compilation, bucket rebuild). Avoids the per-call
/// `petgraph::algo::toposort` Vec allocation + traversal — small but
/// real in hot inference loops.
pub(crate) exec_topo_order: Vec<NodeIndex>,
}
impl CompiledBucket {
@@ -130,6 +136,7 @@ impl CompiledBucket {
intermediate_buffer_dims: FxHashSet::default(),
bucket_indices: FxHashMap::default(),
hlir_synced: false,
exec_topo_order: Vec::new(),
}
}
}
@@ -225,7 +232,6 @@ impl CudaRuntime {
result::memcpy_dtod_async(dst_ptr, src.ptr(), src.len(), stream.cu_stream())
.expect("cuMemcpyDtoDAsync failed");
}
stream.synchronize().unwrap();
dst
}
@@ -328,6 +334,24 @@ impl CudaRuntime {
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
let id = id.to_id();
// Fast path: if the same pointer is already registered, this is a no-op.
// PyTorch's caching allocator routinely hands back the same device
// pointer for the same logical tensor on each forward; bench loops in
// particular hammer this. Skipping the cudarc upgrade_device_ptr +
// ManuallyDrop reallocation + the changed_hlir insert + the per-bucket
// ptr re-cache that fires on the next execute saves ~2µs per input.
if let Some(CudaInput::Ptr(prev)) = self.hlir_buffers.get(&id) {
if *prev == device_ptr {
// Refresh the external_buffers view in case n_bytes shrank to
// exactly cover the live region; cheap and keeps the slice
// length correct without rebuilding the registration.
if let Some(ext) = self.external_buffers.get(&id) {
if ext.len() == n_bytes {
return;
}
}
}
}
// Create CudaSlice view via cudarc's upgrade_device_ptr.
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
let slice = unsafe {
@@ -1466,9 +1490,19 @@ impl Runtime for CudaRuntime {
self.apply_output_ptr_registrations();
let total_start = std::time::Instant::now();
// Populate the topo-order cache lazily — only on first execute for
// this bucket. Walking exec_graph + allocating a Vec every iter
// measurably shows up at small batches where the kernel work itself
// is sub-microsecond and the per-call overhead dominates.
{
let bucket = &mut self.compiled_buckets[self.active_bucket];
if bucket.exec_topo_order.is_empty() && bucket.exec_graph.node_count() > 0 {
bucket.exec_topo_order = toposort(&bucket.exec_graph, None).unwrap();
}
}
let bucket = &self.compiled_buckets[self.active_bucket];
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
for &exec_node in &bucket.exec_topo_order {
let exec_op = &bucket.exec_graph[exec_node];
trace!("Executing: {:?}", exec_op);
@@ -1540,21 +1574,26 @@ impl Runtime for CudaRuntime {
self.cuda_stream.synchronize().unwrap();
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
// Populate last_kernel_stats from HostOps that report stats
self.last_kernel_stats.clear();
let bucket = &self.compiled_buckets[self.active_bucket];
for exec_node in bucket.exec_graph.node_indices() {
let exec_op = &bucket.exec_graph[exec_node];
if let Some(name) = exec_op.internal.stats_name() {
self.last_kernel_stats.push(KernelStats {
name,
execution_time_us: 0.0,
bytes_loaded: 0,
bytes_stored: 0,
flops: 0,
bandwidth_gbps: 0.0,
tflops: 0.0,
});
// last_kernel_stats is only read by print_execution_stats() — a
// diagnostic API. Populating the Vec on every execute() (looping all
// exec nodes and calling stats_name() on each) is wasteful in
// production inference loops. Gate it on the profiling flag.
if self.profiling {
self.last_kernel_stats.clear();
let bucket = &self.compiled_buckets[self.active_bucket];
for exec_node in bucket.exec_graph.node_indices() {
let exec_op = &bucket.exec_graph[exec_node];
if let Some(name) = exec_op.internal.stats_name() {
self.last_kernel_stats.push(KernelStats {
name,
execution_time_us: 0.0,
bytes_loaded: 0,
bytes_stored: 0,
flops: 0,
bandwidth_gbps: 0.0,
tflops: 0.0,
});
}
}
}
@@ -1576,11 +1615,22 @@ impl Runtime for CudaRuntime {
}
}
// Free owned input buffers after a step so they're not held until the
// next set_data overwrites them. External-pointer inputs (registered
// via set_device_ptr) are caller-owned and the runtime doesn't free
// their memory either way — consuming them only invalidates the
// registration and forces the caller to re-register on the next
// execute. That's pure waste in tight inference loops (e.g.
// luminal_python's torch.compile backend, which re-invokes execute()
// for every forward), so leave external-pointer entries in place.
let to_consume: Vec<NodeIndex> = self
.hlir_buffers
.keys()
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
.copied()
.iter()
.filter(|(hlir_node, input)| {
!inputs_with_outputs.contains(hlir_node)
&& !matches!(input, CudaInput::Ptr(_))
})
.map(|(n, _)| *n)
.collect();
for hlir_node in to_consume {

View File

@@ -19,7 +19,14 @@ bytemuck = "1.24.0"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
luminal_nn = { path = "../luminal_nn" }
luminal_tracing = { path = "../luminal_tracing" }
proptest = "1.9.0"
rand = "0.9.2"
rustc-hash = "2.1"
tokenizers = "0.22.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }

View File

@@ -0,0 +1,641 @@
use hf_hub::api::sync::Api;
use luminal::{
dtype::DType,
graph::{BuildSearchSpaceOptions, DimBucket, Graph},
prelude::{F32Pow, GraphTensor, Runtime},
};
use luminal_metal::MetalRuntime;
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
use luminal_tracing::luminal_filter;
use rustc_hash::FxHashSet;
use std::{
error::Error,
io::Write,
path::PathBuf,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "unsloth/Llama-3.2-1B-Instruct";
const MAX_SEQ_LEN: usize = 2048;
const GEN_TOKENS: usize = 96;
const SEARCH_GRAPHS: usize = 100;
const SEARCH_MEMORY_MIB: usize = 1536;
const PROMPT: &str = "In one short paragraph, explain neural networks using the words layers, neurons, learning, and data.";
const LAYERS: usize = 16;
const HIDDEN: usize = 2048;
const INTERMEDIATE: usize = 8192;
const HEAD_DIM: usize = 64;
const N_HEADS: usize = 32;
const N_KV_HEADS: usize = 8;
const KV_GROUPS: usize = N_HEADS / N_KV_HEADS;
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
const VOCAB_SIZE: usize = 128256;
const RMS_NORM_EPS: f32 = 1e-5;
const ROPE_THETA: f32 = 500_000.0;
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
fn prepare_hf_model() -> Result<PathBuf, Box<dyn Error>> {
let repo = Api::new()?.model(REPO_ID.to_string());
let tokenizer_path = repo.get("tokenizer.json")?;
repo.get("model.safetensors")?;
Ok(tokenizer_path.parent().unwrap().to_path_buf())
}
fn llama3_chat_prompt(user_prompt: &str) -> String {
format!(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
}
#[derive(Default, Clone)]
struct StepProfile {
total: Duration,
execute: Duration,
get_logits: Duration,
cache_roundtrip: Duration,
}
fn avg_ms(duration: Duration, n: usize) -> f64 {
if n == 0 {
0.0
} else {
duration.as_secs_f64() * 1e3 / n as f64
}
}
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
let mut row = logits_row.to_vec();
for &tok in seen {
let logit = &mut row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32
}
fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
let mut mask = vec![-1e10f32; q_pos.len() * context_len];
for (qi, &pos) in q_pos.iter().enumerate() {
for ci in 0..context_len {
if ci <= pos {
mask[qi * context_len + ci] = 0.0;
}
}
}
mask
}
struct KVCache {
k_caches: Vec<GraphTensor>,
v_caches: Vec<GraphTensor>,
}
impl KVCache {
fn new(cx: &mut Graph, num_slots: usize) -> Self {
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
k_caches.push(
cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM))
.persist(),
);
v_caches.push(
cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM))
.persist(),
);
}
Self { k_caches, v_caches }
}
}
struct Llama {
embedding: GraphTensor,
layers: Vec<LlamaLayer>,
lm_norm: LayerNorm,
}
impl Llama {
fn init(cx: &mut Graph) -> Self {
let mut layers = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
layers.push(LlamaLayer {
up: cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
gate: cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
down: cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist(),
q_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
k_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
v_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
o_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
});
}
Self {
embedding: cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
layers,
lm_norm: LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
),
}
}
fn forward(
&self,
input: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let mut x = self.embedding.gather(
(input * HIDDEN).expand_dim(1, HIDDEN)
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
q_pos,
scatter_idx,
gather_idx,
attn_mask,
kv_cache.k_caches[i],
kv_cache.v_caches[i],
);
x = x_new;
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.embedding.t());
(logits, cache_outputs)
}
}
struct LlamaLayer {
up: GraphTensor,
gate: GraphTensor,
down: GraphTensor,
q_proj: GraphTensor,
k_proj: GraphTensor,
v_proj: GraphTensor,
o_proj: GraphTensor,
attn_rms: LayerNorm,
mlp_rms: LayerNorm,
}
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
let freqs = input
.graph()
.arange_options(0, HEAD_DIM, 2)
.cast(DType::F32)
/ HEAD_DIM as f32;
let inv_freqs = ROPE_THETA.pow(freqs).reciprocal();
let emb = pos_ids
.cast(DType::F32)
.expand_dim(1, 1)
.matmul(inv_freqs.expand_dim(0, 1));
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
let x1 = input.slice((.., .., HEAD_DIM / 2..));
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
let x0_out = x0 * cos - x1 * sin;
let x1_out = x1 * cos + x0 * sin;
x0_out
.concat_along(x1_out, 2)
.transpose(0, 1)
.merge_dims(1, 2)
}
#[allow(clippy::too_many_arguments)]
fn attention(
q_rope: GraphTensor,
k_rope: GraphTensor,
v: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
let k = gather_rows(k_cache_out, gather_idx, KV_DIM);
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM);
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
let masked_scores = scores + attn_mask.expand_dim(0, N_HEADS);
let weights = masked_scores.softmax(2);
let out = weights.matmul(v_ctx);
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
#[allow(clippy::too_many_arguments)]
fn run_model_step(
cx: &mut Graph,
runtime: &mut MetalRuntime,
input: GraphTensor,
q_pos_t: GraphTensor,
scatter_idx_t: GraphTensor,
gather_idx_t: GraphTensor,
attn_mask_t: GraphTensor,
logits: GraphTensor,
kv_cache: &KVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
tokens: &[u32],
q_pos: &[i32],
scatter_idx: &[i32],
gather_idx: &[i32],
attn_mask: &[f32],
) -> (Vec<f32>, StepProfile) {
let start = Instant::now();
cx.set_dim('s', tokens.len());
cx.set_dim('c', gather_idx.len());
runtime.set_data(input, tokens.iter().map(|t| *t as i32).collect::<Vec<_>>());
runtime.set_data(q_pos_t, q_pos.to_vec());
runtime.set_data(scatter_idx_t, scatter_idx.to_vec());
runtime.set_data(gather_idx_t, gather_idx.to_vec());
runtime.set_data(attn_mask_t, attn_mask.to_vec());
runtime.allocate_intermediate_buffers(&cx.dyn_map);
let execute_start = Instant::now();
runtime.execute(&cx.dyn_map);
let execute = execute_start.elapsed();
let logits_start = Instant::now();
let logits_data = runtime.get_f32(logits);
let get_logits = logits_start.elapsed();
let cache_start = Instant::now();
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
let cache_roundtrip = cache_start.elapsed();
(
logits_data,
StepProfile {
total: start.elapsed(),
execute,
get_logits,
cache_roundtrip,
},
)
}
fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.try_init();
let model_dir = prepare_hf_model()?;
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
.map_err(|err| err as Box<dyn Error>)?;
let prompt_tokens = tokenizer
.encode(llama3_chat_prompt(PROMPT), false)
.map_err(|err| err as Box<dyn Error>)?
.get_ids()
.to_vec();
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
let kv_cache = KVCache::new(&mut cx, MAX_SEQ_LEN);
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
&kv_cache,
);
let logits = logits.output();
for (k_out, v_out) in &cache_outputs {
k_out.output();
v_out.output();
}
cx.set_dim('s', 1);
cx.set_dim('c', 1);
println!("Building E-Graph...");
let egraph_start = Instant::now();
cx.build_search_space_with_options::<MetalRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
);
println!(
" E-Graph build: {:.2} s",
egraph_start.elapsed().as_secs_f64()
);
println!("Loading weights...");
let load_start = Instant::now();
let mut runtime = MetalRuntime::initialize(());
runtime.load_safetensors(&cx, model_dir.join("model.safetensors").to_str().unwrap());
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
let compile_start = Instant::now();
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
.next_power_of_two()
.min(MAX_SEQ_LEN);
let search_s = 16.min(max_prefill).max(2);
let search_c = 16.min(max_context).max(2);
cx.set_dim_buckets(
's',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_prefill).representative(search_s),
],
);
cx.set_dim_buckets(
'c',
&[
DimBucket::new(1, 1),
DimBucket::new(2, max_context).representative(search_c),
],
);
cx.set_dim('s', search_s);
cx.set_dim('c', search_c);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(q_pos_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, SEARCH_GRAPHS);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()
);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let prompt_len = prompt_tokens.len();
let mut context_len = 0usize;
let mut profiles = Vec::new();
let mut seen_tokens = FxHashSet::default();
let repetition_penalty = 1.05;
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, GEN_TOKENS
);
let mut generated = 0usize;
let mut next_token = None;
if GEN_TOKENS > 0 && prompt_len > 0 {
let positions: Vec<usize> = (0..prompt_len).collect();
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
let mask = causal_mask(&positions, prompt_len);
let (logits_data, profile) = run_model_step(
&mut cx,
&mut runtime,
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
logits,
&kv_cache,
&cache_outputs,
&prompt_tokens,
&q_pos,
&q_pos,
&q_pos,
&mask,
);
context_len = prompt_len;
let token = sample_greedy(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(token);
next_token = Some(token);
generated = 1;
profiles.push(profile);
if token != EOS_TOKEN && token != STOP_TOKEN {
print!(
"{}",
tokenizer
.decode(&[token], true)
.map_err(|err| err as Box<dyn Error>)?
);
std::io::stdout().flush()?;
}
}
while generated < GEN_TOKENS {
let current_token = match next_token {
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
_ => break,
};
let gather_idx = (0..=context_len as i32).collect::<Vec<_>>();
let mask = causal_mask(&[context_len], context_len + 1);
let (logits_data, profile) = run_model_step(
&mut cx,
&mut runtime,
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
logits,
&kv_cache,
&cache_outputs,
&[current_token],
&[context_len as i32],
&[context_len as i32],
&gather_idx,
&mask,
);
context_len += 1;
let token = sample_greedy(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(token);
next_token = Some(token);
generated += 1;
profiles.push(profile);
if token == EOS_TOKEN || token == STOP_TOKEN {
break;
}
print!(
"{}",
tokenizer
.decode(&[token], true)
.map_err(|err| err as Box<dyn Error>)?
);
std::io::stdout().flush()?;
}
println!();
let ttft = profiles.first().map(|p| p.total).unwrap_or_default();
let decode_steps = profiles.len().saturating_sub(1);
let decode_total: Duration = profiles.iter().skip(1).map(|p| p.total).sum();
println!(" TTFT: {:.2} ms", ttft.as_secs_f64() * 1e3);
println!(" TPOT: {:.2} ms", avg_ms(decode_total, decode_steps));
let execute_total: Duration = profiles.iter().map(|p| p.execute).sum();
let logits_total: Duration = profiles.iter().map(|p| p.get_logits).sum();
let cache_total: Duration = profiles.iter().map(|p| p.cache_roundtrip).sum();
println!(
" Profile: n={}, exec={:.2} ms, logits={:.2} ms, cache={:.2} ms",
profiles.len(),
avg_ms(execute_total, profiles.len()),
avg_ms(logits_total, profiles.len()),
avg_ms(cache_total, profiles.len()),
);
Ok(())
}

View File

@@ -6,10 +6,127 @@ pub use ops::*;
use luminal::dtype::DType;
use luminal::op::EgglogOp;
use luminal::prelude::*;
use metal::{Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device};
use metal::{
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device,
foreign_types::ForeignTypeRef, mps,
};
use objc::rc::StrongPtr;
use objc::runtime::Object;
use objc::{class, msg_send, sel, sel_impl};
use std::cell::RefCell;
pub const DYN_SLOT_COUNT: usize = 26;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct MpsMatrixDescriptorKey {
rows: usize,
cols: usize,
row_bytes: u64,
data_type: isize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct MpsMatmulKey {
transpose_lhs: bool,
transpose_rhs: bool,
m: usize,
n: usize,
k: usize,
alpha: u64,
beta: u64,
}
#[derive(Default)]
pub struct MpsKernelCache {
matrix_descriptors: FxHashMap<MpsMatrixDescriptorKey, StrongPtr>,
matmul_kernels: FxHashMap<MpsMatmulKey, StrongPtr>,
}
impl MpsKernelCache {
pub(crate) fn matrix_descriptor(
&mut self,
rows: usize,
cols: usize,
row_bytes: u64,
dtype: DType,
) -> *mut Object {
let key = MpsMatrixDescriptorKey {
rows,
cols,
row_bytes,
data_type: Self::mps_data_type(dtype),
};
let descriptor = self
.matrix_descriptors
.entry(key)
.or_insert_with(|| unsafe {
let descriptor: *mut Object = msg_send![
class!(MPSMatrixDescriptor),
matrixDescriptorWithRows: rows
columns: cols
rowBytes: row_bytes as usize
dataType: key.data_type
];
StrongPtr::retain(descriptor)
});
**descriptor
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn matrix_multiplication(
&mut self,
command_buffer: &CommandBufferRef,
transpose_lhs: bool,
transpose_rhs: bool,
m: usize,
n: usize,
k: usize,
alpha: f64,
beta: f64,
) -> *mut Object {
let key = MpsMatmulKey {
transpose_lhs,
transpose_rhs,
m,
n,
k,
alpha: alpha.to_bits(),
beta: beta.to_bits(),
};
let kernel = self.matmul_kernels.entry(key).or_insert_with(|| unsafe {
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
let kernel: *mut Object = msg_send![
kernel,
initWithDevice: device
transposeLeft: transpose_lhs
transposeRight: transpose_rhs
resultRows: m
resultColumns: n
interiorColumns: k
alpha: alpha
beta: beta
];
StrongPtr::new(kernel)
});
**kernel
}
fn mps_data_type(dtype: DType) -> isize {
match dtype {
DType::F32 | DType::TF32 => mps::MPSDataType::Float32 as isize,
DType::F16 => mps::MPSDataType::Float16 as isize,
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
}
}
}
pub struct MetalEncodeContext<'a> {
pub(crate) command_buffer: &'a CommandBufferRef,
pub(crate) dyn_buffer: &'a Buffer,
pub(crate) mps_cache: &'a RefCell<MpsKernelCache>,
}
#[derive(Debug, Clone)]
pub struct MetalMulInfo {
pub shape: Vec<Expression>,
@@ -52,19 +169,18 @@ pub trait MetalKernelOp: EgglogOp {
#[allow(clippy::too_many_arguments)]
fn encode(
&self,
command_buffer: &CommandBufferRef,
context: &mut MetalEncodeContext<'_>,
pipeline: Option<&ComputePipelineState>,
inputs: &[&Buffer],
output: &Buffer,
dyn_map: &FxHashMap<char, usize>,
dyn_buffer: &Buffer,
_input_dtypes: &[DType],
_output_dtype: DType,
) {
let pipeline = pipeline.expect("compute pipeline not compiled");
let encoder = command_buffer.new_compute_command_encoder();
let encoder = context.command_buffer.new_compute_command_encoder();
let dyn_idx = inputs.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(dyn_buffer), 0);
encoder.set_buffer(dyn_idx, Some(context.dyn_buffer), 0);
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
encoder.end_encoding();
}

View File

@@ -1,4 +1,4 @@
use super::{MPSMatrixLayout, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
use super::{MPSMatrixLayout, MetalEncodeContext, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
use luminal::{
egglog_utils::{
SerializedEGraph,
@@ -19,9 +19,8 @@ use luminal::{
shape::flatten_strides,
};
use metal::{
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLSize,
Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLLanguageVersion, MTLSize,
foreign_types::{ForeignType, ForeignTypeRef},
mps,
};
use objc::runtime::Object;
use objc::{class, msg_send, sel, sel_impl};
@@ -56,15 +55,21 @@ pub type MetalOps = (
);
fn compile_shader(device: &Device, source: &str, function_name: &str) -> ComputePipelineState {
let options = metal::CompileOptions::new();
options.set_language_version(MTLLanguageVersion::V2_4);
let library = device
.new_library_with_source(source, &metal::CompileOptions::new())
.expect("Failed to compile Metal shader");
.new_library_with_source(source, &options)
.unwrap_or_else(|err| {
panic!("Failed to compile Metal shader {function_name}: {err:?}\n{source}")
});
let function = library
.get_function(function_name, None)
.expect("Failed to get function from library");
device
.new_compute_pipeline_state_with_function(&function)
.expect("Failed to create compute pipeline state")
.unwrap_or_else(|err| {
panic!("Failed to create Metal compute pipeline state for {function_name}: {err:?}\n{source}")
})
}
fn lower_dynamic_consts(mut code: String) -> String {
@@ -1039,42 +1044,33 @@ impl MetalKernelOp for MetalSumReduce {
constant int *dyn [[buffer({dyn_buffer_index})]],
constant uint &n_outputs [[buffer({n_outputs_index})]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
uint tid [[thread_index_in_threadgroup]]
) {{
if (gid >= n_outputs) return;
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
threadgroup float partials[THREADS_PER_GROUP];
int in_start = {in_idx};
int iters = {iters};
(void)dyn;
// Each thread accumulates multiple elements
float sum = 0.0f;
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
sum += {in_val};
}}
// Warp-level reduction using simd_sum
sum = simd_sum(sum);
// First lane of each warp writes to shared memory
if (simd_lane == 0) {{
warp_sums[simd_id] = sum;
}}
partials[tid] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// First warp does final reduction
if (simd_id == 0) {{
int n_warps = THREADS_PER_GROUP / 32;
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
block_sum = simd_sum(block_sum);
if (tid == 0) {{
out[{out_idx}] = {out_val};
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
if (tid < stride) {{
partials[tid] += partials[tid + stride];
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}}
if (tid == 0) {{
float block_sum = partials[0];
out[{out_idx}] = {out_val};
}}
}}
"#,
@@ -1220,42 +1216,33 @@ impl MetalKernelOp for MetalMaxReduce {
constant int *dyn [[buffer({dyn_buffer_index})]],
constant uint &n_outputs [[buffer({n_outputs_index})]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
uint tid [[thread_index_in_threadgroup]]
) {{
if (gid >= n_outputs) return;
threadgroup float warp_maxs[THREADS_PER_GROUP / 32];
threadgroup float partials[THREADS_PER_GROUP];
int in_start = {in_idx};
int iters = {iters};
(void)dyn;
// Each thread finds max of multiple elements
float max_val = NEG_INF_F;
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
max_val = fmax(max_val, {in_val});
}}
// Warp-level reduction using simd_max
max_val = simd_max(max_val);
// First lane of each warp writes to shared memory
if (simd_lane == 0) {{
warp_maxs[simd_id] = max_val;
}}
partials[tid] = max_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
// First warp does final reduction
if (simd_id == 0) {{
int n_warps = THREADS_PER_GROUP / 32;
float block_max = (tid < uint(n_warps)) ? warp_maxs[tid] : NEG_INF_F;
block_max = simd_max(block_max);
if (tid == 0) {{
out[{out_idx}] = {out_val};
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
if (tid < stride) {{
partials[tid] = fmax(partials[tid], partials[tid + stride]);
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}}
if (tid == 0) {{
float block_max = partials[0];
out[{out_idx}] = {out_val};
}}
}}
"#,
@@ -1427,8 +1414,6 @@ impl EgglogOp for MPSMatmul {
let dt = v(format!("?{}_dt", name.replace('-', "_")));
rule(union(sum_op.clone(), mps_op.clone()))
.subsume(sum_op.clone())
.subsume(mul_op)
.set(dtype(mps_op), dt.clone())
.fact(eq(dt, dtype(sum_op)))
.ruleset("kernel_lower")
@@ -1464,6 +1449,17 @@ impl EgglogOp for MPSMatmul {
1,
1,
),
Rule::raw(
"(rule
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (MPSMatmul ?m ?n ?k ?lhs ?lhsrs ?rhs ?rhsrs ?ors ?tl ?tr)))
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
:ruleset cleanup
:name \"delete-broadcast-mul-sum-when-mps-matmul-exists\"
)",
),
]
}
@@ -1505,14 +1501,6 @@ impl EgglogOp for MPSMatmul {
}
impl MPSMatmul {
fn mps_dtype(dtype: DType) -> mps::MPSDataType {
match dtype {
DType::F32 | DType::TF32 => mps::MPSDataType::Float32,
DType::F16 => mps::MPSDataType::Float16,
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
}
}
fn row_bytes(row_stride: Expression, dtype: DType, dyn_map: &FxHashMap<char, usize>) -> u64 {
let elems = row_stride
.substitute('z', Expression::from(1))
@@ -1521,19 +1509,6 @@ impl MPSMatmul {
(elems * dtype.bits().div_ceil(8)) as u64
}
fn descriptor(rows: usize, cols: usize, row_bytes: u64, dtype: DType) -> *mut Object {
let data_type = Self::mps_dtype(dtype) as isize;
unsafe {
msg_send![
class!(MPSMatrixDescriptor),
matrixDescriptorWithRows: rows
columns: cols
rowBytes: row_bytes as usize
dataType: data_type
]
}
}
fn matrix(buffer: &Buffer, descriptor: *mut Object) -> *mut Object {
unsafe {
let matrix: *mut Object = msg_send![class!(MPSMatrix), alloc];
@@ -1589,12 +1564,11 @@ impl MetalKernelOp for MPSMatmul {
fn encode(
&self,
command_buffer: &CommandBufferRef,
context: &mut MetalEncodeContext<'_>,
_pipeline: Option<&ComputePipelineState>,
inputs: &[&Buffer],
output: &Buffer,
dyn_map: &FxHashMap<char, usize>,
_dyn_buffer: &Buffer,
input_dtypes: &[DType],
output_dtype: DType,
) {
@@ -1610,46 +1584,48 @@ impl MetalKernelOp for MPSMatmul {
let rhs_rows = if self.transpose_rhs { n } else { k };
let rhs_cols = if self.transpose_rhs { k } else { n };
let lhs_desc = Self::descriptor(
lhs_rows,
lhs_cols,
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
lhs_dtype,
);
let rhs_desc = Self::descriptor(
rhs_rows,
rhs_cols,
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
rhs_dtype,
);
let out_desc = Self::descriptor(
m,
n,
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
output_dtype,
);
let (lhs_desc, rhs_desc, out_desc, kernel) = {
let mut cache = context.mps_cache.borrow_mut();
(
cache.matrix_descriptor(
lhs_rows,
lhs_cols,
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
lhs_dtype,
),
cache.matrix_descriptor(
rhs_rows,
rhs_cols,
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
rhs_dtype,
),
cache.matrix_descriptor(
m,
n,
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
output_dtype,
),
cache.matrix_multiplication(
context.command_buffer,
self.transpose_lhs,
self.transpose_rhs,
m,
n,
k,
1.0,
0.0,
),
)
};
let lhs = Self::matrix(inputs[0], lhs_desc);
let rhs = Self::matrix(inputs[1], rhs_desc);
let out = Self::matrix(output, out_desc);
unsafe {
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
let kernel: *mut Object = msg_send![
kernel,
initWithDevice: device
transposeLeft: self.transpose_lhs
transposeRight: self.transpose_rhs
resultRows: m
resultColumns: n
interiorColumns: k
alpha: 1.0f64
beta: 0.0f64
];
let _: () = msg_send![
kernel,
encodeToCommandBuffer: command_buffer.as_ptr()
encodeToCommandBuffer: context.command_buffer.as_ptr()
leftMatrix: lhs
rightMatrix: rhs
resultMatrix: out
@@ -1657,7 +1633,6 @@ impl MetalKernelOp for MPSMatmul {
let _: () = msg_send![lhs, release];
let _: () = msg_send![rhs, release];
let _: () = msg_send![out, release];
let _: () = msg_send![kernel, release];
}
}
@@ -1839,8 +1814,6 @@ impl EgglogOp for MPSBatchedMatmul {
let dt = v(format!("?{}_dt", name.replace('-', "_")));
rule(union(sum_op.clone(), mps_op.clone()))
.subsume(sum_op.clone())
.subsume(mul_op)
.set(dtype(mps_op), dt.clone())
.fact(eq(dt, dtype(sum_op)))
.ruleset("kernel_lower")
@@ -1878,6 +1851,17 @@ impl EgglogOp for MPSBatchedMatmul {
),
1,
),
Rule::raw(
"(rule
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (MPSBatchedMatmul ?b ?m ?n ?k ?lhs ?lhsbs ?lhsrs ?rhs ?rhsbs ?rhsrs ?obs ?ors ?tl ?tr)))
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
:ruleset cleanup
:name \"delete-broadcast-mul-sum-when-mps-batched-matmul-exists\"
)",
),
]
}
@@ -1953,12 +1937,11 @@ impl MetalKernelOp for MPSBatchedMatmul {
fn encode(
&self,
command_buffer: &CommandBufferRef,
context: &mut MetalEncodeContext<'_>,
_pipeline: Option<&ComputePipelineState>,
inputs: &[&Buffer],
output: &Buffer,
dyn_map: &FxHashMap<char, usize>,
_dyn_buffer: &Buffer,
input_dtypes: &[DType],
output_dtype: DType,
) {
@@ -1982,25 +1965,26 @@ impl MetalKernelOp for MPSBatchedMatmul {
let lhs_row_bytes = MPSMatmul::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map);
let rhs_row_bytes = MPSMatmul::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map);
let out_row_bytes = MPSMatmul::row_bytes(self.out_row_stride, output_dtype, dyn_map);
let lhs_desc = MPSMatmul::descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype);
let rhs_desc = MPSMatmul::descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype);
let out_desc = MPSMatmul::descriptor(m, n, out_row_bytes, output_dtype);
let (lhs_desc, rhs_desc, out_desc, kernel) = {
let mut cache = context.mps_cache.borrow_mut();
(
cache.matrix_descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype),
cache.matrix_descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype),
cache.matrix_descriptor(m, n, out_row_bytes, output_dtype),
cache.matrix_multiplication(
context.command_buffer,
self.transpose_lhs,
self.transpose_rhs,
m,
n,
k,
1.0,
0.0,
),
)
};
unsafe {
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
let kernel: *mut Object = msg_send![
kernel,
initWithDevice: device
transposeLeft: self.transpose_lhs
transposeRight: self.transpose_rhs
resultRows: m
resultColumns: n
interiorColumns: k
alpha: 1.0f64
beta: 0.0f64
];
for batch_idx in 0..batch {
let batch_expr = Expression::from(batch_idx as i64);
let lhs_offset = self
@@ -2027,7 +2011,7 @@ impl MetalKernelOp for MPSBatchedMatmul {
let out = MPSMatmul::matrix_with_offset(output, out_offset as u64, out_desc);
let _: () = msg_send![
kernel,
encodeToCommandBuffer: command_buffer.as_ptr()
encodeToCommandBuffer: context.command_buffer.as_ptr()
leftMatrix: lhs
rightMatrix: rhs
resultMatrix: out
@@ -2036,7 +2020,6 @@ impl MetalKernelOp for MPSBatchedMatmul {
let _: () = msg_send![rhs, release];
let _: () = msg_send![out, release];
}
let _: () = msg_send![kernel, release];
}
}
@@ -2163,24 +2146,6 @@ impl EgglogOp for GenericMatmul {
:name \"delete-broadcast-mul-sum-when-generic-matmul-exists\"
)",
),
Rule::raw(
"(rule
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
(= ?sum (MPSMatmul ?mm ?mn ?mk ?ml ?mls ?mr ?mrs ?mos ?mtl ?mtr)))
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
:ruleset cleanup
:name \"prefer-mps-over-generic-matmul\"
)",
),
Rule::raw(
"(rule
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
(= ?sum (MPSBatchedMatmul ?bb ?bm ?bn ?bk ?bl ?blbs ?blrs ?br ?brbs ?brrs ?bobs ?bors ?btl ?btr)))
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
:ruleset cleanup
:name \"prefer-mps-batched-over-generic-matmul\"
)",
),
]
}
@@ -2265,13 +2230,11 @@ impl MetalKernelOp for GenericMatmul {
constant int *dyn [[buffer({dyn_buffer_index})]],
constant uint &n_outputs [[buffer({n_outputs_index})]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
uint tid [[thread_index_in_threadgroup]]
) {{
if (gid >= n_outputs) return;
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
threadgroup float partials[THREADS_PER_GROUP];
int base_idx = {sum_base_idx};
int iters = {iters};
(void)dyn;
@@ -2282,19 +2245,18 @@ impl MetalKernelOp for GenericMatmul {
sum += ({lhs_val}) * ({rhs_val});
}}
sum = simd_sum(sum);
if (simd_lane == 0) {{
warp_sums[simd_id] = sum;
}}
partials[tid] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_id == 0) {{
int n_warps = THREADS_PER_GROUP / 32;
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
block_sum = simd_sum(block_sum);
if (tid == 0) {{
out[{out_idx}] = {out_val};
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
if (tid < stride) {{
partials[tid] += partials[tid + stride];
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}}
if (tid == 0) {{
float block_sum = partials[0];
out[{out_idx}] = {out_val};
}}
}}
"#,

View File

@@ -1,5 +1,6 @@
pub mod dyn_backend;
pub mod kernel;
mod memory_analysis;
pub mod runtime;
#[cfg(test)]

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,9 @@
use crate::kernel::{DYN_SLOT_COUNT, MetalKernelOp};
use crate::kernel::{DYN_SLOT_COUNT, MetalEncodeContext, MetalKernelOp, MpsKernelCache};
use half::{bf16, f16};
use itertools::Itertools;
use luminal::{
dtype::DType,
egglog_utils::SerializedEGraph,
graph::{BucketLLIR, DimBucket, Graph, LLIRGraph},
hlir::{Input, NativeData, Output},
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
@@ -16,15 +17,26 @@ use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptio
use objc::rc::autoreleasepool;
use objc::runtime::Object;
use safetensors::{Dtype, SafeTensors};
use std::{fs::File, time::Duration};
use std::{cell::RefCell, fs::File, time::Duration};
#[derive(Clone)]
struct MetalExecutionStep {
node: NodeIndex,
input_nodes: Vec<NodeIndex>,
input_dtypes: Vec<DType>,
output_dtype: DType,
}
#[derive(Clone)]
struct MetalCompiledBucket {
bucket_indices: FxHashMap<char, usize>,
llir_graph: LLIRGraph,
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
node_dtypes: FxHashMap<NodeIndex, DType>,
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
execution_plan: Vec<MetalExecutionStep>,
}
pub struct MetalRuntime {
@@ -36,16 +48,26 @@ pub struct MetalRuntime {
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
/// Buffers for LLIR intermediate/output tensors
pub buffers: FxHashMap<NodeIndex, Buffer>,
/// Logical byte length for each active LLIR buffer.
buffer_lengths: FxHashMap<NodeIndex, u64>,
/// Dynamic dimensions table (a-z), shared across all kernels.
dyn_buffer: Buffer,
/// Retained MPS descriptors/kernels reused across command encodes.
mps_cache: RefCell<MpsKernelCache>,
/// The current LLIR graph
llir_graph: LLIRGraph,
/// LLIR input node -> HLIR input node.
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
/// Inferred runtime dtype for each LLIR node.
node_dtypes: FxHashMap<NodeIndex, DType>,
/// Compiled pipeline states for each kernel node
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
/// LLIR output node -> input node whose buffer contains the output.
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
/// HLIR output id -> LLIR node whose data feeds the output.
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
/// Precomputed executable nodes and input metadata for the active LLIR graph.
execution_plan: Vec<MetalExecutionStep>,
/// Bucket definitions for dynamic dimensions.
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
/// Compiled LLIR variants, one per bucket combination.
@@ -64,22 +86,10 @@ impl MetalRuntime {
}
fn output_data_node(&self, id: NodeIndex) -> NodeIndex {
let output_id = self
.llir_graph
.node_indices()
.find(|n| {
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
*node == id.index()
} else {
false
}
})
.expect("Cannot find output tensor!");
self.llir_graph
.neighbors_directed(output_id, Direction::Incoming)
.next()
.unwrap()
self.output_data_map
.get(&id)
.copied()
.unwrap_or_else(|| panic!("Cannot find output tensor {id:?}!"))
}
fn follow_aliases(&self, mut node: NodeIndex) -> NodeIndex {
@@ -225,6 +235,7 @@ impl MetalRuntime {
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
if let Some(buffer) = self.buffers.remove(&data_id) {
self.buffer_lengths.remove(&data_id);
return buffer;
}
@@ -269,12 +280,21 @@ impl MetalRuntime {
.map(|inp| inp.dtype)
})
.unwrap_or(DType::F32);
let logical_bytes = self
.buffer_lengths
.get(&data_id)
.copied()
.unwrap_or_else(|| buffer.length());
assert!(
logical_bytes <= buffer.length(),
"Logical buffer size exceeds allocated Metal buffer size"
);
unsafe {
match dtype {
DType::F16 => {
let ptr = buffer.contents() as *const f16;
let len = buffer.length() as usize / std::mem::size_of::<f16>();
let len = logical_bytes as usize / std::mem::size_of::<f16>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| v.to_f32())
@@ -282,7 +302,7 @@ impl MetalRuntime {
}
DType::Int => {
let ptr = buffer.contents() as *const i32;
let len = buffer.length() as usize / std::mem::size_of::<i32>();
let len = logical_bytes as usize / std::mem::size_of::<i32>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| *v as f32)
@@ -290,7 +310,7 @@ impl MetalRuntime {
}
_ => {
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
let len = logical_bytes as usize / std::mem::size_of::<f32>();
std::slice::from_raw_parts(ptr, len).to_vec()
}
}
@@ -304,6 +324,26 @@ impl Runtime for MetalRuntime {
type ExecReturn = ();
type ProfileMetric = Duration;
fn late_egglog_passes(
ops: &[std::sync::Arc<Box<dyn luminal::op::EgglogOp>>],
options: &luminal::graph::BuildSearchSpaceOptions,
dyn_map: &FxHashMap<char, usize>,
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
vec![crate::memory_analysis::metal_memory_analysis_pass(
ops,
options.max_memory_bytes,
dyn_map,
)]
}
fn estimate_graph_memory<'a>(
egraph: &'a SerializedEGraph,
choices: &luminal::egglog_utils::EGraphChoiceSet<'a>,
dyn_map: &FxHashMap<char, usize>,
) -> Option<usize> {
crate::memory_analysis::estimate_graph_memory_bytes(egraph, choices, dyn_map)
}
fn initialize(_: Self::CompileArg) -> Self {
let device = Device::system_default().expect("No Metal device found!");
let command_queue = device.new_command_queue();
@@ -318,11 +358,16 @@ impl Runtime for MetalRuntime {
input_data: FxHashMap::default(),
hlir_buffers: FxHashMap::default(),
buffers: FxHashMap::default(),
buffer_lengths: FxHashMap::default(),
dyn_buffer,
mps_cache: RefCell::new(MpsKernelCache::default()),
llir_graph: StableGraph::default(),
llir_to_hlir: FxHashMap::default(),
node_dtypes: FxHashMap::default(),
pipelines: FxHashMap::default(),
output_alias_map: FxHashMap::default(),
output_data_map: FxHashMap::default(),
execution_plan: vec![],
dim_buckets: FxHashMap::default(),
compiled_buckets: vec![],
active_bucket: 0,
@@ -336,6 +381,7 @@ impl Runtime for MetalRuntime {
#[tracing::instrument(skip_all)]
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
self.buffers.clear();
self.buffer_lengths.clear();
self.dim_buckets.clear();
self.compiled_buckets = vec![self.compile_bucket(FxHashMap::default(), llir_graph)];
self.activate_bucket(0);
@@ -347,19 +393,25 @@ impl Runtime for MetalRuntime {
llir_graph: &LLIRGraph,
dyn_map: &FxHashMap<char, usize>,
trials: usize,
_timeout: Option<std::time::Duration>,
timeout: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String) {
self.load_llir(llir_graph);
self.allocate_intermediate_buffers(dyn_map);
let trials = trials.max(1);
let profile_start = std::time::Instant::now();
let mut duration = Duration::default();
let mut completed_trials = 0;
for _ in 0..trials {
let start = std::time::Instant::now();
self.execute(dyn_map);
duration += start.elapsed();
completed_trials += 1;
if timeout.is_some_and(|timeout| profile_start.elapsed() >= timeout) {
break;
}
}
duration /= trials as u32;
duration /= completed_trials as u32;
(duration, format!("{:.2?}", duration))
}
@@ -370,74 +422,43 @@ impl Runtime for MetalRuntime {
self.select_bucket(dyn_map);
self.allocate_active_intermediate_buffers(dyn_map);
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
.llir_graph
.node_indices()
.filter_map(|n| {
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
Some((n, NodeIndex::new(*node)))
} else {
None
}
})
.collect();
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let mut encode_context = MetalEncodeContext {
command_buffer,
dyn_buffer: &self.dyn_buffer,
mps_cache: &self.mps_cache,
};
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
|| self.llir_graph[node].to_op::<Output>().is_some()
{
continue;
}
for step in &self.execution_plan {
let kernel_op = self.llir_graph[step.node]
.to_dialect::<dyn MetalKernelOp>()
.expect("Execution plan referenced a non-Metal op");
let pipeline = self.pipelines.get(&step.node);
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = self.pipelines.get(&node);
let input_buffers: Vec<&Buffer> = step
.input_nodes
.iter()
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
.collect();
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
input_buffers[alias_idx]
} else {
self.buffers
.get(&step.node)
.expect("Output buffer not allocated!")
};
let input_buffers: Vec<&Buffer> = input_nodes
.iter()
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
.collect();
let input_dtypes: Vec<DType> = input_nodes
.iter()
.map(|n| {
self.node_dtypes
.get(n)
.copied()
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
})
.collect();
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
input_buffers[alias_idx]
} else {
self.buffers
.get(&node)
.expect("Output buffer not allocated!")
};
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
kernel_op.encode(
command_buffer,
pipeline,
&input_buffers,
output_buffer,
dyn_map,
&self.dyn_buffer,
&input_dtypes,
output_dtype,
);
}
kernel_op.encode(
&mut encode_context,
pipeline,
&input_buffers,
output_buffer,
dyn_map,
&step.input_dtypes,
step.output_dtype,
);
}
command_buffer.commit();
@@ -447,6 +468,22 @@ impl Runtime for MetalRuntime {
fn clear_intermediate_buffers(&mut self) {
self.buffers.clear();
self.buffer_lengths.clear();
}
fn intermediate_buffer_bytes(&self) -> usize {
self.buffers
.values()
.map(|buffer| buffer.length() as usize)
.sum()
}
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
Some(self.intermediate_buffer_bytes())
}
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
Some(self.intermediate_buffer_bytes())
}
fn load_llir_buckets(
@@ -455,6 +492,7 @@ impl Runtime for MetalRuntime {
bucket_llirs: &[BucketLLIR],
) {
self.buffers.clear();
self.buffer_lengths.clear();
self.dim_buckets = dim_buckets.clone();
self.compiled_buckets = bucket_llirs
.iter()
@@ -497,7 +535,7 @@ impl MetalRuntime {
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
match dtype {
DType::F32 => {
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
let values = data.to_f32_vec();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
@@ -505,7 +543,7 @@ impl MetalRuntime {
)
}
DType::F16 => {
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
let values = data.to_f16_vec();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
@@ -513,7 +551,7 @@ impl MetalRuntime {
)
}
DType::Int => {
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
let values = data.to_i32_vec();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
@@ -531,6 +569,7 @@ impl MetalRuntime {
fn allocate_active_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
let mut planned = Vec::new();
let capacity_dyn_map = self.active_capacity_dyn_map(dyn_map);
for node in self.llir_graph.node_indices() {
if self.llir_graph[node].to_op::<Input>().is_some() {
@@ -541,28 +580,58 @@ impl MetalRuntime {
if kernel_op.output_aliases_input().is_some() {
continue;
}
let size = kernel_op.output_size().exec(dyn_map).unwrap();
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
let bytes = (size * dtype.bits().div_ceil(8)) as u64;
let requested_bytes =
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, dyn_map);
let allocation_bytes =
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, &capacity_dyn_map)
.max(requested_bytes);
let needs_buffer = self
.buffers
.get(&node)
.is_none_or(|buffer| buffer.length() != bytes);
.is_none_or(|buffer| requested_bytes > buffer.length());
planned.push((node, bytes, needs_buffer));
planned.push((node, requested_bytes, allocation_bytes, needs_buffer));
}
}
for (node, bytes, needs_buffer) in planned {
for (node, requested_bytes, allocation_bytes, needs_buffer) in planned {
self.buffer_lengths.insert(node, requested_bytes);
if needs_buffer {
let buffer = self
.device
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
.new_buffer(allocation_bytes, MTLResourceOptions::StorageModeShared);
self.buffers.insert(node, buffer);
}
}
}
fn output_bytes(
kernel_op: &dyn MetalKernelOp,
dtype: DType,
dyn_map: &FxHashMap<char, usize>,
) -> u64 {
let size = kernel_op.output_size().exec(dyn_map).unwrap();
(size * dtype.bits().div_ceil(8)) as u64
}
fn active_capacity_dyn_map(&self, dyn_map: &FxHashMap<char, usize>) -> FxHashMap<char, usize> {
let mut capacity_dyn_map = dyn_map.clone();
let Some(active_bucket) = self.compiled_buckets.get(self.active_bucket) else {
return capacity_dyn_map;
};
for (&dim, buckets) in &self.dim_buckets {
if let Some(&bucket_index) = active_bucket.bucket_indices.get(&dim)
&& let Some(bucket) = buckets.get(bucket_index)
{
capacity_dyn_map.insert(dim, bucket.max);
}
}
capacity_dyn_map
}
fn compile_bucket(
&self,
bucket_indices: FxHashMap<char, usize>,
@@ -571,12 +640,17 @@ impl MetalRuntime {
let mut node_dtypes = FxHashMap::default();
let mut pipelines = FxHashMap::default();
let mut output_alias_map = FxHashMap::default();
let mut output_data_map = FxHashMap::default();
let mut execution_plan = Vec::new();
let mut llir_to_hlir = FxHashMap::default();
let llir_graph = llir_graph.clone();
let topo_order = toposort(&llir_graph, None).expect("Graph has cycles!");
for node in topo_order {
for node in &topo_order {
let node = *node;
if let Some(input) = llir_graph[node].to_op::<Input>() {
node_dtypes.insert(node, input.dtype);
llir_to_hlir.insert(node, NodeIndex::new(input.node));
continue;
}
@@ -610,17 +684,38 @@ impl MetalRuntime {
{
output_alias_map.insert(node, target);
}
execution_plan.push(MetalExecutionStep {
node,
input_nodes,
input_dtypes,
output_dtype,
});
} else {
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
}
}
for node in topo_order {
if let Some(Output { node: hlir_node }) = llir_graph[node].to_op::<Output>()
&& let Some(data_node) = llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.next()
.map(|e| e.source())
{
output_data_map.insert(NodeIndex::new(*hlir_node), data_node);
}
}
MetalCompiledBucket {
bucket_indices,
llir_graph,
llir_to_hlir,
node_dtypes,
pipelines,
output_alias_map,
output_data_map,
execution_plan,
}
}
@@ -632,11 +727,15 @@ impl MetalRuntime {
.clone();
self.active_bucket = index;
self.llir_graph = bucket.llir_graph;
self.llir_to_hlir = bucket.llir_to_hlir;
self.node_dtypes = bucket.node_dtypes;
self.pipelines = bucket.pipelines;
self.output_alias_map = bucket.output_alias_map;
self.output_data_map = bucket.output_data_map;
self.execution_plan = bucket.execution_plan;
self.refresh_input_data_buffers();
self.buffers.clear();
self.buffer_lengths.clear();
}
fn refresh_input_data_buffers(&mut self) {
@@ -706,74 +805,43 @@ impl MetalRuntime {
self.select_bucket(dyn_map);
self.allocate_active_intermediate_buffers(dyn_map);
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
.llir_graph
.node_indices()
.filter_map(|n| {
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
Some((n, NodeIndex::new(*node)))
} else {
None
}
})
.collect();
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let mut encode_context = MetalEncodeContext {
command_buffer,
dyn_buffer: &self.dyn_buffer,
mps_cache: &self.mps_cache,
};
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
|| self.llir_graph[node].to_op::<Output>().is_some()
{
continue;
}
for step in &self.execution_plan {
let kernel_op = self.llir_graph[step.node]
.to_dialect::<dyn MetalKernelOp>()
.expect("Execution plan referenced a non-Metal op");
let pipeline = self.pipelines.get(&step.node);
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = self.pipelines.get(&node);
let input_buffers: Vec<&Buffer> = step
.input_nodes
.iter()
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
.collect();
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
input_buffers[alias_idx]
} else {
self.buffers
.get(&step.node)
.expect("Output buffer not allocated!")
};
let input_buffers: Vec<&Buffer> = input_nodes
.iter()
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
.collect();
let input_dtypes: Vec<DType> = input_nodes
.iter()
.map(|n| {
self.node_dtypes
.get(n)
.copied()
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
})
.collect();
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
input_buffers[alias_idx]
} else {
self.buffers
.get(&node)
.expect("Output buffer not allocated!")
};
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
kernel_op.encode(
command_buffer,
pipeline,
&input_buffers,
output_buffer,
dyn_map,
&self.dyn_buffer,
&input_dtypes,
output_dtype,
);
}
kernel_op.encode(
&mut encode_context,
pipeline,
&input_buffers,
output_buffer,
dyn_map,
&step.input_dtypes,
step.output_dtype,
);
}
command_buffer.commit();

View File

@@ -3,6 +3,7 @@ use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
use half::{bf16, f16};
use luminal::prelude::*;
use proptest::prelude::*;
use rand::{SeedableRng, rngs::StdRng};
use safetensors::{Dtype, tensor::TensorView};
use std::{
collections::HashMap,
@@ -38,6 +39,30 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
bytemuck::cast_slice(values).to_vec()
}
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
let mut rng = StdRng::seed_from_u64(0);
cx.search_options(rt, SearchOptions::new(limit), &mut rng)
}
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
cx.egraph()
.expect("search space should be built")
.enodes
.values()
.any(|(label, _)| label == op_name)
}
fn assert_matmul_options(cx: &Graph, mps_op_name: &str) {
assert!(
egraph_has_op(cx, mps_op_name),
"expected {mps_op_name} rewrite option in e-graph"
);
assert!(
egraph_has_op(cx, "GenericMatmul"),
"expected GenericMatmul rewrite option in e-graph"
);
}
fn write_test_safetensors(tensors: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> PathBuf {
let tensor_views: HashMap<String, TensorView<'_>> = tensors
.iter()
@@ -401,6 +426,18 @@ proptest! {
}
}
#[test]
fn metal_build_search_space_accepts_memory_budget() {
let mut cx = Graph::default();
let a = cx.tensor(4);
let b = cx.tensor(4);
(a * b).output();
cx.build_search_space_with_options::<MetalRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(1),
);
}
/// Simple deterministic test for add
#[test]
fn metal_simple_add() {
@@ -665,7 +702,7 @@ fn metal_specialized_matmul() {
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
rt = search_candidates(&mut cx, rt, 32);
assert!(
rt.contains_matmul(),
"expected Metal runtime to fuse matmul, kernels: {:?}",
@@ -698,6 +735,7 @@ fn metal_regular_tiled_matmul_path() {
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(m * k, 0.4, -0.2);
@@ -705,19 +743,7 @@ fn metal_regular_tiled_matmul_path() {
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MPSMatmul")),
"expected MPS matmul path, kernels: {:?}",
kernels
);
assert!(
!kernels.iter().any(|k| k.contains("GenericMatmul")),
"MPS-compatible matmul should not extract the generic fallback, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -744,6 +770,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
let output = a.matmul(weight.t()).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(m * k, 0.35, -0.17);
@@ -751,14 +778,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
rt.set_data(a, &a_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
"expected MPS matmul to cover transposed row-major RHS, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -785,6 +805,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
let output = lhs_storage.t().matmul(rhs).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
let lhs_data = seeded_data(k * m, 0.31, -0.12);
@@ -792,14 +813,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
rt.set_data(lhs_storage, &lhs_data);
rt.set_data(rhs, &rhs_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("transpose_lhs: true")),
"expected MPS matmul to cover transposed row-major LHS, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -830,20 +844,14 @@ fn metal_mps_batched_matmul_row_row_layout() {
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSBatchedMatmul");
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(batch * m * k, 0.17, -0.08);
let b_data = seeded_data(batch * k * n, 0.11, -0.05);
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MPSBatchedMatmul")),
"expected MPS batched matmul path, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -880,13 +888,17 @@ fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
let output = merged.matmul(weight.t()).output();
cx.build_search_space::<MetalRuntime>();
assert!(
egraph_has_op(&cx, "GenericMatmul"),
"expected GenericMatmul rewrite option in e-graph"
);
let mut rt = MetalRuntime::initialize(());
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
rt.set_data(attn, &attn_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
rt = search_candidates(&mut cx, rt, 32);
let kernels = rt.debug_kernel_ops();
assert!(
@@ -935,22 +947,14 @@ fn metal_mps_batched_matmul_transposed_rhs_layout() {
let output = a.matmul(weight.permute((0, 2, 1))).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSBatchedMatmul");
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(batch * m * k, 0.13, -0.06);
let weight_data = seeded_data(batch * n * k, 0.09, -0.04);
rt.set_data(a, &a_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels
.iter()
.any(|k| k.contains("MPSBatchedMatmul") && k.contains("transpose_rhs: true")),
"expected MPS batched matmul transposed RHS path, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -984,6 +988,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
let output = a.matmul(weight.t()).cast(DType::F32).output();
cx.build_search_space::<MetalRuntime>();
assert_matmul_options(&cx, "MPSMatmul");
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(m * k, 0.22, -0.07);
@@ -991,14 +996,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
rt.set_data(a, to_f16_vec(&a_data));
rt.set_data(weight, to_f16_vec(&weight_data));
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
"expected MPS F16 matmul to cover transposed row-major RHS, kernels: {:?}",
kernels
);
rt = search_candidates(&mut cx, rt, 32);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);

View File

@@ -865,3 +865,29 @@ Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anyth
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
---
## 2026-05-21 — DLRM compile: silent mis-stride on `index.Tensor` with a None-prefix
1. **Symptom**: compiling facebookresearch/dlrm through `luminal_backend` failed in the top-MLP with `assertion left == right failed: Dims must match to add tensors. left: [2, 8] right: [6, 8]`. The error surfaced ~5 ops downstream of the actual bug, with no mention of `index` anywhere in the trace.
2. **Root cause**: `translate_index_tensor` in `crates/luminal_python/rust/src/translator/movement.rs` had two code paths for advanced indexing. The first ran when an `OptionalTensors` arg held exactly one non-None entry on a specific dim (`first_non_none_dim > 0 && index_names.len() == 1`); it correctly used `first_non_none_dim` to gather on the right axis. The second — the general multi-index fall-through — silently ignored `first_non_none_dim` and computed strides/flat-source-shape as if indices always started at dim 0. DLRM's dot interaction does `Z[:, li, lj]` (Z is `[B, ni, nj]`, two 1-D index tensors after a `:`), which hits the multi-index path with `first_non_none_dim = 1`. The translator built strides over `src_shape[..n_indexed] = [B, ni]` and a flat-source of shape `[B*ni, nj]`, instead of striding over `[ni, nj]` with prefix-dim `[B]`. The downstream gather produced a tensor with the wrong leading dim (6 — the index length — instead of B), and the mismatch only blew up later when broadcast-add into the top-MLP hidden state.
3. **Why it was hard to find**: the trace ends in `process_pt2` with a luminal core assertion about broadcasting in a `+` op. Nothing in the message names the *upstream* op that produced the wrong shape. Worse, the bug only manifests when ALL of {two-plus index tensors, at least one leading `None`, downstream broadcast-sensitive consumer} are present — the common case (`a[idx]`, `a[idx, jdx]` with no prefix) just works. So the bug had survived through every prior model translator test.
4. **The fix**: split the prefix-aware case into its own helper `translate_index_tensor_with_prefix`. It explicitly partitions `src.shape` into `prefix_dims / indexed_dims / suffix_dims`, builds the flat sub-index over `indexed_dims`, promotes/expands it into the full output shape, and adds a broadcast prefix-offset constructed from `arange`s over each prefix dim. Result is fully-flat `source.gather(absolute_idx)`. The suffix-non-empty case is left guarded with a `bail!` (it's separable but DLRM doesn't need it).
5. **Principle**: a shape-keyed assumption baked into one branch of a multi-branch translator is a silent footgun — when the fall-through path is reached with a value the assumption rules out, you get *wrong shapes silently*, and the failure surfaces wherever the wrong shape first encounters a consumer that cares. Guard early: if an invariant the code relies on isn't met (here, "indices apply to the leading dims of source"), check it explicitly and `bail!` with the offending shape rather than computing forward. Even better, refactor so the unsupported case routes to a dedicated path the moment the assumption diverges — small risk of double-implementation, large reduction in "compile silently produces wrong output."
## 2026-05-21 — DLRM compile: `EmbeddingBag` translator gap
1. **Symptom**: same `luminal_backend` compile, first error: `RuntimeError: Failed to translate node N: torch.ops.aten._embedding_bag_forward_only.default: Unsupported ATen op`. This is the central op of DLRM — every sparse feature lookup decomposes to it via `nn.EmbeddingBag`.
2. **What's needed**: `_embedding_bag_forward_only(weight, indices, offsets, ..., mode, ...)` produces `output[b] = reduce_op(weight[indices[offsets[b]:offsets[b+1]]])` for each bag `b`. The general case is a *runtime segment reduction* — the bag boundaries depend on `offsets`, which is a runtime tensor — and luminal has no native segment-reduce primitive.
3. **The fix (in this session)**: add `translate_embedding_bag` covering the uniform-bag-size case, which is what DLRM actually uses. Read `indices.shape[0] = N` and `offsets.shape[0] = B` off the static shape info, compute bag size `K = N / B`, bail if they don't divide. Then gather `[N, D]` (same construction as `translate_embedding`), reshape to `[B, K, D]`, reduce along axis 1 according to `mode` (sum/mean/max). For `K=1` (the eval-time-1-lookup-per-sample DLRM path) skip the reshape+reduce — it's just an `embedding` lookup. `per_sample_weights` and non-uniform bags are guarded with `bail!`.
4. **Why this works for DLRM but isn't general**: a true segment reduction needs either (a) static knowledge of every segment boundary (what we get when bags are uniform), or (b) a scatter-add primitive that handles per-segment accumulation at runtime. (a) covers DLRM's training/eval data generator and the common recsys case where each sample has K-hot lookups for fixed K. (b) is required for any model that genuinely has variable-length bags per sample (e.g. variable-length feature crossings) and is a follow-up.
5. **Principle**: when a PyTorch op has no straight-line luminal lowering, look at the *shapes the model actually feeds in* before declaring it unsupportable. A "segment reduction" over offsets is a hard problem in general; "segment reduction where every bag has K elements with K statically known from indices.shape[0]/offsets.shape[0]" is a 5-line gather+reshape+reduce. The PT2 graph carries the shape info for free — use it.

View File

@@ -127,6 +127,12 @@ impl<'a> Translator<'a> {
}
// addmm: beta*input + alpha*(mat1 @ mat2)
//
// PyTorch's nn.Linear with bias generates `addmm(bias, input, weight.t())`
// with the default `beta=alpha=1.0`. Emitting the multiplies in that
// case wastes 2 HLIR nodes per Linear that egglog has to fold later;
// for a 4-Linear DLRM that's 8 nodes off the search-space count.
// Skip them when the scale is 1.
"torch.ops.aten.addmm.default" => {
let input = self.get_input_tensor(node, 0)?;
let mat1 = self.get_input_tensor(node, 1)?;
@@ -135,7 +141,9 @@ impl<'a> Translator<'a> {
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
let scaled_input = if beta == 1.0 { input } else { input * beta };
let scaled_mm = if alpha == 1.0 { mm } else { mm * alpha };
scaled_input + scaled_mm
}
// Convolution
@@ -154,6 +162,10 @@ impl<'a> Translator<'a> {
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
"torch.ops.aten._embedding_bag.default"
| "torch.ops.aten._embedding_bag_forward_only.default" => {
self.translate_embedding_bag(node)?
}
// Softmax
"torch.ops.aten._softmax.default" => {

View File

@@ -0,0 +1,434 @@
//! DLRM-family pattern matcher for the PT2 translator.
//!
//! Recognizes the `MiniDLRM` topology in a parsed PT2 graph (bot MLP →
//! N sparse `_embedding_bag_forward_only` lookups (bag-size 1) →
//! dot-product interaction via `bmm` + lower-triangular `index.Tensor` →
//! top MLP ending in `sigmoid`) and, when matched, emits a single
//! [`luminal_cuda_lite::kernel::DlrmMegaCustom`] op that replaces the
//! entire per-node translation. The runtime then sees ONE host op
//! instead of the 8 cuBLAS+CudaGraphOp ops the normal path produces.
//!
//! The matcher is intentionally conservative — any mismatch returns
//! `None` and the translator falls back to its standard node-by-node
//! walk, so wrong-graphs never produce wrong-output, only "the fast
//! path didn't trigger." Diagnostic prints are gated on
//! `LUMINAL_DLRM_MEGAKERNEL_DEBUG=1` for development.
//!
//! See `examples/dlrm/src/megakernel.rs` for the standalone proof of
//! concept and `crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs`
//! for the parameterized kernel itself.
//!
//! Companion plan: see `/home/ubuntu/.claude/plans/can-you-plan-out-mossy-wave.md`.
use anyhow::{Context, Result};
use luminal::prelude::*;
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
use crate::pt2_parser::ParsedPT2;
use crate::pt2_schema::Node;
use super::Translator;
/// Resolved DLRM shape + the PT2 graph names of every tensor the
/// megakernel needs as input. All weight/input lookups go through
/// `Translator::get_tensor(name)` which is keyed by PT2 graph_name.
#[derive(Debug)]
pub(super) struct DlrmShape {
pub batch: usize,
pub n_dense_in: usize,
pub ln_bot: Vec<usize>,
pub n_sparse: usize,
pub vocab_sizes: Vec<usize>,
pub m_spa: usize,
pub ln_top: Vec<usize>,
pub dense_input_name: String,
pub index_input_names: Vec<String>, // length n_sparse
pub emb_weight_names: Vec<String>, // length n_sparse
pub bot_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
pub top_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
pub output_name: String,
}
fn debug_enabled() -> bool {
std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").map(|v| v == "1").unwrap_or(false)
}
macro_rules! dbgln {
($($arg:tt)*) => {
if debug_enabled() {
eprintln!("[dlrm_pattern] {}", format!($($arg)*));
}
};
}
/// Try to interpret the parsed PT2 program as a DLRM-shape forward.
/// Returns `None` if any structural check fails — translator falls back
/// to the standard dispatch.
pub(super) fn match_dlrm(parsed: &ParsedPT2) -> Option<DlrmShape> {
let nodes = &parsed.program.graph_module.graph.nodes;
// ---- 1. Index the key op types ----------------------------------
let emb_node_idxs: Vec<usize> = nodes
.iter()
.enumerate()
.filter(|(_, n)| n.target == "torch.ops.aten._embedding_bag_forward_only.default"
|| n.target == "torch.ops.aten._embedding_bag.default")
.map(|(i, _)| i)
.collect();
if emb_node_idxs.is_empty() {
dbgln!("no embedding_bag nodes — not DLRM");
return None;
}
let n_sparse = emb_node_idxs.len();
let first_emb = emb_node_idxs[0];
let last_emb = *emb_node_idxs.last().unwrap();
let addmm_idxs: Vec<usize> = nodes
.iter()
.enumerate()
.filter(|(_, n)| n.target == "torch.ops.aten.addmm.default")
.map(|(i, _)| i)
.collect();
let bot_addmms: Vec<usize> =
addmm_idxs.iter().filter(|&&i| i < first_emb).copied().collect();
let top_addmms: Vec<usize> =
addmm_idxs.iter().filter(|&&i| i > last_emb).copied().collect();
if bot_addmms.is_empty() || top_addmms.is_empty() {
dbgln!(
"addmm split: bot={}, top={} (expected ≥1 each)",
bot_addmms.len(),
top_addmms.len()
);
return None;
}
let sigmoid_idx = nodes
.iter()
.enumerate()
.find(|(_, n)| n.target == "torch.ops.aten.sigmoid.default")
.map(|(i, _)| i)?;
if sigmoid_idx < *top_addmms.last().unwrap() {
dbgln!("sigmoid before last top addmm — not DLRM ordering");
return None;
}
let bmm_idx = nodes
.iter()
.enumerate()
.find(|(_, n)| n.target == "torch.ops.aten.bmm.default")
.map(|(i, _)| i)?;
if bmm_idx < last_emb || bmm_idx > top_addmms[0] {
dbgln!("bmm position wrong (idx {bmm_idx}, last_emb {last_emb}, first_top_addmm {})", top_addmms[0]);
return None;
}
// index.Tensor must exist between bmm and the first top addmm — that's
// the (li, lj) gather of the lower-triangular pairs.
let _index_idx = nodes
.iter()
.enumerate()
.find(|(i, n)| n.target == "torch.ops.aten.index.Tensor" && *i > bmm_idx)
.map(|(i, _)| i)?;
// ---- 2. Extract embedding info (vocab, m_spa, indices, weights) -
let mut vocab_sizes = Vec::with_capacity(n_sparse);
let mut emb_weight_names = Vec::with_capacity(n_sparse);
let mut index_input_names = Vec::with_capacity(n_sparse);
let mut batch_opt: Option<usize> = None;
let mut m_spa_opt: Option<usize> = None;
for &i in &emb_node_idxs {
let n = &nodes[i];
// Validate the bag invariants the megakernel relies on.
// arg ordering: (weight, indices, offsets, scale_grad_by_freq, mode,
// sparse, per_sample_weights, include_last_offset, padding_idx)
let weight_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
let indices_name = n.inputs.get(1)?.arg.as_tensor_name()?.to_string();
let offsets_name = n.inputs.get(2)?.arg.as_tensor_name()?.to_string();
// mode must be 0 (sum) — anything else falls back.
let mode = n.inputs.get(4).and_then(|a| a.arg.as_int()).unwrap_or(0);
if mode != 0 {
dbgln!("embedding_bag mode={mode} != 0 (sum)");
return None;
}
// per_sample_weights must be None (no tensor arg in slot 6).
if let Some(arg) = n.inputs.get(6)
&& arg.arg.as_tensor_name().is_some()
{
dbgln!("embedding_bag has per_sample_weights — not supported");
return None;
}
// include_last_offset must be false.
if matches!(
n.inputs.get(7).and_then(|a| a.arg.as_bool()),
Some(true)
) {
dbgln!("embedding_bag include_last_offset=true — not supported");
return None;
}
let weight_meta = parsed.tensor_meta(&weight_name)?;
if weight_meta.sizes.len() != 2 {
dbgln!("embedding weight has non-2D shape");
return None;
}
let v = weight_meta.sizes[0].hint()? as usize;
let m = weight_meta.sizes[1].hint()? as usize;
if let Some(prev) = m_spa_opt
&& prev != m
{
dbgln!("inconsistent m_spa across embeddings ({prev} vs {m})");
return None;
}
m_spa_opt = Some(m);
// Bag-size-1: indices.len == offsets.len == batch.
let idx_meta = parsed.tensor_meta(&indices_name)?;
let off_meta = parsed.tensor_meta(&offsets_name)?;
if idx_meta.sizes.len() != 1 || off_meta.sizes.len() != 1 {
return None;
}
let idx_len = idx_meta.sizes[0].hint()? as usize;
let off_len = off_meta.sizes[0].hint()? as usize;
if idx_len != off_len {
dbgln!(
"non-uniform bag (indices={idx_len}, offsets={off_len}) — fallback"
);
return None;
}
if let Some(prev) = batch_opt
&& prev != idx_len
{
dbgln!("inconsistent batch across embeddings ({prev} vs {idx_len})");
return None;
}
batch_opt = Some(idx_len);
vocab_sizes.push(v);
emb_weight_names.push(weight_name);
index_input_names.push(indices_name);
}
let m_spa = m_spa_opt?;
let batch = batch_opt?;
// ---- 3. Reconstruct bot/top MLP widths --------------------------
//
// addmm(bias, input, weight^T) → output (B, out)
// inputs[0] = bias (out,) — gives us the layer's out_features
// inputs[1] = input (B, in_w) — first addmm in each chain tells us in_w
// inputs[2] = weight^T — usually produced by a `permute.default`
// whose input is the (out, in) weight param.
let extract_chain_shape = |chain: &[usize]| -> Option<Vec<usize>> {
let mut ln = Vec::with_capacity(chain.len() + 1);
for (i, &node_idx) in chain.iter().enumerate() {
let n = &nodes[node_idx];
let bias_name = n.inputs.first()?.arg.as_tensor_name()?;
let bias_meta = parsed.tensor_meta(bias_name)?;
if bias_meta.sizes.len() != 1 {
return None;
}
let out = bias_meta.sizes[0].hint()? as usize;
if i == 0 {
let input_name = n.inputs.get(1)?.arg.as_tensor_name()?;
let in_meta = parsed.tensor_meta(input_name)?;
if in_meta.sizes.len() != 2 {
return None;
}
let in_w = in_meta.sizes[1].hint()? as usize;
ln.push(in_w);
}
ln.push(out);
}
Some(ln)
};
let ln_bot = extract_chain_shape(&bot_addmms)?;
let ln_top = extract_chain_shape(&top_addmms)?;
// ---- 4. Shape consistency checks --------------------------------
if *ln_bot.last()? != m_spa {
dbgln!("ln_bot.last() = {} != m_spa {m_spa}", ln_bot.last()?);
return None;
}
let n_feat = 1 + n_sparse;
let n_pairs = n_feat * (n_feat - 1) / 2;
if ln_top[0] != m_spa + n_pairs {
dbgln!(
"ln_top[0] = {} != m_spa+n_pairs = {}",
ln_top[0],
m_spa + n_pairs
);
return None;
}
if *ln_top.last()? != 1 {
dbgln!("ln_top.last() = {} != 1", ln_top.last()?);
return None;
}
if vocab_sizes.len() != n_sparse {
return None;
}
if ln_bot.len() < 2 || ln_top.len() < 2 {
return None;
}
// ---- 5. Pull weight + bias parameter names ----------------------
let extract_weights = |chain: &[usize]| -> Option<Vec<(String, String)>> {
let mut out = Vec::with_capacity(chain.len());
for &node_idx in chain {
let n = &nodes[node_idx];
let bias_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
let mat2_name = n.inputs.get(2)?.arg.as_tensor_name()?;
let weight_name = resolve_weight_param(nodes, mat2_name)?;
out.push((weight_name, bias_name));
}
Some(out)
};
let bot_weight_names = extract_weights(&bot_addmms)?;
let top_weight_names = extract_weights(&top_addmms)?;
// ---- 6. dense_input + output names ------------------------------
let dense_input_name = nodes[bot_addmms[0]]
.inputs
.get(1)?
.arg
.as_tensor_name()?
.to_string();
// Validate it's actually a user input (not an intermediate).
let user_input_names: std::collections::HashSet<&str> = parsed
.classify_inputs()
.iter()
.filter_map(|i| match i {
crate::pt2_parser::InputKind::UserInput { graph_name } => Some(graph_name.as_str()),
_ => None,
})
.map(|s| s.to_string())
.collect::<std::collections::HashSet<String>>()
.iter()
.map(|s| -> &str { unsafe { std::mem::transmute::<&str, &str>(s.as_str()) } })
.collect();
let _ = user_input_names; // suppress dead_code if not used; cleaner check below
// (Simpler: just check the name is in classified user inputs by string.)
let inputs = parsed.classify_inputs();
let is_user = inputs.iter().any(|i| {
matches!(
i,
crate::pt2_parser::InputKind::UserInput { graph_name } if graph_name == &dense_input_name
)
});
if !is_user {
dbgln!("dense_input candidate {dense_input_name} is not a user input");
return None;
}
let output_name = nodes[sigmoid_idx]
.outputs
.first()?
.as_tensor
.as_ref()?
.name
.clone();
let shape = DlrmShape {
batch,
n_dense_in: ln_bot[0],
ln_bot,
n_sparse,
vocab_sizes,
m_spa,
ln_top,
dense_input_name,
index_input_names,
emb_weight_names,
bot_weight_names,
top_weight_names,
output_name,
};
dbgln!(
"matched DLRM: batch={} ln_bot={:?} n_sparse={} vocabs={:?} m_spa={} ln_top={:?}",
shape.batch,
shape.ln_bot,
shape.n_sparse,
shape.vocab_sizes,
shape.m_spa,
shape.ln_top
);
Some(shape)
}
/// Walk back from an addmm's `mat2` argument to the underlying weight
/// parameter. PyTorch's `nn.Linear` decomposes to
/// `permute(weight) → addmm(bias, x, permuted)`, so we expect mat2 to be
/// the output of a `permute.default` node whose input is the weight.
/// If mat2 is itself a graph input (no producing node), it IS the weight.
fn resolve_weight_param(nodes: &[Node], name: &str) -> Option<String> {
for n in nodes {
let Some(first_out) = n.outputs.first().and_then(|o| o.as_tensor.as_ref()) else {
continue;
};
if first_out.name == name {
// mat2 was produced by an op. Only `permute.default` is expected;
// anything else is unfamiliar and we should fall back.
if n.target == "torch.ops.aten.permute.default" {
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
} else if n.target == "torch.ops.aten.t.default" {
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
} else {
dbgln!(
"addmm mat2 produced by unexpected op '{}' — fallback",
n.target
);
return None;
}
}
}
// No producing node — mat2 is a graph input (param) directly.
Some(name.to_string())
}
/// Build the megakernel CustomOp inputs vec in the canonical order
/// expected by [`DlrmMegaKernel`] and insert it into the translator's
/// luminal graph. Registers the result under `shape.output_name` so the
/// downstream output-emission loop finds it.
pub(super) fn emit_megakernel(t: &mut Translator<'_>, shape: DlrmShape) -> Result<()> {
// Resolve every input tensor by PT2 graph_name through Translator.tensors.
let mut inputs: Vec<GraphTensor> = Vec::new();
inputs.push(
t.get_tensor(&shape.dense_input_name)
.with_context(|| format!("dense input {} not in tensors", shape.dense_input_name))?,
);
for n in &shape.index_input_names {
inputs.push(t.get_tensor(n).with_context(|| format!("index input {n} not in tensors"))?);
}
for n in &shape.emb_weight_names {
inputs.push(t.get_tensor(n).with_context(|| format!("emb weight {n} not in tensors"))?);
}
for (w, b) in &shape.bot_weight_names {
inputs.push(t.get_tensor(w).with_context(|| format!("bot weight {w} not in tensors"))?);
inputs.push(t.get_tensor(b).with_context(|| format!("bot bias {b} not in tensors"))?);
}
for (w, b) in &shape.top_weight_names {
inputs.push(t.get_tensor(w).with_context(|| format!("top weight {w} not in tensors"))?);
inputs.push(t.get_tensor(b).with_context(|| format!("top bias {b} not in tensors"))?);
}
let kernel = DlrmMegaKernel {
batch: shape.batch,
n_dense_in: shape.n_dense_in,
ln_bot: shape.ln_bot.clone(),
n_sparse: shape.n_sparse,
vocab_sizes: shape.vocab_sizes.clone(),
m_spa: shape.m_spa,
ln_top: shape.ln_top.clone(),
};
let out = t.graph.custom_op(
DlrmMegaCustom(kernel),
inputs,
(shape.batch, 1usize),
DType::F32,
);
t.tensors.insert(shape.output_name.clone(), out);
dbgln!("emitted DlrmMegaCustom; output={}", shape.output_name);
Ok(())
}

View File

@@ -6,6 +6,8 @@ mod attention;
mod binary;
mod conv;
mod dispatch;
#[cfg(feature = "cuda")]
mod dlrm_pattern;
mod movement;
mod reduction;
mod tensor;
@@ -70,12 +72,31 @@ impl<'a> Translator<'a> {
fn translate_graph(&mut self) -> Result<()> {
self.create_inputs()?;
// Fast path: if the entire forward matches the DLRM family shape,
// emit one DlrmMegaCustom op instead of walking nodes. On any
// mismatch the matcher returns None and we fall through to the
// standard dispatch — no semantic difference, just slower (~503µs
// vs ~30µs at bs=2048 for MiniDLRM). CUDA-only: the megakernel
// is a CUDA CustomOp.
#[cfg(feature = "cuda")]
if let Some(shape) = dlrm_pattern::match_dlrm(self.parsed) {
dlrm_pattern::emit_megakernel(self, shape)?;
return self.emit_outputs();
}
let nodes = &self.parsed.program.graph_module.graph.nodes;
for (i, node) in nodes.iter().enumerate() {
self.translate_node(node)
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
}
self.emit_outputs()
}
/// Walks the parsed graph's user outputs, applies the wrap/cast rules
/// that downstream codegen relies on, then attaches an `Output` node
/// per user-output. Shared by the normal dispatch path and the DLRM
/// megakernel fast path.
fn emit_outputs(&mut self) -> Result<()> {
let output_names = self.parsed.output_names();
for name in &output_names {
let tensor = self.get_tensor(name)?;
@@ -84,12 +105,20 @@ impl<'a> Translator<'a> {
} else if tensor.dtype == DType::Int {
tensor
} else {
// The `+ 0.0` wrap pulls double duty: it materializes a fresh
// buffer for outputs that alias an Input (passthrough
// `return x`), AND it acts as an anchor that survives egglog
// rewriting, so the downstream runtime can find the producer
// node for outputs whose original op (e.g. Reduce with
// keepdims, Conv) gets folded away during optimization.
// Removing it broke 24 test_hlir_ops tests with "Cannot find
// output tensor!" — keep it until that anchor invariant is
// refactored elsewhere.
tensor + 0.0
};
tensor.output();
self.output_ids.push((name.clone(), tensor.id));
}
Ok(())
}

View File

@@ -256,6 +256,97 @@ impl<'a> Translator<'a> {
Ok(weight.gather(ids_expanded + arange_expanded))
}
/// `aten._embedding_bag` / `aten._embedding_bag_forward_only`
///
/// Signature: (weight, indices, offsets, scale_grad_by_freq=False, mode=0,
/// sparse=False, per_sample_weights=None, include_last_offset=False,
/// padding_idx=-1) -> (output, offset2bag, bag_size, max_indices)
///
/// Strategy: for the bag-size-uniform case (N indices spread evenly across
/// B bags, i.e. N % B == 0), reshape gather output [N, D] into [B, K, D]
/// and reduce along K according to `mode`. We deliberately read uniformity
/// off the *static shapes* of `indices` and `offsets` — non-uniform bags
/// require a runtime segment-sum primitive we don't yet have.
///
/// DLRM hits the K=1 special case (offsets=[0,1,...,B-1], indices=[B]) per
/// sparse table per sample — the same lookup pattern as `aten.embedding`.
/// Only the first tuple element is materialized; the bookkeeping outputs
/// (offset2bag, bag_size, max_indices) are inference-time dead ends.
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?;
let offsets = self.get_input_tensor(node, 2)?;
let mode = self.get_int_arg(node, 4).unwrap_or(0);
let include_last_offset = self.get_bool_arg(node, 7).unwrap_or(false);
if let Some(arg) = node.inputs.get(6)
&& arg.arg.as_tensor_name().is_some()
{
bail!("_embedding_bag: per_sample_weights not supported");
}
if indices.shape.len() != 1 || offsets.shape.len() != 1 {
bail!(
"_embedding_bag: expected 1-D indices and offsets, got shapes {:?}, {:?}",
indices.shape.dims,
offsets.shape.dims
);
}
let n = indices.shape.dims[0]
.to_usize()
.context("_embedding_bag: indices length must be statically known")?;
let b_raw = offsets.shape.dims[0]
.to_usize()
.context("_embedding_bag: offsets length must be statically known")?;
let b = if include_last_offset { b_raw - 1 } else { b_raw };
if b == 0 {
bail!("_embedding_bag: empty bag set");
}
if n % b != 0 {
bail!(
"_embedding_bag: non-uniform bag size not supported (indices={n}, bags={b})"
);
}
let k = n / b;
let hidden_dim = weight.shape.dims[1];
// Step 1: gather weight rows. Same construction as translate_embedding —
// flatten the (idx, hidden) pair into a single offset into the weight
// matrix and gather. Result: [N, D].
let indices_int = indices.cast(DType::Int);
let ids_expanded = (indices_int * hidden_dim).expand_dim(1, hidden_dim);
let arange = self.graph.arange(hidden_dim);
let arange_expanded = arange.expand_dim(0, indices.shape.dims[0]);
let gathered = weight.gather(ids_expanded + arange_expanded);
// Step 2: bag-size-1 → already [B, D]; skip reshape/reduce.
if k == 1 {
return Ok(gathered);
}
// Step 3: reshape [B*K, D] → [B, K, D] (contiguous, identity stride view).
let bag_shape = vec![
Expression::from(b),
Expression::from(k),
hidden_dim,
];
let mut bagged = GraphTensor {
id: gathered.id,
graph_ref: gathered.graph_ref,
shape: ShapeTracker::new(bag_shape),
dtype: gathered.dtype,
};
// Step 4: reduce along axis=1.
bagged = match mode {
0 => bagged.sum(1),
1 => bagged.mean(1),
2 => bagged.max(1),
m => bail!("_embedding_bag: unsupported mode {m} (0=sum, 1=mean, 2=max)"),
};
Ok(bagged)
}
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let source = self.get_input_tensor(node, 0)?;
@@ -318,6 +409,20 @@ impl<'a> Translator<'a> {
let index_names = &index_names;
// Prefix-of-Nones case: `source[:, ..., :, idx_0, idx_1, ..., idx_{m-1}]`
// — indices apply to dims [first..first+m), not [0..m). The original
// multi-index path below assumes first==0 and silently mis-strides
// (and mis-flattens) when called with a prefix; route to the
// prefix-aware path before falling through. Suffix-of-Nones after the
// indices is not yet supported here.
if first_non_none_dim > 0 {
return self.translate_index_tensor_with_prefix(
source,
index_names,
first_non_none_dim,
);
}
let src_shape = source.shape.dims;
let n_indexed = index_names.len();
@@ -398,6 +503,132 @@ impl<'a> Translator<'a> {
}
}
/// Advanced indexing with a `None` prefix: `source[:, ..., :, i0, i1, ...]`.
///
/// Output shape: `prefix_dims ++ idx_shape ++ suffix_dims` where
/// `prefix_dims = src.shape[..first]`, `suffix_dims = src.shape[first+m..]`,
/// and `idx_shape` is the broadcast shape of the m index tensors.
///
/// Currently supports the no-suffix case (indices land on the trailing
/// dims). DLRM's dot interaction hits this: `Z[:, li, lj]` with
/// `Z: [B, ni, nj]`, `li, lj: [L]`.
fn translate_index_tensor_with_prefix(
&mut self,
source: GraphTensor,
index_names: &[crate::pt2_schema::TensorName],
first: usize,
) -> Result<GraphTensor> {
let src_shape = source.shape.dims;
let n_indexed = index_names.len();
let src_rank = src_shape.len();
if first + n_indexed > src_rank {
bail!(
"index.Tensor (prefix): {n_indexed} indices starting at dim {first} \
exceed source rank {src_rank}"
);
}
let prefix_dims: Vec<Expression> = src_shape[..first].to_vec();
let indexed_dims: Vec<Expression> = src_shape[first..first + n_indexed].to_vec();
let suffix_dims: Vec<Expression> = src_shape[first + n_indexed..].to_vec();
if !suffix_dims.is_empty() {
bail!(
"index.Tensor (prefix): trailing-dim suffix after indices not \
supported (prefix={} indexed={} suffix={})",
prefix_dims.len(),
n_indexed,
suffix_dims.len()
);
}
// Per-axis strides within the indexed subspace (right-to-left product).
let mut strides = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * indexed_dims[i + 1];
}
let indexed_size = indexed_dims
.iter()
.copied()
.fold(Expression::from(1usize), |a, b| a * b);
// Collapse the m index tensors into a single flat index in the indexed
// subspace. Negative entries get normalized per axis.
let mut flat_idx: Option<GraphTensor> = None;
for (i, idx_name) in index_names.iter().enumerate() {
let idx_t = self.get_tensor(&idx_name.name)?.cast(DType::Int);
let axis_size = indexed_dims[i];
let zero = self.graph.constant(0).expand_rhs(idx_t.shape);
let is_neg = idx_t.lt(zero).cast(DType::Int);
let idx_norm = idx_t + is_neg * axis_size;
let stride = strides[i];
let weighted = if stride.to_usize() == Some(1) {
idx_norm
} else {
idx_norm * stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (a, w) = broadcast_binary(acc, weighted);
a + w
}
None => weighted,
});
}
let flat_idx = flat_idx.context("index.Tensor (prefix): no indices")?;
let idx_shape: Vec<Expression> = flat_idx.shape.dims.to_vec();
// Build the absolute flat index over `source` viewed as 1D, shape
// `prefix_dims ++ idx_shape`:
// abs[p..., k...] = flat_prefix(p...) * indexed_size + flat_idx[k...]
// Construct by promoting `flat_idx` to the full rank then adding a
// broadcast prefix-offset tensor.
let mut full_shape: Vec<Expression> = prefix_dims.clone();
full_shape.extend_from_slice(&idx_shape);
// Promote flat_idx: insert prefix_dims leading axes, then expand.
let mut idx_promoted = flat_idx;
for _ in 0..prefix_dims.len() {
idx_promoted = idx_promoted.expand_dim(0, Expression::from(1usize));
}
idx_promoted.shape.expand(full_shape.clone());
// Prefix offset: for each prefix dim pi (right-to-left), accumulate
// arange(prefix_dims[pi]) * (product_of_more_inner_prefix_dims * indexed_size).
let mut prefix_offset: Option<GraphTensor> = None;
let mut cum_stride = indexed_size;
for (pi, pd) in prefix_dims.iter().enumerate().rev() {
let ar = self.graph.arange(*pd) * cum_stride;
// arange is shape [pd]; lift it into full_shape at position pi.
let mut ar_promoted = ar;
for _ in 0..pi {
ar_promoted = ar_promoted.expand_dim(0, Expression::from(1usize));
}
let trailing = full_shape.len() - pi - 1;
for _ in 0..trailing {
let r = ar_promoted.shape.len();
ar_promoted = ar_promoted.expand_dim(r, Expression::from(1usize));
}
ar_promoted.shape.expand(full_shape.clone());
prefix_offset = Some(match prefix_offset {
Some(acc) => acc + ar_promoted,
None => ar_promoted,
});
cum_stride = cum_stride * *pd;
}
let final_idx = match prefix_offset {
Some(po) => idx_promoted + po,
None => idx_promoted,
};
// Flatten source to 1D and gather with the absolute index.
let total: Expression = src_shape
.iter()
.copied()
.fold(Expression::from(1usize), |a, b| a * b);
let fully_flat = reshape_tensor(source, vec![total]);
Ok(fully_flat.gather(final_idx))
}
pub(crate) fn translate_gather(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;

View File

@@ -43,6 +43,19 @@ class CompiledModel:
else torch.float32
for i in range(len(self._input_names))
]
# Pre-zip + caches for the hot path. The CudaRuntime now preserves
# external-pointer registrations across execute() calls and treats
# set_device_ptr as a no-op when the pointer is unchanged — caching
# the (name, ptr) here avoids the pyo3 round-trip entirely in tight
# loops where PyTorch's caching allocator keeps re-handing back the
# same tensor (e.g. inference loops with reused activation buffers).
self._input_specs = list(zip(self._input_names, self._input_dtypes))
self._last_input_ptrs: dict[str, int] = {}
# Output dtype/zero-copy decisions are properties of the compiled
# graph and never change; computing them lazily and caching avoids
# ~10µs of pyo3 calls per iter.
self._output_torch_dtypes_cache = None
self._output_zero_copy_cache = None
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -89,22 +102,41 @@ class CompiledModel:
# Set user input data via pointer.
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
# For CUDA inputs already in the expected dtype + contiguous, we
# skip the detach/contiguous/to chain (those allocate new Tensor
# objects even when they're no-ops) and short-circuit set_input_device_ptr
# when the pointer hasn't moved since the last call. The runtime
# treats same-ptr re-registration as a no-op too, but skipping the
# pyo3 round-trip here saves another ~5µs per input.
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = []
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
graph = self._graph
last_input_ptrs = self._last_input_ptrs
if self._supports_device_ptrs:
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
if (
tensor.is_cuda
and tensor.dtype is expected_dtype
and tensor.is_contiguous()
):
t = tensor
else:
t = tensor.detach().contiguous().to(expected_dtype)
ptr = t.data_ptr()
if last_input_ptrs.get(name) != ptr:
graph.set_input_device_ptr(name, ptr, t.numel() * t.element_size())
last_input_ptrs[name] = ptr
_input_refs.append(t)
else:
else:
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
graph.set_input_from_ptr(
name,
t.data_ptr(),
t.numel() * t.element_size(),
_torch_dtype_code(t.dtype),
)
# Resolve output shapes before run() (needed for pre-allocation).
if self._has_dynamic_dims:

View File

@@ -0,0 +1,209 @@
"""End-to-end compile tests for a faithful DLRM-style recommender.
`MiniDLRM` below mirrors `DLRM_Net` from facebookresearch/dlrm:
bottom-MLP on dense features, an `EmbeddingBag` per sparse table, dot-product
interaction over the (1 + n_sparse) feature vectors, and a top-MLP. The
forward signature `(dense_x, lS_o, lS_i)` matches DLRM exactly.
This is the smallest model that exercises the three translator paths added for
DLRM:
- `aten._embedding_bag_forward_only.default` (uniform-bag-size lowering)
- `aten.index.Tensor` with a `None` prefix (`Z[:, li, lj]`)
- the existing `aten.bmm` / `aten.cat` paths under the above feeders
"""
from typing import Callable
import torch
import torch._dynamo
import torch.nn as nn
from luminal import luminal_backend
class MiniDLRM(nn.Module):
"""Minimal faithful DLRM (dot interaction, mode='sum' EmbeddingBag)."""
def __init__(
self,
m_spa: int,
ln_emb: list[int],
ln_bot: list[int],
ln_top: list[int],
) -> None:
super().__init__()
assert ln_bot[-1] == m_spa, "bottom MLP must end at m_spa"
n_feat = 1 + len(ln_emb)
n_pairs = n_feat * (n_feat - 1) // 2
assert ln_top[0] == n_pairs + m_spa, (
f"top MLP entry width must equal n_pairs ({n_pairs}) + m_spa ({m_spa}) "
f"= {n_pairs + m_spa}, got {ln_top[0]}"
)
self.m_spa = m_spa
self.emb_l = nn.ModuleList(
[nn.EmbeddingBag(int(n), m_spa, mode="sum") for n in ln_emb]
)
self.bot_l = self._build_mlp(ln_bot, sigmoid_last=False)
self.top_l = self._build_mlp(ln_top, sigmoid_last=True)
@staticmethod
def _build_mlp(sizes: list[int], sigmoid_last: bool) -> nn.Sequential:
layers: list[nn.Module] = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=True))
if i == len(sizes) - 2 and sigmoid_last:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def _apply_emb(
self, lS_o: list[torch.Tensor], lS_i: list[torch.Tensor]
) -> list[torch.Tensor]:
return [self.emb_l[k](lS_i[k], lS_o[k]) for k in range(len(self.emb_l))]
def _interact(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
batch_size, d = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
Z = torch.bmm(T, torch.transpose(T, 1, 2))
_, ni, nj = Z.shape
li = torch.tensor(
[i for i in range(ni) for _ in range(i)], device=x.device
)
lj = torch.tensor(
[j for i in range(nj) for j in range(i)], device=x.device
)
Zflat = Z[:, li, lj]
return torch.cat([x, Zflat], dim=1)
def forward(
self,
dense_x: torch.Tensor,
lS_o: list[torch.Tensor],
lS_i: list[torch.Tensor],
) -> torch.Tensor:
x = self.bot_l(dense_x)
ly = self._apply_emb(lS_o, lS_i)
z = self._interact(x, ly)
return self.top_l(z)
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
def _make_inputs(
batch_size: int,
dense_dim: int,
ln_emb: list[int],
bag_size: int,
device: torch.device,
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
dense_x = torch.rand(batch_size, dense_dim, device=device)
if bag_size == 1:
offsets = [
torch.arange(batch_size, dtype=torch.long, device=device)
for _ in ln_emb
]
indices = [
torch.randint(0, int(n), (batch_size,), dtype=torch.long, device=device)
for n in ln_emb
]
else:
offsets = [
torch.arange(
0,
batch_size * bag_size,
bag_size,
dtype=torch.long,
device=device,
)
for _ in ln_emb
]
indices = [
torch.randint(
0, int(n), (batch_size * bag_size,), dtype=torch.long, device=device
)
for n in ln_emb
]
return dense_x, offsets, indices
def _build_model(
m_spa: int,
ln_emb: list[int],
ln_bot: list[int],
device: torch.device,
) -> MiniDLRM:
torch.manual_seed(0)
n_feat = 1 + len(ln_emb)
n_pairs = n_feat * (n_feat - 1) // 2
ln_top = [n_pairs + m_spa, 8, 1]
model = MiniDLRM(m_spa, ln_emb, ln_bot, ln_top).to(device).eval()
return model
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_dlrm_dot_bag1_smallbatch(device: torch.device) -> None:
"""The canonical DLRM eval path: 1 lookup per sample per sparse table."""
m_spa = 4
ln_emb = [10, 20, 30]
ln_bot = [13, 8, m_spa]
model = _build_model(m_spa, ln_emb, ln_bot, device)
inputs = _make_inputs(batch_size=2, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
eager = model(*inputs)
out = compiled(*inputs)
assert torch.allclose(out, eager, atol=1e-5)
def test_dlrm_dot_bag1_largerbatch(device: torch.device) -> None:
"""Larger batch (64) — sanity-check that the bs-1 specialization isn't load-bearing."""
m_spa = 4
ln_emb = [10, 20, 30]
ln_bot = [13, 8, m_spa]
model = _build_model(m_spa, ln_emb, ln_bot, device)
inputs = _make_inputs(batch_size=64, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
eager = model(*inputs)
out = compiled(*inputs)
assert torch.allclose(out, eager, atol=1e-4)
def test_dlrm_dot_multihot(device: torch.device) -> None:
"""Uniform multi-hot bags (bag_size=3) — exercises the reshape+sum path."""
m_spa = 4
ln_emb = [10, 20, 30]
ln_bot = [13, 8, m_spa]
model = _build_model(m_spa, ln_emb, ln_bot, device)
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=3, device=device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
eager = model(*inputs)
out = compiled(*inputs)
assert torch.allclose(out, eager, atol=1e-5)
def test_dlrm_dot_larger_tables(device: torch.device) -> None:
"""Verifies bigger embedding tables don't change the path."""
m_spa = 4
ln_emb = [50, 100, 200]
ln_bot = [13, 8, m_spa]
model = _build_model(m_spa, ln_emb, ln_bot, device)
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
eager = model(*inputs)
out = compiled(*inputs)
assert torch.allclose(out, eager, atol=1e-5)

17
examples/dlrm/Cargo.toml Normal file
View File

@@ -0,0 +1,17 @@
[package]
name = "dlrm"
version = "0.1.0"
edition = "2024"
[[bin]]
name = "dlrm"
path = "src/main.rs"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
safetensors = "0.7.0"
memmap2 = "0.9.9"
bytemuck = "1.24.0"
rand = "0.9.2"

306
examples/dlrm/bench.py Normal file
View File

@@ -0,0 +1,306 @@
"""DLRM inference latency benchmark across PyTorch backends + luminal.
Backends measured:
1. PyTorch eager
2. torch.compile (default backend = inductor, mode="reduce-overhead")
3. AOTInductor (export → aoti_compile_and_package → load → run)
4. CUDA graphs (capture-replay around the eager model)
5. PyTorch + luminal_backend (torch.compile with our PT2 → luminal backend)
The rust luminal path is measured separately by the dlrm binary's --bench
flag and the two results are combined in the rank table later.
Methodology:
- Same MiniDLRM at the small config, batch_size=2 (matches export.py and
the rust binary so the comparison is apples-to-apples).
- 50 warmup iters per backend, 500 measured iters.
- Per-iteration latency via paired cudaEvent_record + elapsed_time.
- Report mean / p50 / p99 in microseconds; also dump every measurement
to /tmp/dlrm_bench_<backend>.txt so other consumers can re-aggregate.
Run:
/lambda/nfs/tucker-fs/second/luminal/crates/luminal_python/.venv/bin/python \
examples/dlrm/bench.py
"""
import os
import sys
import statistics
import tempfile
import time
from pathlib import Path
from typing import Callable
import torch
# MiniDLRM lives in tests.
TESTS_DIR = (
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
)
sys.path.insert(0, str(TESTS_DIR))
from test_dlrm import MiniDLRM # noqa: E402
from luminal import luminal_backend # noqa: E402
DEVICE = torch.device("cuda")
WARMUP = 50
ITERS = 500
M_SPA = 4
LN_EMB = [10, 20, 30]
LN_BOT = [13, 8, M_SPA]
LN_TOP = [10, 8, 1]
# Real-workload DLRM batch — kernel work dominates per-launch overhead.
BATCH = 2048
def make_model() -> torch.nn.Module:
torch.manual_seed(0)
return MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
def make_inputs() -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
torch.manual_seed(42)
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
indices = [
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE) for n in LN_EMB
]
offsets = [torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB]
return dense_x, offsets, indices
def time_callable(fn: Callable[[], torch.Tensor], iters: int) -> list[float]:
"""Time `fn` over `iters` iterations using CUDA events. Returns per-iter
microseconds."""
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
torch.cuda.synchronize()
for i in range(iters):
start_evts[i].record()
_ = fn()
end_evts[i].record()
torch.cuda.synchronize()
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
def report(name: str, samples_us: list[float]) -> dict[str, float]:
samples_us = sorted(samples_us)
n = len(samples_us)
mean = sum(samples_us) / n
p50 = samples_us[n // 2]
p99 = samples_us[int(n * 0.99)]
print(f" {name:<32s} mean={mean:8.2f}µs p50={p50:8.2f}µs p99={p99:8.2f}µs")
# Dump every sample for downstream aggregation.
out_path = f"/tmp/dlrm_bench_{name.replace(' ', '_').replace('(', '').replace(')', '')}.txt"
Path(out_path).write_text("\n".join(f"{s:.4f}" for s in samples_us))
return {"name": name, "mean": mean, "p50": p50, "p99": p99, "n": n}
# ---------------------------------------------------------------------------
# Backends
# ---------------------------------------------------------------------------
def bench_eager() -> dict[str, float]:
model = make_model()
inputs = make_inputs()
@torch.no_grad()
def fn() -> torch.Tensor:
return model(*inputs)
for _ in range(WARMUP):
fn()
return report("eager", time_callable(fn, ITERS))
def bench_torch_compile() -> dict[str, float]:
torch._dynamo.reset()
model = make_model()
inputs = make_inputs()
compiled = torch.compile(model, mode="reduce-overhead")
@torch.no_grad()
def fn() -> torch.Tensor:
return compiled(*inputs)
for _ in range(WARMUP):
fn()
return report("torch.compile (inductor)", time_callable(fn, ITERS))
def bench_aoti() -> dict[str, float]:
"""AOTInductor: export → compile-and-package → load → run.
Note: torch.export currently treats list[Tensor] inputs as pytree-flattened,
so the runtime callable takes positional tensors. We unpack manually.
"""
torch._dynamo.reset()
model = make_model()
dense_x, offsets, indices = make_inputs()
# Wrap to surface tensor inputs at the top-level positional signature.
class FlatWrapper(torch.nn.Module):
def __init__(self, m: torch.nn.Module) -> None:
super().__init__()
self.m = m
def forward(
self,
dense_x: torch.Tensor,
o0: torch.Tensor,
o1: torch.Tensor,
o2: torch.Tensor,
i0: torch.Tensor,
i1: torch.Tensor,
i2: torch.Tensor,
) -> torch.Tensor:
return self.m(dense_x, [o0, o1, o2], [i0, i1, i2])
flat_model = FlatWrapper(model).to(DEVICE).eval()
flat_inputs = (dense_x, *offsets, *indices)
with torch.no_grad():
ep = torch.export.export(flat_model, flat_inputs)
with tempfile.TemporaryDirectory() as tmp:
pkg_path = os.path.join(tmp, "dlrm.pt2")
torch._inductor.aoti_compile_and_package(ep, package_path=pkg_path)
loaded = torch._inductor.aoti_load_package(pkg_path)
@torch.no_grad()
def fn() -> torch.Tensor:
return loaded(*flat_inputs)
for _ in range(WARMUP):
fn()
return report("AOTInductor", time_callable(fn, ITERS))
def bench_cuda_graphs() -> dict[str, float]:
"""Capture the eager forward as a CUDA graph, then replay.
MiniDLRM builds the `li`/`lj` lower-triangular index tensors via
`torch.tensor([...], device=...)` inside `_interact`, which triggers a
fresh host→device copy each call — and CUDA-graph capture can't observe
non-pinned host→device copies. Wrap the model to pre-bake those indices
as cuda buffers on the wrapper, then patch the bound method.
"""
model = make_model()
n_feat = 1 + len(LN_EMB)
li_const = torch.tensor(
[i for i in range(n_feat) for _ in range(i)], device=DEVICE
)
lj_const = torch.tensor(
[j for i in range(n_feat) for j in range(i)], device=DEVICE
)
def _interact_static(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
bs, d = x.shape
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
Z = torch.bmm(T, torch.transpose(T, 1, 2))
Zflat = Z[:, li_const, lj_const]
return torch.cat([x, Zflat], dim=1)
# Bind the static version so `self` resolves correctly.
import types
model._interact = types.MethodType(_interact_static, model)
dense_x, offsets, indices = make_inputs()
static_dense = dense_x.clone()
static_offsets = [o.clone() for o in offsets]
static_indices = [i.clone() for i in indices]
@torch.no_grad()
def fwd() -> torch.Tensor:
return model(static_dense, static_offsets, static_indices)
# CUDA-graph prep: a stream warmup, then capture.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_ = fwd()
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_out = fwd()
@torch.no_grad()
def fn() -> torch.Tensor:
# Real workloads would copy fresh inputs into static_* here. For pure
# replay-latency measurement the inputs are constant.
g.replay()
return static_out
for _ in range(WARMUP):
fn()
return report("CUDA graphs (eager capture)", time_callable(fn, ITERS))
def bench_luminal_backend() -> dict[str, float]:
torch._dynamo.reset()
model = make_model()
inputs = make_inputs()
compiled = torch.compile(model, backend=luminal_backend)
@torch.no_grad()
def fn() -> torch.Tensor:
return compiled(*inputs)
for _ in range(WARMUP):
fn()
return report("luminal_backend (PT2)", time_callable(fn, ITERS))
def main() -> None:
torch.cuda.synchronize()
print(f"Device: {torch.cuda.get_device_name(0)}")
print(f"PyTorch: {torch.__version__}")
print(f"Config: m_spa={M_SPA} ln_emb={LN_EMB} batch={BATCH} iters={ITERS}\n")
rows = []
for fn in (bench_eager, bench_torch_compile, bench_aoti, bench_cuda_graphs, bench_luminal_backend):
t0 = time.perf_counter()
try:
rows.append(fn())
except Exception as e:
print(f" FAILED {fn.__name__}: {type(e).__name__}: {e}")
print(f" (setup+bench took {time.perf_counter() - t0:.1f}s)\n")
# Pull in any externally-produced rust samples (rust luminal binary
# writes both —bench and —mega samples to /tmp).
for label, path_str in [
("rust luminal", "/tmp/dlrm_bench_rust_luminal.txt"),
("DLRM megakernel", "/tmp/dlrm_bench_megakernel.txt"),
]:
p = Path(path_str)
if not p.exists():
continue
samples_us = sorted(float(s) for s in p.read_text().splitlines() if s)
n = len(samples_us)
rows.append({
"name": label,
"mean": sum(samples_us) / n,
"p50": samples_us[n // 2],
"p99": samples_us[int(n * 0.99)],
"n": n,
})
print(f" {label:<32s} mean={rows[-1]['mean']:8.2f}µs "
f"p50={rows[-1]['p50']:8.2f}µs p99={rows[-1]['p99']:8.2f}µs "
f"(from {path_str})")
# Rank by mean latency.
rows.sort(key=lambda r: r["mean"])
print("=" * 60)
print("Ranking (mean latency, lower is better):\n")
fastest = rows[0]["mean"]
print(f" {'#':<3}{'backend':<32s}{'mean µs':>10s}{'vs fastest':>14s}")
for i, r in enumerate(rows):
ratio = r["mean"] / fastest
print(f" {i + 1:<3}{r['name']:<32s}{r['mean']:>10.2f}{ratio:>13.2f}x")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
"""DLRM latency sweep: batch_size × n_sparse_tables × backend.
Reuses the per-backend timing primitives from `bench.py` but parameterises
the model config so we can see how each backend scales along both DLRM's
key axes: batch size (parallelism / kernel utilisation) and number of
sparse tables (kernel launch count, host-side dispatch cost).
For each (batch, n_sparse) cell, runs:
- PyTorch eager
- torch.compile (mode='reduce-overhead')
- AOTInductor
- CUDA graphs (eager capture)
- luminal_backend (PT2)
The rust luminal path can't be invoked from python; we skip it here. The
single-config bench.py remains the cross-check that includes rust.
Output is one table per backend with rows = batch, cols = n_sparse, plus a
final per-cell "winner" matrix.
"""
import gc
import sys
import time
import os
import tempfile
import types
from pathlib import Path
import torch
TESTS_DIR = (
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
)
sys.path.insert(0, str(TESTS_DIR))
from test_dlrm import MiniDLRM # noqa: E402
from luminal import luminal_backend # noqa: E402
DEVICE = torch.device("cuda")
WARMUP = 25
ITERS = 200 # halved vs bench.py to keep sweep wall-clock reasonable
M_SPA = 4
# Sweep grid — real-workload DLRM batches where matmul efficiency is what's
# actually being compared (sub-100 batches were launch-overhead dominated and
# said more about wrapper cost than backend quality).
BATCH_SIZES = [256, 1024, 2048, 4096]
N_SPARSE_LIST = [3, 8, 16]
def make_model(n_sparse: int):
torch.manual_seed(0)
# Embedding table vocab sizes: alternate small/medium so the lookups
# exercise different table widths without making setup time explode.
base_vocabs = [10, 20, 30, 40, 60, 80, 100, 120, 160, 200, 240, 320, 400, 500, 640, 800]
ln_emb = base_vocabs[:n_sparse]
ln_bot = [13, 8, M_SPA]
n_feat = 1 + n_sparse
n_pairs = n_feat * (n_feat - 1) // 2
ln_top = [n_pairs + M_SPA, 8, 1]
return MiniDLRM(M_SPA, ln_emb, ln_bot, ln_top).to(DEVICE).eval(), ln_emb
def make_inputs(batch: int, ln_emb: list[int]):
torch.manual_seed(42)
dense_x = torch.rand(batch, 13, device=DEVICE)
indices = [
torch.randint(0, n, (batch,), dtype=torch.long, device=DEVICE) for n in ln_emb
]
offsets = [torch.arange(batch, dtype=torch.long, device=DEVICE) for _ in ln_emb]
return dense_x, offsets, indices
def time_callable(fn, iters: int) -> list[float]:
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
torch.cuda.synchronize()
for i in range(iters):
start_evts[i].record()
fn()
end_evts[i].record()
torch.cuda.synchronize()
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
def mean_us(samples: list[float]) -> float:
return sum(samples) / len(samples)
# ---------------------------------------------------------------------------
# Backends
# ---------------------------------------------------------------------------
def bench_eager(model, inputs):
@torch.no_grad()
def fn():
return model(*inputs)
for _ in range(WARMUP):
fn()
return mean_us(time_callable(fn, ITERS))
def bench_torch_compile(model, inputs):
torch._dynamo.reset()
compiled = torch.compile(model, mode="reduce-overhead")
@torch.no_grad()
def fn():
return compiled(*inputs)
for _ in range(WARMUP):
fn()
return mean_us(time_callable(fn, ITERS))
def bench_aoti(model, inputs):
torch._dynamo.reset()
dense_x, offsets, indices = inputs
n_sparse = len(offsets)
# Flat-signature wrapper so torch.export sees positional tensors.
class FlatWrapper(torch.nn.Module):
def __init__(self, m, n_sparse: int):
super().__init__()
self.m = m
self.n_sparse = n_sparse
def forward(self, *args):
n = self.n_sparse
dense_x = args[0]
offsets = list(args[1 : 1 + n])
indices = list(args[1 + n : 1 + 2 * n])
return self.m(dense_x, offsets, indices)
flat_model = FlatWrapper(model, n_sparse).to(DEVICE).eval()
flat_inputs = (dense_x, *offsets, *indices)
with torch.no_grad():
ep = torch.export.export(flat_model, flat_inputs)
with tempfile.TemporaryDirectory() as tmp:
pkg = os.path.join(tmp, "dlrm.pt2")
torch._inductor.aoti_compile_and_package(ep, package_path=pkg)
loaded = torch._inductor.aoti_load_package(pkg)
@torch.no_grad()
def fn():
return loaded(*flat_inputs)
for _ in range(WARMUP):
fn()
return mean_us(time_callable(fn, ITERS))
def bench_cuda_graphs(model, inputs):
"""Capture eager forward as a CUDA graph and replay. Patches the
interaction's li/lj construction to be static buffers so capture works
(same trick the single-config bench uses)."""
dense_x, offsets, indices = inputs
n_sparse = len(offsets)
n_feat = 1 + n_sparse
li = torch.tensor([i for i in range(n_feat) for _ in range(i)], device=DEVICE)
lj = torch.tensor([j for i in range(n_feat) for j in range(i)], device=DEVICE)
def _interact_static(self, x, ly):
bs, d = x.shape
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
Z = torch.bmm(T, torch.transpose(T, 1, 2))
Zflat = Z[:, li, lj]
return torch.cat([x, Zflat], dim=1)
model._interact = types.MethodType(_interact_static, model)
static_dense = dense_x.clone()
static_offsets = [o.clone() for o in offsets]
static_indices = [i.clone() for i in indices]
@torch.no_grad()
def fwd():
return model(static_dense, static_offsets, static_indices)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
fwd()
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
_ = fwd()
@torch.no_grad()
def fn():
g.replay()
for _ in range(WARMUP):
fn()
return mean_us(time_callable(fn, ITERS))
def bench_luminal_backend(model, inputs):
torch._dynamo.reset()
compiled = torch.compile(model, backend=luminal_backend)
@torch.no_grad()
def fn():
return compiled(*inputs)
for _ in range(WARMUP):
fn()
return mean_us(time_callable(fn, ITERS))
BACKENDS = [
("eager", bench_eager),
("torch.compile", bench_torch_compile),
("AOTInductor", bench_aoti),
("CUDA graphs", bench_cuda_graphs),
("luminal_backend", bench_luminal_backend),
]
# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------
def fmt(v: float) -> str:
if v != v: # NaN
return " - "
return f"{v:7.1f}"
def main() -> None:
print(f"Device: {torch.cuda.get_device_name(0)}")
print(f"PyTorch: {torch.__version__}")
print(
f"Sweep: batch ∈ {BATCH_SIZES}, n_sparse ∈ {N_SPARSE_LIST}, "
f"backends ∈ {[b[0] for b in BACKENDS]}, iters={ITERS}\n"
)
# results[backend_name][batch][n_sparse] = mean µs
results: dict[str, dict[tuple[int, int], float]] = {
name: {} for name, _ in BACKENDS
}
total_cells = len(BATCH_SIZES) * len(N_SPARSE_LIST) * len(BACKENDS)
cell = 0
for n_sparse in N_SPARSE_LIST:
for batch in BATCH_SIZES:
model, ln_emb = make_model(n_sparse)
inputs = make_inputs(batch, ln_emb)
for name, fn in BACKENDS:
cell += 1
t0 = time.perf_counter()
try:
mu = fn(model, inputs)
except Exception as e:
mu = float("nan")
print(
f" [{cell:>3}/{total_cells}] "
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
f"FAILED: {type(e).__name__}: {str(e).splitlines()[-1][:80]}"
)
continue
results[name][(batch, n_sparse)] = mu
print(
f" [{cell:>3}/{total_cells}] "
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
f"mean={mu:>7.1f}µs (took {time.perf_counter() - t0:.1f}s)"
)
gc.collect()
torch.cuda.empty_cache()
torch._dynamo.reset()
# ---- Print one table per backend -----------------------------------
print("\n" + "=" * 78)
print("Latency in µs by backend (rows = batch, cols = n_sparse)")
for name, _ in BACKENDS:
print(f"\n {name}:")
header = " " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST)
print(header)
for bs in BATCH_SIZES:
row = f" bs={bs:<4} "
for ns in N_SPARSE_LIST:
v = results[name].get((bs, ns), float("nan"))
row += f" {fmt(v)} "
print(row)
# ---- Print "fastest backend per cell" matrix -----------------------
print("\n" + "=" * 78)
print("Winner per cell (lowest mean µs):")
print("\n " + "".join(f" n_sp={ns:<14}" for ns in N_SPARSE_LIST))
for bs in BATCH_SIZES:
row = f" bs={bs:<4} "
for ns in N_SPARSE_LIST:
options = [
(name, results[name].get((bs, ns), float("inf"))) for name, _ in BACKENDS
]
options = [(n, v) for n, v in options if v == v and v != float("inf")]
if not options:
row += " - "
continue
winner = min(options, key=lambda x: x[1])
row += f" {winner[0]:<13s} {winner[1]:>6.1f}"
print(row)
# ---- luminal_backend vs eager: scaling story -----------------------
print("\n" + "=" * 78)
print("luminal_backend / eager (lower than 1.0 = luminal wins this cell):")
print("\n " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST))
for bs in BATCH_SIZES:
row = f" bs={bs:<4} "
for ns in N_SPARSE_LIST:
le = results.get("luminal_backend", {}).get((bs, ns), float("nan"))
eg = results.get("eager", {}).get((bs, ns), float("nan"))
if eg != eg or le != le or eg == 0:
row += " - "
else:
row += f" {le / eg:>5.2f}x"
print(row)
if __name__ == "__main__":
main()

89
examples/dlrm/export.py Normal file
View File

@@ -0,0 +1,89 @@
"""Three-way DLRM equivalence harness.
Builds the MiniDLRM at the fixed config used by examples/dlrm/src/main.rs,
serializes weights + sample inputs + the PyTorch eager output to safetensors
files that the rust binary loads. Also runs the PyTorch + luminal_backend
path so the comparison happens in one place.
Saves:
/tmp/dlrm_weights.safetensors — state_dict with PyTorch names
/tmp/dlrm_inputs.safetensors — dense_x, indices_{0..2}, and `expected`
(the PyTorch eager output, fp32)
Then run:
cargo run --release --manifest-path examples/dlrm/Cargo.toml
"""
import sys
from pathlib import Path
import torch
from safetensors.torch import save_file
# Import MiniDLRM from the test file we already authored.
TESTS_DIR = Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
sys.path.insert(0, str(TESTS_DIR))
from test_dlrm import MiniDLRM # noqa: E402
# Backend (and venv) shared with the test runner.
from luminal import luminal_backend # noqa: E402
M_SPA = 4
LN_EMB = [10, 20, 30]
LN_BOT = [13, 8, M_SPA]
LN_TOP = [10, 8, 1]
# Match the rust binary's BATCH_SIZE — real-workload DLRM batch where
# compute-bound matmul efficiency is what's being measured.
BATCH = 2048
DEVICE = torch.device("cuda")
def main() -> None:
torch.manual_seed(0)
model = MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
indices = [
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE)
for n in LN_EMB
]
offsets = [
torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB
]
with torch.no_grad():
eager_out = model(dense_x, offsets, indices)
compiled = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
luminal_out = compiled(dense_x, offsets, indices)
max_diff_lum = (luminal_out - eager_out).abs().max().item()
print(f"PyTorch eager output : {eager_out.flatten().tolist()}")
print(f"PyTorch + luminal : {luminal_out.flatten().tolist()}")
print(f" max |diff| eager vs luminal_backend : {max_diff_lum:.3e}")
assert max_diff_lum < 1e-5, "PT eager and luminal_backend disagree"
# Save weights — state_dict names already match what rust uses.
weights = {k: v.detach().cpu() for k, v in model.state_dict().items()}
save_file(weights, "/tmp/dlrm_weights.safetensors")
print(f" wrote /tmp/dlrm_weights.safetensors ({len(weights)} tensors)")
inputs = {
"dense_x": dense_x.detach().cpu().contiguous(),
"expected": eager_out.detach().cpu().contiguous(),
}
for k, ix in enumerate(indices):
# Rust reads i32 indices.
inputs[f"indices_{k}"] = ix.detach().cpu().to(torch.int32).contiguous()
save_file(inputs, "/tmp/dlrm_inputs.safetensors")
print(f" wrote /tmp/dlrm_inputs.safetensors ({len(inputs)} tensors)")
print(
"\nNext: cargo run --release --manifest-path examples/dlrm/Cargo.toml --bin dlrm"
)
if __name__ == "__main__":
main()

637
examples/dlrm/src/main.rs Normal file
View File

@@ -0,0 +1,637 @@
//! Pure-rust DLRM mirroring `MiniDLRM` from
//! `crates/luminal_python/tests/test_dlrm.py`.
//!
//! Loads weights + sample inputs + expected output produced by
//! `examples/dlrm/export.py`, runs the same compute graph through luminal's
//! CUDA runtime, and prints max-abs diff vs the saved PyTorch eager output.
//!
//! Topology (fixed for now — same as MiniDLRM at the small-config we test):
//! m_spa = 4
//! ln_emb = [10, 20, 30] (3 sparse tables)
//! ln_bot = [13, 8, 4] (Linear-ReLU-Linear-ReLU)
//! ln_top = [10, 8, 1] (Linear-ReLU-Linear-Sigmoid)
//! batch_size = 2, bag_size = 1
//!
//! Weight name convention matches the PyTorch state_dict (so
//! `runtime.load_safetensors` matches by name with no remapping):
//! emb_l.{k}.weight (V_k, m_spa)
//! bot_l.{0,2}.{weight,bias} Linear in_features → out_features
//! top_l.{0,2}.{weight,bias} same
//! PyTorch stores Linear weight as (out, in); we permute when matmul'ing.
use std::path::Path;
use luminal::prelude::*;
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_nn::gather_rows;
use memmap2::MmapOptions;
use safetensors::SafeTensors;
const M_SPA: usize = 4;
const LN_EMB: [usize; 3] = [10, 20, 30];
const LN_BOT: [usize; 3] = [13, 8, M_SPA];
const LN_TOP: [usize; 3] = [10, 8, 1];
// Real-workload DLRM batch — large enough that kernel work dominates the
// per-launch overhead and the compute-bound performance is what's measured.
const BATCH_SIZE: usize = 2048;
/// Linear with bias whose weight matches PyTorch's `nn.Linear` storage:
/// shape `(out, in)`. Forward computes `input @ weight.T + bias`.
struct Linear {
weight: GraphTensor, // (out_features, in_features)
bias: GraphTensor, // (out_features,)
}
impl Linear {
fn new(cx: &mut Graph, prefix: &str, in_features: usize, out_features: usize) -> Self {
Self {
weight: cx
.named_tensor(format!("{prefix}.weight"), (out_features, in_features))
.persist(),
bias: cx
.named_tensor(format!("{prefix}.bias"), out_features)
.persist(),
}
}
fn forward(&self, input: GraphTensor) -> GraphTensor {
let out_features = self.weight.shape.dims[0];
let mm = input.matmul(self.weight.permute((1, 0)));
// Broadcast bias (out,) → output shape (..., out).
let bias_b = self.bias.expand_dim(0, mm.shape.dims[0]);
// bias_b shape: (B, out) — matches `mm` for 2-D input.
let _ = out_features;
mm + bias_b
}
}
// We use luminal's primitive .relu() / .sigmoid() rather than hand-rolling
// them out of maximum / exp / reciprocal so the HLIR generated here matches
// what the PT2 translator emits for `aten.relu.default` / `aten.sigmoid.default`
// op-for-op. See dispatch.rs: both route through these same primitives.
fn bot_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
layers[1].forward(layers[0].forward(x).relu()).relu()
}
fn top_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
layers[1].forward(layers[0].forward(x).relu()).sigmoid()
}
/// Dot interaction: cat(dense, sparse...) → reshape → bmm → flat-tri-upper indexing.
/// Matches MiniDLRM._interact in the python test.
fn interact_features(
cx: &mut Graph,
dense: GraphTensor,
sparse: &[GraphTensor],
) -> GraphTensor {
let batch = dense.shape.dims[0];
let d = dense.shape.dims[1];
let n_feat = 1 + sparse.len();
// T = cat([dense, *sparse], dim=1).view(B, n_feat, d)
let mut t = dense;
for s in sparse {
t = t.concat_along(*s, 1);
}
// Reshape (B, n_feat * d) → (B, n_feat, d). concat_along leaves a contiguous
// tensor so a fresh ShapeTracker is safe.
let bagged = GraphTensor {
id: t.id,
graph_ref: t.graph_ref,
shape: ShapeTracker::new((batch, Expression::from(n_feat), d)),
dtype: t.dtype,
};
// Z = bmm(T, T.transpose(1, 2)) → (B, n_feat, n_feat)
let z = bagged.matmul(bagged.permute((0, 2, 1)));
// Strictly-lower-triangular indices into (n_feat, n_feat). For n_feat=4
// these are 6 (i,j) pairs: (1,0),(2,0),(2,1),(3,0),(3,1),(3,2).
let mut li = Vec::new();
let mut lj = Vec::new();
for i in 0..n_feat {
for j in 0..i {
li.push(i as i32);
lj.push(j as i32);
}
}
let n_pairs = li.len();
// Build flat_idx_per_pair[k] = li[k] * n_feat + lj[k] (constant across batch).
let mut flat_idx_per_pair = Vec::with_capacity(n_pairs);
for k in 0..n_pairs {
flat_idx_per_pair.push(li[k] * n_feat as i32 + lj[k]);
}
// Absolute flat index into Z viewed as 1D for each (b, k):
// abs[b, k] = b * (n_feat*n_feat) + flat_idx_per_pair[k]
let row_stride = n_feat * n_feat; // entries per batch in Z
let arange_b = cx.arange(batch); // (B,) ints, values 0..B
let abs_idx = arange_b.expand_dim(1, Expression::from(n_pairs))
* Expression::from(row_stride);
// pair_idx_const: (n_pairs,) ints, captured as a graph input we set once.
let pair_idx = cx
.named_tensor("__dot_pair_idx", n_pairs)
.as_dtype(DType::Int)
.persist();
let abs_idx = abs_idx + pair_idx.expand_dim(0, batch);
// Gather Z as 1D.
let z_flat = GraphTensor {
id: z.id,
graph_ref: z.graph_ref,
shape: ShapeTracker::new(batch * row_stride),
dtype: z.dtype,
};
let zflat_indexed = z_flat.gather(abs_idx); // (B, n_pairs)
// R = cat(dense, zflat_indexed, dim=1) → (B, d + n_pairs)
dense.concat_along(zflat_indexed, 1)
}
fn main() {
// Parse args: optional --bench / --stats / --mega, then positional paths.
let mut bench_mode = false;
let mut stats_mode = false;
let mut mega_mode = false;
let mut positional: Vec<String> = Vec::new();
for arg in std::env::args().skip(1) {
if arg == "--bench" {
bench_mode = true;
} else if arg == "--stats" {
stats_mode = true;
} else if arg == "--mega" {
mega_mode = true;
} else {
positional.push(arg);
}
}
let weights_path = positional
.first()
.cloned()
.unwrap_or_else(|| "/tmp/dlrm_weights.safetensors".to_string());
let inputs_path = positional
.get(1)
.cloned()
.unwrap_or_else(|| "/tmp/dlrm_inputs.safetensors".to_string());
if mega_mode {
run_megakernel(&weights_path, &inputs_path, bench_mode);
return;
}
assert!(
Path::new(&weights_path).exists(),
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
);
assert!(
Path::new(&inputs_path).exists(),
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
);
// ---- Build graph -----------------------------------------------------
let mut cx = Graph::default();
let dense_in = cx
.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
.map(|k| {
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
.as_dtype(DType::Int)
})
.collect();
// Embedding tables (bag_size=1 → just row gather).
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
.map(|k| {
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
.persist()
})
.collect();
let sparse_feats: Vec<GraphTensor> = (0..LN_EMB.len())
.map(|k| gather_rows(emb_weights[k], idx_tensors[k], M_SPA))
.collect();
// Bottom MLP: Linear 13→8, ReLU, Linear 8→4, ReLU.
let bot = [
Linear::new(&mut cx, "bot_l.0", LN_BOT[0], LN_BOT[1]),
Linear::new(&mut cx, "bot_l.2", LN_BOT[1], LN_BOT[2]),
];
let dense_out = bot_forward(&bot, dense_in);
// Dot interaction → (B, n_pairs + m_spa) = (B, 10) for our config.
let interacted = interact_features(&mut cx, dense_out, &sparse_feats);
// Top MLP: Linear 10→8, ReLU, Linear 8→1, Sigmoid.
let top = [
Linear::new(&mut cx, "top_l.0", LN_TOP[0], LN_TOP[1]),
Linear::new(&mut cx, "top_l.2", LN_TOP[1], LN_TOP[2]),
];
let out = top_forward(&top, interacted).output();
// ---- Compile + load weights ------------------------------------------
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, &weights_path);
// Set the strictly-lower-triangular pair index constant.
let n_feat = 1 + LN_EMB.len();
let mut pair_idx_vals = Vec::new();
for i in 0..n_feat {
for j in 0..i {
pair_idx_vals.push((i * n_feat + j) as i32);
}
}
// Find the named input by walking the graph.
let pair_idx_id = find_named_input(&cx, "__dot_pair_idx")
.expect("pair_idx tensor not found in graph");
runtime.set_data(pair_idx_id, pair_idx_vals);
// Load inputs + expected output from safetensors.
let inputs_mmap = unsafe {
MmapOptions::new()
.map(&std::fs::File::open(&inputs_path).unwrap())
.unwrap()
};
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
let dense_x: Vec<f32> = bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
runtime.set_data(dense_in, dense_x);
for (k, idx_t) in idx_tensors.iter().enumerate() {
let ix: Vec<i32> = bytemuck::cast_slice(
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
)
.to_vec();
runtime.set_data(*idx_t, ix);
}
// ---- Search (small budget — graph is tiny) ---------------------------
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
runtime = cx.search_options(runtime, SearchOptions::new(8).trials(1), &mut rng);
// ---- Execute and compare ---------------------------------------------
runtime.execute(&cx.dyn_map);
let result = runtime.get_f32(out);
let expected_bytes = inputs_st.tensor("expected").unwrap().data();
let expected: &[f32] = bytemuck::cast_slice(expected_bytes);
println!("rust output : {result:?}");
println!("expected : {expected:?}");
let max_diff = result
.iter()
.zip(expected.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
println!("max |diff| : {max_diff:.3e}");
assert!(
max_diff < 1e-4,
"rust output diverges from PyTorch eager (max_diff={max_diff})"
);
println!("OK — rust luminal matches PyTorch eager within 1e-4.");
if stats_mode {
let host_ops = runtime.host_ops();
println!("\n=== Active bucket host-op inventory ({} ops) ===", host_ops.len());
let mut by_type: std::collections::BTreeMap<String, usize> =
std::collections::BTreeMap::new();
for op in &host_ops {
let s = format!("{op:?}");
let head = s.split_whitespace().next().unwrap_or(&s).to_string();
*by_type.entry(head).or_insert(0) += 1;
}
for (k, v) in &by_type {
println!(" {v:>3} {k}");
}
// Per-op detail: extract the cuBLASLt epilogue + shape signature so
// we can see at a glance whether bias/relu fusion fired (the egglog
// rewrites map matmul+add+maximum_f32(0) -> EPILOGUE_RELU_BIAS).
println!("\n=== cuBLASLt op detail ===");
for op in &host_ops {
let s = format!("{op:?}");
if !s.starts_with("CuBlasLt") {
continue;
}
let epilogue = extract_field(&s, "epilogue:");
let shape = (extract_field(&s, "m:"), extract_field(&s, "n:"), extract_field(&s, "k:"));
println!(
" m={:<8} n={:<8} k={:<8} epilogue={}",
shape.0, shape.1, shape.2, epilogue
);
}
}
if bench_mode {
// Cache input vectors so the bench loop can re-set_data each iter (the
// PyTorch backends do an equivalent staging step under the hood).
let dense_vec: Vec<f32> =
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
.map(|k| {
bytemuck::cast_slice(
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
)
.to_vec()
})
.collect();
bench_rust(
&mut cx,
&mut runtime,
out,
dense_in,
&idx_tensors,
dense_vec,
idx_vecs,
);
}
}
/// Time `runtime.execute` directly. Inputs are already loaded once before
/// `--bench` and not re-uploaded between calls, mirroring CUDA-graph replay
/// semantics. Synchronizes the stream once at the end and divides total
/// elapsed by `iters` for a steady-state mean; also prints per-iter samples
/// to /tmp/dlrm_bench_rust_luminal.txt for the python aggregator.
fn bench_rust(
cx: &mut Graph,
runtime: &mut CudaRuntime,
out: GraphTensor,
dense_in: GraphTensor,
idx_tensors: &[GraphTensor],
dense_vec: Vec<f32>,
idx_vecs: Vec<Vec<i32>>,
) {
bench_through_luminal(
cx,
runtime,
out,
dense_in,
idx_tensors,
dense_vec,
idx_vecs,
"/tmp/dlrm_bench_rust_luminal.txt",
"[bench] rust luminal",
);
}
/// Shared steady-state bench for any luminal graph + runtime. Re-sets
/// inputs every iter, calls `execute`, then `get_f32` to force a stream
/// sync. Dumps per-iter µs samples to `samples_path` for
/// `examples/dlrm/bench.py` to merge into its ranking.
#[allow(clippy::too_many_arguments)]
fn bench_through_luminal(
cx: &mut Graph,
runtime: &mut CudaRuntime,
out: GraphTensor,
dense_in: GraphTensor,
idx_tensors: &[GraphTensor],
dense_vec: Vec<f32>,
idx_vecs: Vec<Vec<i32>>,
samples_path: &str,
label: &str,
) {
const WARMUP: usize = 50;
const ITERS: usize = 500;
use std::time::Instant;
let bench_once = |runtime: &mut CudaRuntime| {
runtime.set_data(dense_in, dense_vec.clone());
for (k, t) in idx_tensors.iter().enumerate() {
runtime.set_data(*t, idx_vecs[k].clone());
}
runtime.execute(&cx.dyn_map);
let _ = runtime.get_f32(out);
};
for _ in 0..WARMUP {
bench_once(runtime);
}
let mut samples = Vec::with_capacity(ITERS);
for _ in 0..ITERS {
let t0 = Instant::now();
bench_once(runtime);
samples.push(t0.elapsed().as_secs_f64() * 1e6);
}
samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mean = samples.iter().sum::<f64>() / ITERS as f64;
let p50 = samples[ITERS / 2];
let p99 = samples[(ITERS as f64 * 0.99) as usize];
println!(
"\n{label}: mean={mean:.2}µs p50={p50:.2}µs p99={p99:.2}µs (n={ITERS})"
);
let body = samples
.iter()
.map(|s| format!("{s:.4}"))
.collect::<Vec<_>>()
.join("\n");
std::fs::write(samples_path, body).expect("write bench samples");
println!(" per-iter samples -> {samples_path}");
}
/// `--mega`: build a luminal Graph whose entire forward is a single
/// [`DlrmMegaCustom`] op, then run it through the standard
/// `CudaRuntime` flow (load_safetensors → search → execute → get_f32).
/// Verifies bitwise vs the saved PyTorch eager output, optionally
/// benches steady-state per-call latency through the same `bench_rust`
/// path the non-mega rust binary uses.
///
/// The point: same kernel as the PT2-backend fast path (the parameterized
/// `DlrmMegaKernel` in `luminal_cuda_lite::kernel::dlrm_megakernel`),
/// just constructed by hand instead of via the translator's pattern
/// matcher. Everything past the `cx.custom_op` call — buffer
/// management, weight loading, input registration, kernel dispatch,
/// output retrieval — is luminal's runtime.
fn run_megakernel(weights_path: &str, inputs_path: &str, bench: bool) {
assert!(
Path::new(weights_path).exists(),
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
);
assert!(
Path::new(inputs_path).exists(),
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
);
let mut cx = Graph::default();
// ---- User inputs -----------------------------------------------------
let dense_in = cx.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
.map(|k| {
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
.as_dtype(DType::Int)
})
.collect();
// ---- Weights — names must match safetensors keys so the runtime's
// load_safetensors matches by Input label.
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
.map(|k| {
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
.persist()
})
.collect();
// PyTorch's nn.Linear stores weight as (out_features, in_features).
let bot_l0_w = cx
.named_tensor("bot_l.0.weight", (LN_BOT[1], LN_BOT[0]))
.persist();
let bot_l0_b = cx.named_tensor("bot_l.0.bias", LN_BOT[1]).persist();
let bot_l1_w = cx
.named_tensor("bot_l.2.weight", (LN_BOT[2], LN_BOT[1]))
.persist();
let bot_l1_b = cx.named_tensor("bot_l.2.bias", LN_BOT[2]).persist();
let top_l0_w = cx
.named_tensor("top_l.0.weight", (LN_TOP[1], LN_TOP[0]))
.persist();
let top_l0_b = cx.named_tensor("top_l.0.bias", LN_TOP[1]).persist();
let top_l1_w = cx
.named_tensor("top_l.2.weight", (LN_TOP[2], LN_TOP[1]))
.persist();
let top_l1_b = cx.named_tensor("top_l.2.bias", LN_TOP[2]).persist();
// ---- One CustomOp does the whole forward ----------------------------
// Input order MUST match what DlrmMegaKernel's CUDA source expects:
// dense, indices..., emb_weights..., bot Linears (w then b each),
// top Linears (w then b each). See `kernel::dlrm_megakernel`.
let mut inputs: Vec<GraphTensor> = vec![dense_in];
inputs.extend(idx_tensors.iter().copied());
inputs.extend(emb_weights.iter().copied());
inputs.extend([
bot_l0_w, bot_l0_b, bot_l1_w, bot_l1_b, top_l0_w, top_l0_b, top_l1_w, top_l1_b,
]);
let kernel = DlrmMegaKernel {
batch: BATCH_SIZE,
n_dense_in: LN_BOT[0],
ln_bot: LN_BOT.to_vec(),
n_sparse: LN_EMB.len(),
vocab_sizes: LN_EMB.to_vec(),
m_spa: M_SPA,
ln_top: LN_TOP.to_vec(),
};
let out = cx
.custom_op(
DlrmMegaCustom(kernel),
inputs,
(BATCH_SIZE, 1usize),
DType::F32,
)
.output();
// ---- Compile + load weights — same path as the non-mega flow -------
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, weights_path);
// ---- Inputs ---------------------------------------------------------
let inputs_mmap = unsafe {
MmapOptions::new()
.map(&std::fs::File::open(inputs_path).unwrap())
.unwrap()
};
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
let dense_vec: Vec<f32> =
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
runtime.set_data(dense_in, dense_vec.clone());
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
.map(|k| {
bytemuck::cast_slice(
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
)
.to_vec()
})
.collect();
for (k, idx_t) in idx_tensors.iter().enumerate() {
runtime.set_data(*idx_t, idx_vecs[k].clone());
}
// ---- Search ---------------------------------------------------------
// Single-CustomOp graph: nothing to search over. One trial.
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
runtime = cx.search_options(runtime, SearchOptions::new(1).trials(1), &mut rng);
// ---- Execute + verify -----------------------------------------------
runtime.execute(&cx.dyn_map);
let result = runtime.get_f32(out);
let expected: &[f32] =
bytemuck::cast_slice(inputs_st.tensor("expected").unwrap().data());
let max_diff = result
.iter()
.zip(expected.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
println!(
"[mega] output[..4]={:?} expected[..4]={:?} max|diff|={:.3e}",
&result[..result.len().min(4)],
&expected[..result.len().min(4)],
max_diff
);
assert!(
max_diff < 1e-4,
"megakernel output diverges from PyTorch eager (max_diff={max_diff})"
);
println!("[mega] OK — luminal megakernel matches PyTorch eager within 1e-4");
// Inventory the host ops — should be exactly 1 (the DlrmMegaCustom).
let host_ops = runtime.host_ops();
println!("[mega] active bucket host-op count: {}", host_ops.len());
if bench {
// Reuse the shared bench loop. Writes per-iter µs samples to
// /tmp/dlrm_bench_megakernel.txt so examples/dlrm/bench.py picks
// them up under the "DLRM megakernel" row.
bench_through_luminal(
&mut cx,
&mut runtime,
out,
dense_in,
&idx_tensors,
dense_vec,
idx_vecs,
"/tmp/dlrm_bench_megakernel.txt",
"[mega]",
);
}
}
/// Pull a `Field: value,` from a Debug-formatted struct dump. Returns the
/// substring between `field` and the next `,` or `}`, trimmed.
fn extract_field(s: &str, field: &str) -> String {
let Some(idx) = s.find(field) else {
return "?".to_string();
};
let start = idx + field.len();
let tail = &s[start..];
let end = tail
.find(|c: char| c == ',' || c == '}')
.unwrap_or(tail.len());
tail[..end].trim().to_string()
}
/// Walk the graph looking for an [`Input`] op with the given label. Used to
/// recover a `NodeIndex` we can `set_data` against when the original
/// `GraphTensor` handle isn't in scope.
fn find_named_input(cx: &Graph, label: &str) -> Option<NodeIndex> {
use luminal::hlir::Input;
for n in cx.graph.node_indices() {
if let Some(Input { label: l, .. }) =
(*cx.graph[n]).as_any().downcast_ref::<Input>()
{
if l == label {
return Some(n);
}
}
}
None
}

View File

@@ -11,6 +11,7 @@ luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
tokenizers = "0.22.2"
rustc-hash = "2"
rand = "0.9.2"
# HuggingFace model download
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }

View File

@@ -5,11 +5,13 @@ use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rand::{SeedableRng, rngs::SmallRng};
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
const SEARCH_SEED: u64 = 0;
fn env_bool(name: &str) -> bool {
std::env::var(name)
@@ -78,7 +80,12 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, search_graphs);
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(
runtime,
SearchOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
&mut rng,
);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);

View File

@@ -113,10 +113,6 @@ impl QwenRuntime for luminal_metal::MetalRuntime {
fn get_f32(&self, id: NodeIndex) -> Vec<f32> {
luminal_metal::MetalRuntime::get_f32(self, id)
}
fn prepare_execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
luminal_metal::MetalRuntime::allocate_intermediate_buffers(self, dyn_map);
}
}
pub fn run_qwen<R>(mut runtime: R, config: QwenRunConfig) -> Result<(), Box<dyn Error>>
@@ -177,6 +173,17 @@ where
DimBucket::new(2, max_prefill).representative(search_s),
],
);
let max_decode_p = config.max_seq_len.saturating_sub(1);
let decode_p_representative = prompt_tokens.len().min(max_decode_p).max(1);
let p_buckets = if max_decode_p == 0 {
vec![DimBucket::new(0, 0)]
} else {
vec![
DimBucket::new(0, 0),
DimBucket::new(1, max_decode_p).representative(decode_p_representative),
]
};
cx.set_dim_buckets('p', &p_buckets);
cx.set_dim('s', search_s);
cx.set_dim('p', 0);
runtime.set_i32_data(input.id, vec![1; search_s]);

View File

@@ -279,17 +279,17 @@ impl DynBackend for NativeDynBackend {
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.f32(i)).collect()
data.to_f32_vec()
}
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.i32(i)).collect()
data.to_i32_vec()
}
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.bool(i)).collect()
data.to_bool_vec()
}
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {

View File

@@ -1620,7 +1620,56 @@ pub fn extract_expr<'a>(
pub type EGraphChoiceSet<'a> = FxHashMap<&'a ClassId, &'a NodeId>;
/// Count the total number of possible IR/IList choice sets, capped at `limit`.
fn is_search_choice_eclass(label: &str) -> bool {
label.contains("IR") || label.contains("IList") || label.contains("OpKind")
}
fn extractor_list_len(egraph: &SerializedEGraph, eclass_id: &ClassId) -> Option<usize> {
let mut len = 0usize;
let mut cur_eclass: ClassId = eclass_id.clone();
let mut visited: FxHashSet<ClassId> = FxHashSet::default();
loop {
if !visited.insert(cur_eclass.clone()) {
return None;
}
let (label, enodes) = egraph.eclasses.get(&cur_eclass)?;
if !label.contains("List") {
return Some(len);
}
let head_enode = enodes.first()?;
let head_label = &egraph.enodes[head_enode].0;
if head_label == "ENil" || head_label == "INil" {
return Some(len);
}
if head_label != "ECons" && head_label != "ICons" {
return Some(len);
}
len += 1;
let children = &egraph.enodes[head_enode].1;
if children.len() < 2 {
return Some(len);
}
cur_eclass = children[1].clone();
}
}
fn opkind_metadata_consistent(egraph: &SerializedEGraph, node: &NodeId) -> bool {
let lens: Vec<usize> = egraph.enodes[node]
.1
.iter()
.filter_map(|c| {
let lbl = &egraph.eclasses[c].0;
if lbl.contains("List") {
extractor_list_len(egraph, c)
} else {
None
}
})
.collect();
lens.is_empty() || lens.iter().all(|l| *l == lens[0])
}
/// Count the total number of possible searchable choice sets, capped at `limit`.
///
/// Search deduplicates candidates by `EGraphChoiceSet`, so this gives the exact
/// number of candidates when it is below `limit` without risking overflow on
@@ -1632,7 +1681,7 @@ pub fn count_choice_sets_up_to(egraph: &SerializedEGraph, limit: usize) -> usize
let mut count = 1usize;
for (label, enodes) in egraph.eclasses.values() {
if !label.contains("IR") && !label.contains("IList") {
if !is_search_choice_eclass(label) {
continue;
}
@@ -1650,10 +1699,10 @@ pub fn random_initial_choice<'a>(
) -> EGraphChoiceSet<'a> {
let mut choices = FxHashMap::default();
for (eclass, (label, enodes)) in &egraph.eclasses {
if !label.contains("IR") && !label.contains("IList") {
if !is_search_choice_eclass(label) {
continue;
}
// Prefer synth-injected enodes when available — they point at
// Use synth-injected enodes when available — they point at
// deterministic single-variant kind eclasses produced by the
// deep-clone fallback in `inject_kernel_alternatives`, so the
// extractor's first-enode walk is guaranteed length-consistent.
@@ -1667,7 +1716,18 @@ pub fn random_initial_choice<'a>(
.enumerate()
.filter_map(|(i, n)| n.as_ref().starts_with("synth_").then_some(i))
.collect();
let pick_idx = if !synth_indices.is_empty() {
let consistent_opkind_indices: Vec<usize> = if label == "OpKind" {
enodes
.iter()
.enumerate()
.filter_map(|(i, n)| opkind_metadata_consistent(egraph, n).then_some(i))
.collect()
} else {
Vec::new()
};
let pick_idx = if !consistent_opkind_indices.is_empty() {
consistent_opkind_indices[rng.random_range(0..consistent_opkind_indices.len())]
} else if !synth_indices.is_empty() {
synth_indices[rng.random_range(0..synth_indices.len())]
} else {
rng.random_range(0..enodes.len())
@@ -1684,9 +1744,9 @@ pub fn validate_choice_set<'a>(
choices: &EGraphChoiceSet<'a>,
ops: &[Arc<Box<dyn EgglogOp>>],
) -> Result<(), String> {
// Check all IR/IList eclasses have a choice
// Check all searchable eclasses have a choice.
for (eclass, (label, enodes)) in &egraph.eclasses {
if !label.contains("IR") && !label.contains("IList") {
if !is_search_choice_eclass(label) {
continue;
}
let Some(chosen) = choices.get(eclass) else {
@@ -1719,7 +1779,7 @@ pub fn validate_choice_set<'a>(
.eclasses
.get(ch)
.ok_or_else(|| format!("Eclass {} not found", ch.as_ref()))?;
if label.contains("IR") || label.contains("IList") {
if is_search_choice_eclass(label) {
let n = choices
.get(ch)
.ok_or_else(|| format!("No choice for reachable eclass {}", ch.as_ref()))?;
@@ -1745,14 +1805,12 @@ pub fn validate_choice_set<'a>(
if op_name == "Op" {
// Normalized op — check OpKind child
if let Some(kind_eclass) = children.first() {
if let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) {
if let Some(kn) = kind_enodes.first() {
let kind_name = &egraph.enodes[kn].0;
if kind_name != "CustomOpKind"
&& !ops.iter().any(|op| op.sort().name == *kind_name)
{
return Err(format!("No extractor for OpKind {kind_name}"));
}
if let Some(kn) = choices.get(kind_eclass) {
let kind_name = &egraph.enodes[kn].0;
if kind_name != "CustomOpKind"
&& !ops.iter().any(|op| op.sort().name == *kind_name)
{
return Err(format!("No extractor for OpKind {kind_name}"));
}
}
}
@@ -1813,9 +1871,7 @@ pub fn extract_generation<'a>(
let mutable_classes: Vec<&ClassId> = egraph
.eclasses
.iter()
.filter(|(_, (label, enodes))| {
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
})
.filter(|(_, (label, enodes))| is_search_choice_eclass(label) && enodes.len() > 1)
.map(|(class_id, _)| class_id)
.collect();
@@ -1848,9 +1904,21 @@ pub fn extract_generation<'a>(
for _ in 0..rng.random_range(1..=mutations_per_generation) {
// Pick a random mutable eclass
let class_id = mutable_classes[rng.random_range(0..mutable_classes.len())];
let (_, enodes) = &egraph.eclasses[class_id];
let (label, enodes) = &egraph.eclasses[class_id];
// Pick a random enode for this class
let new_node = &enodes[rng.random_range(0..enodes.len())];
let consistent_opkind_nodes: Vec<&NodeId> = if label == "OpKind" {
enodes
.iter()
.filter(|n| opkind_metadata_consistent(egraph, n))
.collect()
} else {
Vec::new()
};
let new_node = if !consistent_opkind_nodes.is_empty() {
consistent_opkind_nodes[rng.random_range(0..consistent_opkind_nodes.len())]
} else {
&enodes[rng.random_range(0..enodes.len())]
};
// Insert returns the previous binding (if any); fold the diff
// into the running hash. If the new pick equals the old one,
// the two XORs cancel and `child_hash` is unchanged — exactly
@@ -1932,7 +2000,7 @@ pub fn egglog_to_llir_from_root<'a>(
let mut reachability_stack = vec![choices[root_class]];
while let Some(r) = reachability_stack.pop() {
for ch in &egraph.enodes[r].1 {
if egraph.eclasses[ch].0.contains("IR") || egraph.eclasses[ch].0.contains("IList") {
if is_search_choice_eclass(&egraph.eclasses[ch].0) {
let n = choices[ch];
if !reachable.contains(n) {
reachability_stack.push(n);
@@ -1968,69 +2036,19 @@ pub fn egglog_to_llir_from_root<'a>(
// structurally-equivalent kind enodes whose ELIST children
// were unioned but resolve (under the extractor's first-enode
// walk) to inconsistent lengths — picking such an enode causes
// a downstream `flatten_strides` length mismatch. Prefer the
// first kind enode whose ELIST children all walk to the same
// length; fall back to the original first enode if no
// consistent candidate exists (rare; only happens for ops
// outside the runnable subgraph).
// a downstream `flatten_strides` length mismatch. Candidate
// generation filters these out where possible; this fallback is
// structural only and does not rank backend implementations.
let kind_enodes = &egraph.eclasses[kind_eclass].1;
let extractor_length = |eclass_id: &ClassId| -> Option<usize> {
let mut len = 0usize;
let mut cur_eclass: ClassId = eclass_id.clone();
let mut visited: FxHashSet<ClassId> = FxHashSet::default();
loop {
if !visited.insert(cur_eclass.clone()) {
return None;
}
let (label, enodes) = egraph.eclasses.get(&cur_eclass)?;
if !label.contains("List") {
return Some(len);
}
let head_enode = enodes.first()?;
let head_label = &egraph.enodes[head_enode].0;
if head_label == "ENil" || head_label == "INil" {
return Some(len);
}
if head_label != "ECons" && head_label != "ICons" {
return Some(len);
}
len += 1;
let children = &egraph.enodes[head_enode].1;
if children.len() < 2 {
return Some(len);
}
cur_eclass = children[1].clone();
}
};
let elist_lens_for = |n: &NodeId| -> Vec<usize> {
egraph.enodes[n]
.1
.iter()
.filter_map(|c| {
let lbl = &egraph.eclasses[c].0;
if lbl.contains("List") {
extractor_length(c)
} else {
None
}
})
.collect()
};
let is_consistent = |n: &NodeId| -> bool {
let lens = elist_lens_for(n);
lens.is_empty() || lens.iter().all(|l| *l == lens[0])
};
let is_kernel = |n: &NodeId| -> bool {
let l = &egraph.enodes[n].0;
l.starts_with("Kernel") || l.starts_with("Fused")
};
// Prefer a consistent kernel kind; then any consistent;
// then any kernel; then fall back to first.
let kind_enode = kind_enodes
.iter()
.find(|n| is_kernel(n) && is_consistent(n))
.or_else(|| kind_enodes.iter().find(|n| is_consistent(n)))
.or_else(|| kind_enodes.iter().find(|n| is_kernel(n)))
let kind_enode = choices
.get(kind_eclass)
.copied()
.filter(|n| opkind_metadata_consistent(egraph, n))
.or_else(|| {
kind_enodes
.iter()
.find(|n| opkind_metadata_consistent(egraph, n))
})
.unwrap_or(&kind_enodes[0]);
let kind_label = &egraph.enodes[kind_enode].0;
@@ -2039,8 +2057,7 @@ pub fn egglog_to_llir_from_root<'a>(
.1
.iter()
.map(|c| {
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
{
if is_search_choice_eclass(&egraph.eclasses[c].0) {
choices[c]
} else {
&egraph.eclasses[c].1[0]
@@ -2085,8 +2102,7 @@ pub fn egglog_to_llir_from_root<'a>(
.1
.iter()
.map(|c| {
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
{
if is_search_choice_eclass(&egraph.eclasses[c].0) {
choices[c]
} else {
&egraph.eclasses[c].1[0]
@@ -2165,10 +2181,11 @@ mod tests {
let egraph = egraph(vec![
eclass("a", "IR", 2),
eclass("b", "IList", 3),
eclass("op", "OpKind", 5),
eclass("c", "Shape", 99),
]);
assert_eq!(count_choice_sets_up_to(&egraph, 100), 6);
assert_eq!(count_choice_sets_up_to(&egraph, 100), 30);
}
#[test]

View File

@@ -173,15 +173,16 @@ impl BuildSearchSpaceOptions {
pub struct SearchOptions {
/// Maximum number of graphs to evaluate
pub limit: usize,
/// Number of offspring per generation (default: 30)
/// Number of offspring per generation (default: 10)
pub generation_size: usize,
/// Number of mutations applied to each offspring (default: 30)
/// Number of mutations applied to each offspring (default: 10)
pub mutations: usize,
/// Number of profiling trials per candidate (default: 3)
pub trials: usize,
/// Number of best genomes to keep as parents per generation (default: 1)
pub keep_best: usize,
/// Optional per-candidate profiling timeout.
/// Per-candidate profiling timeout. If a profile call reaches this budget,
/// that candidate is discarded and search continues.
pub profile_timeout: Option<std::time::Duration>,
/// Optional per-group search timeout.
pub group_timeout: Option<std::time::Duration>,
@@ -194,11 +195,11 @@ impl SearchOptions {
pub fn new(limit: usize) -> Self {
Self {
limit,
generation_size: 30,
mutations: 30,
generation_size: 10,
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: None,
profile_timeout: Some(std::time::Duration::from_secs(1)),
group_timeout: None,
profile_dims: FxHashMap::default(),
}
@@ -315,6 +316,27 @@ fn maybe_dump_selected_llir(label: &str, dyn_map: &FxHashMap<char, usize>, llir:
}
}
fn random_choice_generation<'a, G: rand::Rng>(
egraph: &'a SerializedEGraph,
generation_size: usize,
prev_selected: &mut FxHashSet<u64>,
rng: &mut G,
) -> Vec<crate::egglog_utils::EGraphChoiceSet<'a>> {
let mut generation = Vec::with_capacity(generation_size);
let max_attempts = generation_size.saturating_mul(100);
let mut attempts = 0;
while generation.len() < generation_size && attempts < max_attempts {
attempts += 1;
let genome = random_initial_choice(egraph, rng);
if prev_selected.insert(hash_choice_set(&genome)) {
generation.push(genome);
}
}
generation
}
/// A Luminal compute graph.
///
/// All computation is represented as a directed acyclic graph.
@@ -1347,6 +1369,11 @@ impl Graph {
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
runtime.clear_intermediate_buffers();
let profile_timed_out = |elapsed: std::time::Duration| {
options
.profile_timeout
.is_some_and(|timeout| elapsed >= timeout)
};
// Find a viable initial genome (may need multiple attempts if some panic)
let (mut best_genome, mut best_metric, display, mut n_graphs);
@@ -1385,13 +1412,15 @@ impl Graph {
// unrolled graph size.
collapse_loops_to_first_iter(&mut graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let has_nan = runtime.has_nan_outputs(&graph, &profile_dyn_map);
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
(
rep_metric,
append_memory_display(
@@ -1401,11 +1430,12 @@ impl Graph {
runtime.allocated_intermediate_buffer_bytes(),
),
has_nan,
timed_out,
)
}));
match result {
Ok((metric, disp, false)) => {
Ok((metric, disp, false, false)) => {
best_genome = genome;
best_metric = R::aggregate_profile_metrics(&[metric]);
display = disp;
@@ -1435,6 +1465,7 @@ impl Graph {
// Track top-N parents for offspring generation
let mut parents: Vec<(R::ProfileMetric, crate::egglog_utils::EGraphChoiceSet<'_>)> =
vec![(best_metric.clone(), best_genome.clone())];
let mut resample_generation = false;
while n_graphs < search_limit {
if options
@@ -1446,26 +1477,33 @@ impl Graph {
// Generate offspring from all parents, dividing budget evenly
let budget = (search_limit - n_graphs).min(options.generation_size);
let per_parent = budget.div_ceil(parents.len());
let mut all_offspring = Vec::new();
for (_, parent_genome) in &parents {
let remaining = budget.saturating_sub(all_offspring.len());
if remaining == 0 {
break;
let all_offspring = if resample_generation {
random_choice_generation(egraph, budget, &mut prev_selected, rng)
} else {
let per_parent = budget.div_ceil(parents.len());
let mut offspring = Vec::new();
for (_, parent_genome) in &parents {
let remaining = budget.saturating_sub(offspring.len());
if remaining == 0 {
break;
}
offspring.extend(extract_generation(
egraph,
parent_genome,
per_parent.min(remaining),
options.mutations,
&mut prev_selected,
rng,
));
}
all_offspring.extend(extract_generation(
egraph,
parent_genome,
per_parent.min(remaining),
options.mutations,
&mut prev_selected,
rng,
));
}
offspring
};
if all_offspring.is_empty() {
break;
}
let mut generation_found_non_timeout = false;
for genome in all_offspring {
if options
.group_timeout
@@ -1502,13 +1540,16 @@ impl Graph {
// before profiling — see initial-genome path.
collapse_loops_to_first_iter(&mut llir_graph);
runtime.clear_intermediate_buffers();
let profile_start = std::time::Instant::now();
let (rep_metric, rep_display) = runtime.profile(
&llir_graph,
&profile_dyn_map,
options.trials,
options.profile_timeout,
);
let has_nan = runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
let timed_out = profile_timed_out(profile_start.elapsed());
let has_nan =
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
(
rep_metric,
append_memory_display(
@@ -1518,15 +1559,28 @@ impl Graph {
runtime.allocated_intermediate_buffer_bytes(),
),
has_nan,
timed_out,
)
}));
let (new_metric, display_metric) = match profile_result {
Ok((metric, display, false)) => {
Ok((metric, display, false, false)) => {
generation_found_non_timeout = true;
(R::aggregate_profile_metrics(&[metric]), display)
}
Ok((_, _, true)) | Err(_) => {
// NaN or panic — redraw bars and skip
Ok((_, _, _, true)) | Err(_) => {
// Timed out or panicked — redraw bars and skip.
for _ in 1..n_bar_lines {
print!("\x1b[1A");
}
print!("\r\x1b[2K");
render_bars(n_graphs, search_limit, bucket_progress);
std::io::stdout().flush().unwrap();
continue;
}
Ok((_, _, true, false)) => {
generation_found_non_timeout = true;
// Completed profiling but produced NaNs — redraw bars and skip.
for _ in 1..n_bar_lines {
print!("\x1b[1A");
}
@@ -1577,6 +1631,8 @@ impl Graph {
render_bars(n_graphs, search_limit, bucket_progress);
std::io::stdout().flush().unwrap();
}
resample_generation = !generation_found_non_timeout;
}
// Clear progress bars

View File

@@ -1613,8 +1613,7 @@ fn bin_fn<A: Copy>(
a_ind: StridedIterator,
a: &[A],
b_ind: StridedIterator,
b: &NativeData,
b_get: impl Fn(&NativeData, usize) -> A,
b: &[A],
op: impl Fn(A, A) -> A,
) -> Vec<A> {
let a_shape = a_ind.shape.clone();
@@ -1634,7 +1633,36 @@ fn bin_fn<A: Copy>(
"bin_fn: b index {j} out of bounds (b.len={}), shape={b_shape:?}, strides={b_strides:?}",
b.len(),
);
op(a[i], b_get(b, j))
op(a[i], b[j])
})
.collect()
}
fn bin_cmp_fn<A: Copy>(
a_ind: StridedIterator,
a: &[A],
b_ind: StridedIterator,
b: &[A],
op: impl Fn(A, A) -> bool,
) -> Vec<bool> {
let a_shape = a_ind.shape.clone();
let a_strides = a_ind.strides.clone();
let b_shape = b_ind.shape.clone();
let b_strides = b_ind.strides.clone();
a_ind
.zip(b_ind)
.map(|(i, j)| {
assert!(
i < a.len(),
"bin_cmp_fn: a index {i} out of bounds (a.len={}), shape={a_shape:?}, strides={a_strides:?}",
a.len(),
);
assert!(
j < b.len(),
"bin_cmp_fn: b index {j} out of bounds (b.len={}), shape={b_shape:?}, strides={b_strides:?}",
b.len(),
);
op(a[i], b[j])
})
.collect()
}
@@ -1708,20 +1736,23 @@ impl NativeOp for Add {
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
);
match a {
NativeData::F32(a) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x + y))
match (a, b) {
(NativeData::F32(a), NativeData::F32(b)) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
}
NativeData::F16(a) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x + y))
(NativeData::F16(a), NativeData::F16(b)) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
}
NativeData::Bf16(a) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x + y))
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
}
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
(NativeData::Int(a), NativeData::Int(b)) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
(NativeData::Bool(_), NativeData::Bool(_)) => {
panic!("Cannot add Bool tensors, cast to F32 first")
}
_ => panic!("Add inputs must have the same dtype"),
}
}
}
@@ -1795,20 +1826,23 @@ impl NativeOp for Mul {
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
);
match a {
NativeData::F32(a) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x * y))
match (a, b) {
(NativeData::F32(a), NativeData::F32(b)) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
}
NativeData::F16(a) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x * y))
(NativeData::F16(a), NativeData::F16(b)) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
}
NativeData::Bf16(a) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x * y))
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
}
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
(NativeData::Int(a), NativeData::Int(b)) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
(NativeData::Bool(_), NativeData::Bool(_)) => {
panic!("Cannot multiply Bool tensors, cast to F32 first")
}
_ => panic!("Mul inputs must have the same dtype"),
}
}
}
@@ -1882,20 +1916,21 @@ impl NativeOp for Mod {
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
);
match a {
NativeData::F32(a) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x % y))
match (a, b) {
(NativeData::F32(a), NativeData::F32(b)) => {
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
}
NativeData::F16(a) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x % y))
(NativeData::F16(a), NativeData::F16(b)) => {
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
}
NativeData::Bf16(a) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x % y))
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
}
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x % y))
(NativeData::Int(a), NativeData::Int(b)) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
}
NativeData::Bool(_) => panic!("Cannot mod Bool tensors"),
(NativeData::Bool(_), NativeData::Bool(_)) => panic!("Cannot mod Bool tensors"),
_ => panic!("Mod inputs must have the same dtype"),
}
}
}
@@ -1970,13 +2005,24 @@ impl NativeOp for LessThan {
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
);
// Comparison always returns Bool
NativeData::Bool(
a_ind
.zip(b_ind)
.map(|(i, j)| NativeData::f32(a, i) < NativeData::f32(b, j))
.collect(),
)
match (a, b) {
(NativeData::F32(a), NativeData::F32(b)) => {
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
}
(NativeData::F16(a), NativeData::F16(b)) => {
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
}
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
}
(NativeData::Int(a), NativeData::Int(b)) => {
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
}
(NativeData::Bool(a), NativeData::Bool(b)) => {
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| !x & y))
}
_ => panic!("LessThan inputs must have the same dtype"),
}
}
}
@@ -2708,16 +2754,7 @@ impl NativeData {
pub fn f32(&self, i: usize) -> f32 {
match self {
NativeData::F32(v) => v[i],
NativeData::F16(v) => v[i].to_f32(),
NativeData::Bf16(v) => v[i].to_f32(),
NativeData::Int(v) => v[i] as f32,
NativeData::Bool(v) => {
if v[i] {
1.0
} else {
0.0
}
}
_ => panic!("NativeData::f32 called on non-F32 data"),
}
}
@@ -2725,10 +2762,7 @@ impl NativeData {
pub fn f16(&self, i: usize) -> f16 {
match self {
NativeData::F16(v) => v[i],
NativeData::F32(v) => f16::from_f32(v[i]),
NativeData::Bf16(v) => f16::from_f32(v[i].to_f32()),
NativeData::Int(v) => f16::from_f32(v[i] as f32),
NativeData::Bool(v) => f16::from_f32(if v[i] { 1.0 } else { 0.0 }),
_ => panic!("NativeData::f16 called on non-F16 data"),
}
}
@@ -2736,10 +2770,7 @@ impl NativeData {
pub fn bf16(&self, i: usize) -> bf16 {
match self {
NativeData::Bf16(v) => v[i],
NativeData::F32(v) => bf16::from_f32(v[i]),
NativeData::F16(v) => bf16::from_f32(v[i].to_f32()),
NativeData::Int(v) => bf16::from_f32(v[i] as f32),
NativeData::Bool(v) => bf16::from_f32(if v[i] { 1.0 } else { 0.0 }),
_ => panic!("NativeData::bf16 called on non-Bf16 data"),
}
}
@@ -2747,16 +2778,7 @@ impl NativeData {
pub fn i32(&self, i: usize) -> i32 {
match self {
NativeData::Int(v) => v[i],
NativeData::F32(v) => v[i] as i32,
NativeData::F16(v) => v[i].to_f32() as i32,
NativeData::Bf16(v) => v[i].to_f32() as i32,
NativeData::Bool(v) => {
if v[i] {
1
} else {
0
}
}
_ => panic!("NativeData::i32 called on non-Int data"),
}
}
@@ -2764,10 +2786,50 @@ impl NativeData {
pub fn bool(&self, i: usize) -> bool {
match self {
NativeData::Bool(v) => v[i],
NativeData::F32(v) => v[i] != 0.0,
NativeData::F16(v) => v[i].to_f32() != 0.0,
NativeData::Bf16(v) => v[i].to_f32() != 0.0,
NativeData::Int(v) => v[i] != 0,
_ => panic!("NativeData::bool called on non-Bool data"),
}
}
pub fn to_f32_vec(&self) -> Vec<f32> {
match self {
NativeData::F32(v) => v.clone(),
NativeData::F16(v) => v.iter().map(|v| v.to_f32()).collect(),
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32()).collect(),
NativeData::Int(v) => v.iter().map(|v| *v as f32).collect(),
NativeData::Bool(v) => v.iter().map(|v| if *v { 1.0 } else { 0.0 }).collect(),
}
}
pub fn to_f16_vec(&self) -> Vec<f16> {
match self {
NativeData::F32(v) => v.iter().copied().map(f16::from_f32).collect(),
NativeData::F16(v) => v.clone(),
NativeData::Bf16(v) => v.iter().map(|v| f16::from_f32(v.to_f32())).collect(),
NativeData::Int(v) => v.iter().map(|v| f16::from_f32(*v as f32)).collect(),
NativeData::Bool(v) => v
.iter()
.map(|v| f16::from_f32(if *v { 1.0 } else { 0.0 }))
.collect(),
}
}
pub fn to_i32_vec(&self) -> Vec<i32> {
match self {
NativeData::F32(v) => v.iter().map(|v| *v as i32).collect(),
NativeData::F16(v) => v.iter().map(|v| v.to_f32() as i32).collect(),
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32() as i32).collect(),
NativeData::Int(v) => v.clone(),
NativeData::Bool(v) => v.iter().map(|v| if *v { 1 } else { 0 }).collect(),
}
}
pub fn to_bool_vec(&self) -> Vec<bool> {
match self {
NativeData::F32(v) => v.iter().map(|v| *v != 0.0).collect(),
NativeData::F16(v) => v.iter().map(|v| v.to_f32() != 0.0).collect(),
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32() != 0.0).collect(),
NativeData::Int(v) => v.iter().map(|v| *v != 0).collect(),
NativeData::Bool(v) => v.clone(),
}
}
}

View File

@@ -3,7 +3,7 @@ use std::fmt::Debug;
use crate::egglog_utils::{
extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
};
use crate::hlir::Output;
use crate::hlir::{Add as NativeAdd, LessThan as NativeLessThan, NativeData, NativeOp, Output};
use crate::prelude::*;
use candle_core::{Device, Tensor};
use proptest::prelude::*;
@@ -430,6 +430,34 @@ fn fuzz_test_genome_execution() {
// --- Consumed-input semantics tests ---
#[test]
#[should_panic(expected = "Add inputs must have the same dtype")]
fn native_add_rejects_mixed_dtypes() {
let op = NativeAdd {
shape: vec![2.into()],
a_strides: vec![1.into()],
b_strides: vec![1.into()],
input_shapes: vec![],
};
let a = NativeData::F32(vec![1.0, 2.0]);
let b = NativeData::Int(vec![1, 2]);
op.execute(vec![&a, &b], &FxHashMap::default());
}
#[test]
#[should_panic(expected = "LessThan inputs must have the same dtype")]
fn native_less_than_rejects_mixed_dtypes() {
let op = NativeLessThan {
shape: vec![2.into()],
a_strides: vec![1.into()],
b_strides: vec![1.into()],
input_shapes: vec![],
};
let a = NativeData::F32(vec![1.0, 2.0]);
let b = NativeData::Int(vec![1, 2]);
op.execute(vec![&a, &b], &FxHashMap::default());
}
#[test]
#[should_panic]
fn test_inputs_consumed_after_execute() {