Compare commits

...

343 Commits

Author SHA1 Message Date
Austin Glover
7ee5b54438 idiomatic changes 2026-04-06 23:05:06 +00:00
Austin Glover
389c05abeb auto rebuild, silence onnx, cache images. 2026-04-06 22:43:24 +00:00
Austin Glover
dcc2c9cbb4 hf tests 2026-04-06 22:42:58 +00:00
Austin Glover
a9af4c3923 test deps 2026-04-06 22:42:27 +00:00
Austin Glover
3092d0d68b skill slop 2026-04-06 22:29:34 +00:00
Austin Glover
8a2bd714ac test classes wip 2026-04-02 01:32:37 +00:00
Austin Glover
54a26a044c save codex data on container restart 2026-04-02 01:31:23 +00:00
Austin Glover
5a0d3f87cc Merge remote-tracking branch 'origin/main' into pytest-classes
# Conflicts:
#	crates/luminal_python/pyproject.toml
#	crates/luminal_python/tests/conftest.py
#	crates/luminal_python/tests/generate_llama38b_artifacts.py
#	crates/luminal_python/tests/test_llama3.py
2026-04-02 01:18:40 +00:00
Joe Fioti
a28b755245 Merge pull request #259 from luminal-ai/tucker_shared_pytorch_memory 2026-04-01 12:51:15 -07:00
Tucker Morgan
fd83534e53 Remove dead logical.rs stub from luminal_cuda_lite
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:58:12 +00:00
Tucker Morgan
b5d984c3fa Move KernelExp/KernelSigmoid to other_ops.rs and remove logical intermediaries
hlir.rs should only contain 1:1 HLIR op analogues. KernelExp and KernelSigmoid
are fused kernels, so they belong in other_ops.rs. Also removed the redundant
logical::Exp and logical::Sigmoid intermediary ops since the kernel ops match
HLIR patterns directly via their direct-fusion egglog rules.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:46:17 +00:00
Tucker Morgan
64a5ca41b5 Merge remote-tracking branch 'origin/main' into tucker_shared_pytorch_memory 2026-04-01 16:45:16 +00:00
Joe Fioti
9bda47714a Merge pull request #256 from luminal-ai/asglover/modal_ci_ready 2026-04-01 05:21:02 -07:00
Austin Glover
9e513b6589 Fix git safe.directory for pre-commit in CUDA clippy container
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:32:40 -07:00
Austin Glover
a62d728bd7 Fix CUDA clippy container image to luminal-docker
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:21:36 -07:00
Austin Glover
4114714d3f Rename clippy workflow to cuda-clippy and fix container image
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:17:47 -07:00
Austin Glover
6191597571 Remove Modal CUDA clippy job, now handled by T4 runner
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:10:17 -07:00
Austin Glover
253cd95ab0 Run clippy on T4 runner with CUDA container for full lint coverage
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:05:05 -07:00
Austin Glover
d7e396ba5b Gate Modal CI on 'modal-ready' label and convert CUDA tests to Modal
- Gate test-cuda.yml and test-python-cuda.yml behind 'modal-ready' label
- Convert CUDA clippy and unit tests from self-hosted runner to Modal
- Add ci/modal_cargo_test.py and ci/modal_cargo_clippy.py runners

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 16:03:17 -07:00
Joe Fioti
1a53626716 Merge pull request #260 from luminal-ai/nvidia-devcontainer-args 2026-03-31 15:55:21 -07:00
Austin Glover
4329d68adc Merge main and resolve workflow conflicts
Resolve conflicts from main's pre-commit migration and Modal pytest runner.
Split new lint jobs (ruff, ruff-format, metal-clippy) into individual files
and update test-python-cuda to use Modal runner from main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 15:44:29 -07:00
Tucker Morgan
989e7e2d44 Fixing native tests 2026-03-31 21:27:34 +00:00
Austin Glover
4f0a3ab102 Merge branch 'main' into nvidia-devcontainer-args 2026-03-31 13:52:26 -07:00
Tucker Morgan
019972cdd4 Fixing ruff lint issue 2026-03-31 20:46:17 +00:00
Tucker Morgan
d7a3f468bd Ruff formatting 2026-03-31 20:44:23 +00:00
Tucker Morgan
c504fbf8a1 Merge cleanip 2026-03-31 20:41:40 +00:00
Austin Glover
648720caf9 force a 12.8 cuda version of torch. 2026-03-31 20:33:44 +00:00
Tucker Morgan
625be7f4da Merge origin/main into tucker_shared_pytorch_memory
Resolved conflicts:
- other_ops.rs: kept kernel_rewrite import, dropped unused compile_kernel
- lib.rs: kept weight_device_ptrs param, added validate_backend call
- runtime.rs: accepted two-phase CUDA init helpers from main
- compiled_model.py: kept weight_refs/user_indices/is_cuda fields
- pt2.py: kept original_weights tracking for zero-copy
- test_llama3.py: kept xfail + device param for dynamic test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 20:27:30 +00:00
Austin Glover
21ed7ef31f gitignore claude , codex 2026-03-31 20:24:00 +00:00
Joe Fioti
6e94f80c9e Merge pull request #255 from luminal-ai/modal-pytest-runnner
Modal pytest runner
2026-03-31 12:33:49 -07:00
Tucker Morgan
c2a17a4854 Removing uneeded qwen3 moe test file 2026-03-31 19:04:29 +00:00
Austin Glover
386b3df983 try recommended nvidia container args 2026-03-31 18:57:14 +00:00
Tucker Morgan
5c60f1d768 Fixing up small things for review 2026-03-31 18:24:16 +00:00
Tucker Morgan
4c51e3ea84 Cargo fmt: 2026-03-31 16:44:03 +00:00
Tucker Morgan
846551aa6f Cargo clippy 2026-03-31 16:42:30 +00:00
Tucker Morgan
c26076bc75 Cargo fmt 2026-03-31 16:38:09 +00:00
Tucker Morgan
871629b770 fmt and clippy 2026-03-31 16:35:13 +00:00
Tucker Morgan
c6dfa9c62f Unify ONNX/PT2 compilation paths and extract shared helpers
Restructure so both ONNX and PT2 paths follow the same call flow:
  lib.rs (thin PyO3 wrapper)
    → onnx_translator.rs / pt2_compiled_model.rs (format-specific translate + compile)
      → compiled_graph.rs::parse_graph (shared backend pipeline)

Rust changes:
- Create onnx_translator.rs with compile_onnx() and translate_onnx()
  (moved from compiled_graph.rs and lib.rs)
- compiled_graph.rs now only contains shared code (GraphTranslation,
  WeightData, CompiledGraph, parse_graph)
- Cache label_map in CompiledGraph for O(1) set_weight_* lookups
- Move weight_device_ptrs into WeightData.device_ptrs
- Add search_iters param to process_onnx (parity with PT2)
- Fix .unwrap() → ? error propagation in ONNX file loading
- lib.rs reduced to thin PyO3 registration layer

Python changes:
- Extract _collect_weight_pointers(), _detect_backend(),
  _load_cpu_weights() shared helpers in main.py
- Both ONNX and PT2 paths use the same helpers
- Centralize _register_cache_serialization() in __init__.py
- CompiledModel: add input_names override, keep user_indices for
  torch.compile lifted-param filtering

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 23:29:15 +00:00
Tucker Morgan
90e3a915d7 Cargo fmt 2026-03-30 22:20:41 +00:00
Tucker Morgan
56cb237aa2 removing uneeded prints 2026-03-30 22:20:32 +00:00
Tucker Morgan
a2c42b35c8 Cleaning up qwen tests 2026-03-30 21:36:30 +00:00
Tucker Morgan
898204b2dd setting test right 2026-03-30 17:51:32 +00:00
Tucker Morgan
2c1a7f087f removing uneeded logs 2026-03-30 17:36:26 +00:00
Austin Glover
112d064700 remove unnecessary ignore 2026-03-28 00:39:34 +00:00
Austin Glover
c51c36fbcb add node for mcp servers 2026-03-28 00:37:46 +00:00
Austin Glover
ee372d464e ignore codex 2026-03-28 00:37:34 +00:00
Austin Glover
1bef1344d1 pytest native approach to caching 2026-03-28 00:36:41 +00:00
Austin Glover
2e27c29b47 Gate Modal CI on 'modal-ready' label and split workflows into one-job-per-file
Modal examples now only run on PRs when the 'modal-ready' label is applied,
preventing expensive GPU runs on every push. Split test.yml and lint.yml
into individual workflow files for clearer CI organization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-27 15:45:40 -07:00
Austin Glover
8d41c491fd test prototype 2026-03-27 01:36:37 +00:00
Austin Glover
64f390a833 silence maturin logs 2026-03-27 01:36:23 +00:00
Austin Glover
8d20581f38 testing infra 2026-03-27 01:35:56 +00:00
Austin Glover
bfd4ae9b27 temp 2026-03-27 01:35:49 +00:00
Tucker Morgan
92e4260f1e Fixing weight stripping issues 2026-03-26 21:31:48 +00:00
Tucker Morgan
662a564efc Cleaning up a set of changes 2026-03-26 18:39:57 +00:00
Tucker Morgan
1761dc6b66 Missed a directory 2026-03-26 18:05:28 +00:00
Tucker Morgan
da71273d7e Getting LLama tests closer to proper passing 2026-03-26 18:05:13 +00:00
Austin Glover
39122672b4 clippy fixes 2026-03-26 01:26:51 +00:00
Austin Glover
d866ba6407 plz 2026-03-26 01:05:50 +00:00
Austin Glover
9a0fb453ed plz 2026-03-26 01:00:39 +00:00
Austin Glover
dab60f0b21 simplify cuda clippy 2026-03-26 00:54:21 +00:00
Austin Glover
1ea872bd2a allow lots of args 2026-03-25 23:44:11 +00:00
Austin Glover
90a66ac704 metal clippy 2026-03-25 23:38:33 +00:00
Austin Glover
2b94ba0b71 retry 2026-03-25 23:26:56 +00:00
Austin Glover
2ed65d5386 precommit 2026-03-25 22:50:21 +00:00
Austin Glover
336d49c147 pre-commit stuff 2026-03-25 22:50:12 +00:00
Austin Glover
1ff5840a76 disable TF32 2026-03-25 22:41:42 +00:00
Austin Glover
bc94b10648 ruff changes 2026-03-25 22:34:45 +00:00
Tucker Morgan
7c921d03a8 Working weight sharing in both onnx and pt 2026-03-25 21:27:14 +00:00
Austin Glover
4e46051617 skip slow tests 2026-03-25 20:46:40 +00:00
Austin Glover
a55952d591 skip test_hf_llama3_1b_decode_loop_dynamic 2026-03-25 20:42:39 +00:00
Tucker Morgan
679aa7e092 Fixing up the onnx and fx parsing layer to share more of their code paths 2026-03-25 17:25:00 +00:00
Tucker Morgan
3dd2be2fb2 First pass of the new memory model 2026-03-25 15:59:06 +00:00
Austin Glover
c290e266f7 move dependency back to candle 2026-03-25 04:02:08 +00:00
Austin Glover
be3d8aa064 fix file permission issue w/ workaround 2026-03-25 04:01:53 +00:00
Austin Glover
84e0c842a1 ignore venv? 2026-03-25 03:37:37 +00:00
Austin Glover
403fd36b1f hugging face caching, all files on host 2026-03-25 02:06:28 +00:00
Austin Glover
651a4c2aee pass HF token 2026-03-25 01:45:09 +00:00
Austin Glover
41ddd244ef lowkirkenuenly 2026-03-25 00:54:09 +00:00
Austin Glover
0dbee87a8c Merge branch 'modal-pytest-runnner' of https://github.com/luminal-ai/luminal into modal-pytest-runnner 2026-03-25 00:41:24 +00:00
Austin Glover
194b8adfa5 satisfy clippy 2026-03-25 00:31:44 +00:00
Austin Glover
86e616800d Merge branch 'main' into modal-pytest-runnner 2026-03-24 17:23:51 -07:00
Austin Glover
683205121d add lots of stuff for creating profiling outputs 2026-03-25 00:17:05 +00:00
Austin Glover
6d653e854d add profiling stuff 2026-03-25 00:16:38 +00:00
Austin Glover
08397b566d compile to cubin 2026-03-25 00:16:25 +00:00
Austin Glover
5310335256 force cuda 12.8 for modal 2026-03-25 00:12:55 +00:00
Austin Glover
638765b62b add artifacts 2026-03-25 00:12:35 +00:00
Austin Glover
3850b3a533 chatgpt 2026-03-25 00:12:12 +00:00
Austin Glover
a4c84c6cf5 force CUDA 12.8 on modal 2026-03-25 00:12:00 +00:00
Joe Fioti
f32161d43b Merge pull request #254 from luminal-ai/tucker_fx_parsing_integration
Add PT2/FX graph parsing integration with bucketed compilation
2026-03-24 14:07:11 -07:00
Austin Glover
da83d51b27 extend timeout 2026-03-24 01:07:40 +00:00
Austin Glover
29a3ffa3e3 wip 2026-03-23 23:22:38 +00:00
Austin Glover
97e358916a simplify runner 2026-03-23 23:20:51 +00:00
Tucker Morgan
631c1b53d7 Fixing bad test config 2026-03-23 22:48:05 +00:00
Austin Glover
02449a6bea move CICD to go through modal 2026-03-23 22:18:08 +00:00
Tucker Morgan
6c4597102e Fixing test 2026-03-23 22:02:41 +00:00
Austin Glover
b077cfdb76 reduce diff 2026-03-23 21:50:47 +00:00
Tucker Morgan
869b519e39 Fixing up tests 2026-03-23 21:50:40 +00:00
Tucker Morgan
2b831c9f25 cargo fmt 2026-03-23 21:10:44 +00:00
Tucker Morgan
f35a950496 Running a review to clean things up 2026-03-23 21:08:15 +00:00
Tucker Morgan
9ab0e1472c Fmt and Clippy 2026-03-23 18:18:08 +00:00
Tucker Morgan
88f2601d5e Preclippy push 2026-03-23 18:05:25 +00:00
Tucker Morgan
b0ebdcba8c Preclipply fix 2026-03-23 18:04:55 +00:00
Tucker Morgan
0ab124194b Pre commit cleanup 2026-03-23 17:57:51 +00:00
Austin Glover
7f042ae615 Merge remote-tracking branch 'origin/main' into tucker-branch-austin-wip
# Conflicts:
#	crates/luminal_python/.gitignore
#	crates/luminal_python/rust/src/ops_parse/unary.rs
2026-03-23 17:52:46 +00:00
Joe Fioti
082d9c48bd renamed progress bar ui 2026-03-23 05:57:03 +00:00
Tucker Morgan
251e9526f3 Merge remote-tracking branch 'origin/main' into tucker_fx_parsing_integration 2026-03-22 21:34:49 +00:00
Tucker Morgan
41b3774ec2 Getting the errors on most models to pass 1e-5 2026-03-22 21:32:02 +00:00
Joe Fioti
3fdb464f5a fmt 2026-03-21 18:58:02 +00:00
Austin Glover
a3a4fd94ec ignore uv.lock 2026-03-20 23:55:55 +00:00
Austin Glover
5446dccb04 clean up conf test 2026-03-20 23:47:40 +00:00
Austin Glover
8e6535563e clear errors 2026-03-20 23:46:57 +00:00
Austin Glover
bdc923aa50 add logic to prevent unclear errors 2026-03-20 23:46:38 +00:00
Austin Glover
ea67742b3b ignore uv.lock 2026-03-20 23:45:52 +00:00
Austin Glover
149e570f26 modal in dev, not sure if the module_name is necessary? 2026-03-20 23:45:14 +00:00
Tucker Morgan
ac52098d5c Fixing brioken llama issue 2026-03-20 23:06:16 +00:00
Austin Glover
01946ecd10 modal 2026-03-20 22:57:02 +00:00
Tucker Morgan
ef70fee204 Work to get the early version of llama to work 2026-03-20 22:24:06 +00:00
Austin Glover
a6fea110dc ignore 2026-03-20 20:59:20 +00:00
Austin Glover
2d4ebb2cb6 sure 2026-03-20 20:52:15 +00:00
Austin Glover
5adb875b04 ignore cargo-local 2026-03-20 20:50:55 +00:00
Joe Fioti
2adfcfa70e bucketed compilation 2026-03-20 18:33:15 +00:00
Tucker Morgan
65600e8730 Cargo fmt 2026-03-20 17:45:40 +00:00
Tucker Morgan
4c4f39b4af A quick review pass 2026-03-20 17:45:06 +00:00
Tucker Morgan
49b9209ad0 All basic tests passing for both Onnx and FX 2026-03-20 00:08:57 +00:00
Austin Glover
dea5df51dd wip 2026-03-19 23:59:12 +00:00
Austin Glover
eb6a6c2174 ignore luminal artifacts 2026-03-19 23:59:05 +00:00
Austin Glover
8864ef31fb ignore claude 2026-03-19 23:58:23 +00:00
Austin Glover
38c98a8835 fix devcontainers 2026-03-19 23:58:13 +00:00
Tucker Morgan
31b5fd886d First pass, copy over scartch pad work, easy test fixes 2026-03-19 22:57:32 +00:00
Joe Fioti
04b2753aa8 fixed test 2026-03-19 15:54:21 -07:00
Joe Fioti
dbb5282fd6 removed kimi submodule 2026-03-19 15:48:00 -07:00
Joe Fioti
8e315c62df Merge pull request #240 from luminal-ai/tucker_pytorch_integegration_layer
PyTorch Integration Layer (luminal_python)
2026-03-19 13:40:46 -07:00
Tucker Morgan
f53d990581 Adding Flatten 2026-03-19 20:18:15 +00:00
Tucker Morgan
2c8ecba6a5 Flatten function on ShapeTracker and GraphTensor 2026-03-19 18:51:30 +00:00
Tucker Morgan
1644cce031 Clippy and fmt 2026-03-19 18:20:57 +00:00
Tucker Morgan
b2bb455b30 Removing as many forced materlizaing as possibling 2026-03-19 18:18:55 +00:00
Tucker Morgan
8628b1425a Actually removing the uv lock 2026-03-19 16:36:01 +00:00
Tucker Morgan
ca66609d6f cargo fmt 2026-03-18 23:39:03 +00:00
Tucker Morgan
c50e122ac1 making scatter element and gather elemnets to be simpler 2026-03-18 23:37:19 +00:00
Tucker Morgan
272acabd0c Small fixes from comments 2026-03-18 23:16:27 +00:00
Tucker Morgan
f772c0529a Removing GPU dashboard 2026-03-18 22:28:49 +00:00
Tucker Morgan
c3e1f568ea Wrong name for docket imag 2026-03-18 21:55:16 +00:00
Tucker Morgan
eb3dd02836 updating docker image 2026-03-18 21:44:47 +00:00
Tucker Morgan
8a9f85b0ce Cargo fmt fixes 2026-03-18 21:27:06 +00:00
Tucker Morgan
372501e527 Fixing clippy and review issues 2026-03-18 21:23:55 +00:00
Tucker Morgan
cda12a6d84 Clippy and fmt fixes 2026-03-18 21:04:33 +00:00
Tucker Morgan
566fb00ed2 Fixes in bad tests 2026-03-18 21:03:20 +00:00
Tucker Morgan
ecb78a2635 Fixing up issues left over from a merge 2026-03-18 18:56:39 +00:00
Tucker Morgan
f401ffb900 Merge remote-tracking branch 'origin/main' into tucker_pytorch_integegration_layer 2026-03-18 16:36:08 +00:00
Joe Fioti
cd94000140 Merge pull request #253 from luminal-ai/matmul-flattening-minimal
Matmul flattening rules for cuBLAS matching
2026-03-18 09:35:42 -07:00
Austin Glover
1f1636e188 cargo fmt
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 01:20:55 +00:00
Austin Glover
371fa8491a Update matmul flattening rules for Op() wrapper syntax
Adapt the 3 egglog flattening rules (squeeze, batch_merge_a_contig,
batch_merge_b_contig) to use the new Op(...) (ICons ...) wrapper
syntax from the HLIR refactor, so they match the current egglog
representation of Mul and Sum nodes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 01:04:37 +00:00
Austin Glover
6d1fe67b66 flatten matmuls 2026-03-18 00:58:55 +00:00
Austin Glover
189d1e2594 match on flattened matmuls 2026-03-18 00:58:55 +00:00
Austin Glover
c7acfb9794 flatten matmul when applicable 2026-03-18 00:58:33 +00:00
Tucker Morgan
fe6af5290a Merge remote-tracking branch 'origin/main' into tucker_pytorch_integegration_layer
# Conflicts:
#	.github/workflows/test.yml
#	crates/luminal_cuda/src/block/ops.rs
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg
#	crates/luminal_cuda/src/kernel/other_ops.rs
#	crates/luminal_cuda_lite/src/kernel/hlir.rs
2026-03-17 17:19:04 +00:00
Joe Fioti
8b8669c744 Merge pull request #252 from luminal-ai/hlir_enhancements
tweaks
2026-03-16 21:54:09 -07:00
Joe Fioti
7af771b999 tweaks 2026-03-17 04:46:27 +00:00
Joe Fioti
133757f187 Merge pull request #250 from xiaoniaoyouhuajiang/feature/metal-mega-kernel
metal backend optimization for transformer
2026-03-13 16:16:31 -07:00
xiaoniaoyouhuajiang
07ee241b25 merge upstream work 2026-03-13 17:01:07 +08:00
Joe Fioti
958331ab6c Merge pull request #251 from luminal-ai/hlir_enhancements
Hlir enhancements
2026-03-12 22:28:33 -07:00
Joe Fioti
340199d4a8 fixed cuda testss 2026-03-13 05:16:09 +00:00
Joe Fioti
f17a95e673 metal fixes 2026-03-12 20:24:07 -07:00
Joe Fioti
6bb576e711 more test fixes 2026-03-13 03:10:37 +00:00
Joe Fioti
744e4d767a changed actions 2026-03-13 03:08:29 +00:00
Joe Fioti
c940161f25 fmt 2026-03-12 23:27:12 +00:00
Joe Fioti
3aa2c309f5 clippy 2026-03-12 23:24:13 +00:00
Joe Fioti
8a0592646b added qwen3 moe 2026-03-12 23:10:32 +00:00
Joe Fioti
68ce81e52b Merge remote-tracking branch 'origin/main' into hlir_enhancements 2026-03-12 22:37:07 +00:00
Joe Fioti
39789404f4 Merge pull request #248 from luminal-ai/modal-ci
Add CI workflows: lint, CUDA tests, Modal llama
2026-03-12 15:28:34 -07:00
Joe Fioti
8c53234966 Merge branch 'main' into modal-ci 2026-03-12 14:45:30 -07:00
Joe Fioti
71eca945cb changed luminal_cuda_lite 2026-03-12 21:34:13 +00:00
Joe Fioti
8da130ae1c Merge remote-tracking branch 'origin/main' into hlir_enhancements 2026-03-12 20:33:09 +00:00
Joe Fioti
fef6a45c9c paged attention llama example 2026-03-12 20:29:39 +00:00
Joe Fioti
c6763a69ba fixed metal ci 2026-03-12 11:24:45 -07:00
Joe Fioti
30caca106c fixed metal ci 2026-03-12 11:16:45 -07:00
Joe Fioti
6c90bb5059 metal ci/cd 2026-03-12 11:01:04 -07:00
xiaoniaoyouhuajiang
82189cd602 threadgroup staging 2026-03-12 18:56:05 +08:00
xiaoniaoyouhuajiang
cc5e0a639d support simdgroup micro-kernel 2026-03-12 18:23:50 +08:00
xiaoniaoyouhuajiang
8dc05233cb support tiled matmul for gemm 2026-03-12 17:32:19 +08:00
xiaoniaoyouhuajiang
0ab9947292 matmul family template & match pipeline 2026-03-12 16:56:02 +08:00
Joe Fioti
f11ba3a388 added tensor.repeat 2026-03-11 22:35:32 -07:00
Austin Glover
a346e503db make it so draft PRs don't trigger 2026-03-11 18:16:10 +00:00
xiaoniaoyouhuajiang
6bbf244924 support f16 for metal 2026-03-11 14:02:48 +08:00
Joe Fioti
a8505668ac converted gemma and qwen 2026-03-11 03:39:51 +00:00
Austin Glover
a0b237c424 use new images 2026-03-11 02:03:39 +00:00
Austin Glover
8fabacd17e manually override build 2026-03-11 01:50:21 +00:00
Austin Glover
df0128ad04 pass example str 2026-03-11 01:28:04 +00:00
Austin Glover
de55e67594 lock to cuda 13 2026-03-11 01:22:22 +00:00
Austin Glover
cdff26755f fmt 2026-03-11 01:17:01 +00:00
Austin Glover
9f11b7e24a try matrix formulation 2026-03-11 01:15:54 +00:00
Austin Glover
27344a0e45 updating CI 2026-03-11 00:57:56 +00:00
Joe Fioti
e9e6f824a1 fixed memory leak 2026-03-10 23:08:28 +00:00
Austin Glover
d894eeae50 remove caching 2026-03-10 22:45:25 +00:00
Joe Fioti
da078b5bdd fixed in-place scatter for llama, cleaned up cleanup rules 2026-03-10 21:26:35 +00:00
Tucker Morgan
f156265ff4 Add TileMatmul broadcast constraints and expand qwen image tests
Add (= ?a_n_stride (MNum 0)) and (= ?b_m_stride (MNum 0)) constraints
to TileMatmulSplitK and TileMatmulFullSplit to prevent matching
element-wise Mul+Sum patterns (e.g. x*x from LayerNorm) as matmuls.
Also refactor qwen image tests to use torch.compile backend for
transformer tests, add medium configs, and add full production-scale
test cases.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-10 18:04:54 +00:00
Joe Fioti
b34f104cea normalized egglog ops 2026-03-10 04:49:41 +00:00
Austin Glover
1873e26185 fix: add Modal environment to modal-llama workflow for secrets access
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 00:22:14 +00:00
Austin Glover
384a426ba3 modal prototype 2026-03-09 23:58:47 +00:00
Tucker Morgan
d25654a0ec Small fixes to qwen image 2026-03-09 20:54:34 +00:00
Tucker Morgan
579daa1a57 Merge origin/main into tucker_pytorch_integegration_layer
Resolved conflicts:
- other_ops.rs: kept parallelized reduction loop from branch
- base.rs: kept existing commutativity rules, avoided duplicate add-comm

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-09 17:48:29 +00:00
Tucker Morgan
4fa8a92086 Pushed from only the luminal_python crate 2026-03-09 17:46:25 +00:00
Tucker Morgan
bf7debc1d6 Various fixes after a merge 2026-03-09 17:45:54 +00:00
Joe Fioti
a49c970029 hlir attn 2026-03-09 06:05:24 +00:00
xiaoniaoyouhuajiang
1d2db8f88f Merge branch 'feature/metal-mega-kernel' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-mega-kernel 2026-03-09 10:06:33 +08:00
Tucker Morgan
4fff906b8d Merge branch 'tucker_pytorch_integegration_layer' of https://github.com/luminal-ai/luminal into tucker_pytorch_integegration_layer
# Conflicts:
#	.github/workflows/test.yml
#	crates/luminal_python/LessonsLearned.md
#	crates/luminal_python/run_tests_cuda.sh
2026-03-08 02:10:05 +00:00
Tucker Morgan
85b49018a3 KernelEmbed failure fix 2026-03-08 02:02:52 +00:00
Joe Fioti
a82c530ae7 consumable inputs 2026-03-07 20:28:27 +00:00
Joe Fioti
8fdfac19e1 fixed model examples 2026-03-07 19:00:47 +00:00
Joe Fioti
c61dfa0a13 merge dims 2026-03-07 00:59:53 +00:00
Tucker Morgan
3815a6de67 Working on dynamic caching handling from transformers 2026-03-06 22:48:22 +00:00
Tucker Morgan
375c54b641 Other part of i64 converstion 2026-03-06 22:48:22 +00:00
Tucker Morgan
1dd8853a60 Fixing up Num expression in Term to use i64 not i32 2026-03-06 22:48:22 +00:00
Tucker Morgan
050eeba815 Fixing borken topk on cuda 2026-03-06 22:48:22 +00:00
Tucker Morgan
05e0a2fc31 Fixing llama3 decode loop so it odes not crash 2026-03-06 22:48:21 +00:00
Tucker Morgan
af7d0b002f Working LLama3 1b instruct 2026-03-06 22:48:21 +00:00
Tucker Morgan
f6c3b68f86 Fixing broken tests 2026-03-06 22:48:21 +00:00
Tucker Morgan
afc17eac31 Fix for failing reduction matmul tests, adjusting z stride 2026-03-06 22:47:20 +00:00
Tucker Morgan
5330ab9159 Dynamic Shapes 2026-03-06 22:47:20 +00:00
Tucker Morgan
ea2cd9da45 Scatter ND cuda test fix 2026-03-06 22:47:20 +00:00
Tucker Morgan
92957228f5 inlining parsing functions into dispatch.rs 2026-03-06 22:47:19 +00:00
Tucker Morgan
4152c8a732 Locking onnx opset to 20, reject all models that are not that explict model version 2026-03-06 22:47:19 +00:00
Tucker Morgan
d95e847dba removing unsupported datatype constants 2026-03-06 22:47:19 +00:00
Tucker Morgan
153ac59773 removing eprint 2026-03-06 22:47:19 +00:00
Tucker Morgan
bcd1fe673a fixing onehot test 2026-03-06 22:47:19 +00:00
Tucker Morgan
99d65ae9df Fixing Slice2d test 2026-03-06 22:47:19 +00:00
Tucker Morgan
af9442bfa0 Llama3 2026-03-06 22:47:18 +00:00
Tucker Morgan
3414a9fe6c Pushing local changes 2026-03-06 22:47:18 +00:00
Tucker Morgan
d3a60de7fa Just creating a nice point for me 2026-03-06 22:47:18 +00:00
Tucker Morgan
63bad7d5a2 Gather Elements, LayerNorm, Gemm, IsNan, and Expand 2026-03-06 22:47:18 +00:00
Tucker Morgan
69e143f33d Adjusting how we determine where to compile the model 2026-03-06 22:47:18 +00:00
Tucker Morgan
ca9113b8a8 Fixing up uses of to axes 2026-03-06 22:47:17 +00:00
Tucker Morgan
4e27a9fb31 Refactor parsing functions to reuse as code 2026-03-06 22:47:17 +00:00
Tucker Morgan
ad51ad88cd Ceil and Cargo fmt 2026-03-06 22:47:17 +00:00
Tucker Morgan
9afc41e0a1 Reduce Mean 2026-03-06 22:47:17 +00:00
Tucker Morgan
0cbe874492 ReduceMin handling and tests 2026-03-06 22:47:17 +00:00
Tucker Morgan
987ac5b5ec Trilu trill handling and tests 2026-03-06 22:47:17 +00:00
Tucker Morgan
d4faf80cf9 Added common condtionals, Not, And, Or, Xoe, LessOrEqual, GreaterOrEqual 2026-03-06 22:47:17 +00:00
Tucker Morgan
8ce60f61ec Soft max function and tests 2026-03-06 22:47:16 +00:00
Tucker Morgan
3a9c945e71 Min and max handling 2026-03-06 22:47:16 +00:00
Tucker Morgan
765ddd5070 Adding where node handling and tests 2026-03-06 22:47:16 +00:00
Tucker Morgan
4178b38eb0 Added pow handling and testing 2026-03-06 22:47:16 +00:00
Tucker Morgan
516faaa205 Fixed mod tests failing by changing from using mod operator to fmod function 2026-03-06 22:47:16 +00:00
Tucker Morgan
5033efc8d7 Clip and Equal Nodes 2026-03-06 22:47:16 +00:00
Tucker Morgan
de659b4a06 Removing unsued code 2026-03-06 22:47:16 +00:00
Tucker Morgan
612c76bd7c Rust format issues 2026-03-06 22:47:15 +00:00
Tucker Morgan
2230b28751 Trigger CI 2026-03-06 22:47:15 +00:00
Tucker Morgan
ae3e2c9331 Disable mod test 2026-03-06 22:47:15 +00:00
Tucker Morgan
f09c2ddd68 Fix all 30 clippy errors in luminal_python crate
- Remove unused imports (onnx_protobuf::Message, protobuf::MessageField,
  std::sync::Arc, get_dtype_for_onnx_value)
- Gate cuda-only code with #[cfg(feature = "cuda")]
  (transpose_weight_data, Arc import, runtime.rs)
- Prefix unused variables/params with _ (_known_values, _input_tensor_names)
- Remove needless return statements in binary.rs and unary.rs (13 total)
- Replace never-loop pattern in compiled_graph.rs with direct panic
- Replace result = result * x with result *= x in movement.rs
- Replace manual memcpy loop with copy_from_slice in movement.rs
- Replace needless range loops with iterator form in reduction.rs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 22:47:15 +00:00
Tucker Morgan
dae508f26c Add maturin to dev dependency-groups so uv run can find it in CI
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 22:47:15 +00:00
Tucker Morgan
9b970e9d01 Fixing up fmt issues that was casuing ci/cd to fail 2026-03-06 22:47:15 +00:00
Tucker Morgan
c14c7f4be4 Added Abs, Sigmoid, Relu, and Tanh, and tests for them 2026-03-06 22:47:15 +00:00
Tucker Morgan
f60147f4a7 Adding tests to cicd 2026-03-06 22:47:14 +00:00
Tucker Morgan
2f2ba2fb1f breaking up ops parse into different files, added pytest random 2026-03-06 22:46:03 +00:00
Tucker Morgan
695df9cb29 Removing unsued parameters, adding in trace! 2026-03-06 22:46:02 +00:00
Tucker Morgan
97af33e671 Sum and Max Reduce 2026-03-06 22:46:02 +00:00
Tucker Morgan
c1824a1e8d Squeeze parsing and tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
96cfdd9fcf Fixed broken cuda tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
e4d2be9c89 Fixing a bunch of changes 2026-03-06 22:46:02 +00:00
Tucker Morgan
10ea5be27d Fixed native tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
3d5c939a7f fixing issue with floor op for mod 2026-03-06 22:46:02 +00:00
Tucker Morgan
4ba5b3b121 Handling the less node and tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
2dc8e23496 Reshape, Shape, ops parsing added, tests as well 2026-03-06 22:46:01 +00:00
Tucker Morgan
9798fd0e38 Added Mod and tests for Mod op 2026-03-06 22:46:01 +00:00
Tucker Morgan
cf36d45b77 Fixing up cuda cast tests 2026-03-06 22:46:01 +00:00
Tucker Morgan
dccda51f30 Adding handling constant nodes 2026-03-06 22:46:01 +00:00
Tucker Morgan
34e9f6f5e8 Added transpose op and tests 2026-03-06 22:46:01 +00:00
Tucker Morgan
ba1aa9575c Cleaning up test files 2026-03-06 22:46:01 +00:00
Tucker Morgan
77a68d3091 Adding a way to select which backend is used for compile 2026-03-06 22:46:01 +00:00
Tucker Morgan
aaaa476352 Adding Sqrt node handling 2026-03-06 22:46:00 +00:00
Tucker Morgan
879e4203af Extending to use cuda, and fixing up luminal_python project structure so it follows best practices 2026-03-06 22:46:00 +00:00
tucker-luminal
f5185b86d6 Adding cuda support 2026-03-06 22:46:00 +00:00
tucker-luminal
7bf2543909 Add and Sub first pass
Add and sub first pass, pushing this so I can test on a cuda enviroment
2026-03-06 22:46:00 +00:00
Tucker Morgan
00b797e281 Working on dynamic caching handling from transformers 2026-03-06 22:45:05 +00:00
Joe Fioti
41f6a64746 clippy fix 2026-03-06 22:38:08 +00:00
Joe Fioti
2556bfa90b merged 2026-03-06 22:35:15 +00:00
Joe Fioti
32756f04c1 Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-03-06 22:31:34 +00:00
Joe Fioti
eaf8ba9219 fixed z-strides 2026-03-06 22:31:30 +00:00
Joe Fioti
b37c06d9b9 fmt 2026-03-06 11:52:27 -08:00
Joe Fioti
dd85e23e60 removed shapetracker on hlir edges 2026-03-06 11:48:45 -08:00
xiaoniaoyouhuajiang
a0a162049e make dyn dim = n_inputs +1 2026-03-06 17:11:36 +08:00
Tucker Morgan
26ffc23c12 Other part of i64 converstion 2026-03-05 23:55:49 +00:00
Tucker Morgan
df3f6b539b Fixing up Num expression in Term to use i64 not i32 2026-03-05 23:31:21 +00:00
Tucker Morgan
4d2bb35e9e Fixing borken topk on cuda 2026-03-05 21:54:24 +00:00
Tucker Morgan
fa6991d8fb Fixing llama3 decode loop so it odes not crash 2026-03-05 20:10:10 +00:00
Tucker Morgan
45a4e8c617 Working LLama3 1b instruct 2026-03-04 20:52:54 +00:00
Tucker Morgan
c6ee7e2f21 Fixing broken tests 2026-03-03 22:03:21 +00:00
Tucker Morgan
335ec78d3e Fix for failing reduction matmul tests, adjusting z stride 2026-03-03 20:11:39 +00:00
Tucker Morgan
7f760ad847 Dynamic Shapes 2026-03-03 18:54:42 +00:00
Tucker Morgan
a2e9b5209f Scatter ND cuda test fix 2026-03-03 18:54:42 +00:00
Tucker Morgan
d18e93d671 inlining parsing functions into dispatch.rs 2026-03-03 18:54:42 +00:00
Tucker Morgan
5a432c717c Locking onnx opset to 20, reject all models that are not that explict model version 2026-03-03 18:54:41 +00:00
Tucker Morgan
a15ae2d41c removing unsupported datatype constants 2026-03-03 18:54:41 +00:00
Tucker Morgan
2a03239c8d removing eprint 2026-03-03 18:54:41 +00:00
Tucker Morgan
9b11ce16ba fixing onehot test 2026-03-03 18:54:41 +00:00
Tucker Morgan
428d28e307 Fixing Slice2d test 2026-03-03 18:54:41 +00:00
Tucker Morgan
8f4680e7c0 Llama3 2026-03-03 18:54:41 +00:00
Tucker Morgan
4cb7c2c337 Pushing local changes 2026-03-03 18:54:40 +00:00
Tucker Morgan
38be807c11 Just creating a nice point for me 2026-03-03 18:54:40 +00:00
Tucker Morgan
26902c3cbf Gather Elements, LayerNorm, Gemm, IsNan, and Expand 2026-03-03 18:54:40 +00:00
Tucker Morgan
8550e98370 Adjusting how we determine where to compile the model 2026-03-03 18:54:40 +00:00
Tucker Morgan
b47aa9f4c5 Fixing up uses of to axes 2026-03-03 18:54:39 +00:00
Tucker Morgan
0264290091 Refactor parsing functions to reuse as code 2026-03-03 18:54:39 +00:00
Tucker Morgan
d3c8bbb838 Ceil and Cargo fmt 2026-03-03 18:54:39 +00:00
Tucker Morgan
09af10df2d Reduce Mean 2026-03-03 18:54:39 +00:00
Tucker Morgan
a3bc233cc9 ReduceMin handling and tests 2026-03-03 18:54:39 +00:00
Tucker Morgan
472148ec2e Trilu trill handling and tests 2026-03-03 18:54:39 +00:00
Tucker Morgan
202ee2e4b4 Added common condtionals, Not, And, Or, Xoe, LessOrEqual, GreaterOrEqual 2026-03-03 18:52:46 +00:00
Tucker Morgan
a52c90de55 Soft max function and tests 2026-03-03 18:52:46 +00:00
Tucker Morgan
d0fc4e528b Min and max handling 2026-03-03 18:52:46 +00:00
Tucker Morgan
3c51d2cd6e Adding where node handling and tests 2026-03-03 18:52:46 +00:00
Tucker Morgan
a0783896a1 Added pow handling and testing 2026-03-03 18:52:46 +00:00
Tucker Morgan
64700aa2a8 Fixed mod tests failing by changing from using mod operator to fmod function 2026-03-03 18:52:45 +00:00
Tucker Morgan
c7325f5590 Clip and Equal Nodes 2026-03-03 18:52:45 +00:00
Tucker Morgan
e1f8d2366f Removing unsued code 2026-03-03 18:52:45 +00:00
Tucker Morgan
bfee79d764 Rust format issues 2026-03-03 18:52:45 +00:00
Tucker Morgan
aab5db7d10 Trigger CI 2026-03-03 18:52:45 +00:00
Tucker Morgan
840ab6abd2 Disable mod test 2026-03-03 18:52:45 +00:00
Tucker Morgan
4079eebf1f Fix all 30 clippy errors in luminal_python crate
- Remove unused imports (onnx_protobuf::Message, protobuf::MessageField,
  std::sync::Arc, get_dtype_for_onnx_value)
- Gate cuda-only code with #[cfg(feature = "cuda")]
  (transpose_weight_data, Arc import, runtime.rs)
- Prefix unused variables/params with _ (_known_values, _input_tensor_names)
- Remove needless return statements in binary.rs and unary.rs (13 total)
- Replace never-loop pattern in compiled_graph.rs with direct panic
- Replace result = result * x with result *= x in movement.rs
- Replace manual memcpy loop with copy_from_slice in movement.rs
- Replace needless range loops with iterator form in reduction.rs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 18:52:44 +00:00
Tucker Morgan
0ad7b4e509 Add maturin to dev dependency-groups so uv run can find it in CI
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 18:52:44 +00:00
Tucker Morgan
60206e362c Fixing up fmt issues that was casuing ci/cd to fail 2026-03-03 18:52:44 +00:00
Tucker Morgan
ce72c5223b Added Abs, Sigmoid, Relu, and Tanh, and tests for them 2026-03-03 18:52:44 +00:00
Tucker Morgan
d0d958d69d Adding tests to cicd 2026-03-03 18:52:44 +00:00
Tucker Morgan
e554532108 breaking up ops parse into different files, added pytest random 2026-03-03 18:52:44 +00:00
Tucker Morgan
4ebb762724 Removing unsued parameters, adding in trace! 2026-03-03 18:52:43 +00:00
Tucker Morgan
734787b7c4 Sum and Max Reduce 2026-03-03 18:52:43 +00:00
Tucker Morgan
97e7e6e3f2 Squeeze parsing and tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
adf1e6be8b Fixed broken cuda tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
dcf49e3a61 Fixing a bunch of changes 2026-03-03 18:52:43 +00:00
Tucker Morgan
a3ff32d7a7 Fixed native tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
8760818756 fixing issue with floor op for mod 2026-03-03 18:52:42 +00:00
Tucker Morgan
e4c153b388 Handling the less node and tests 2026-03-03 18:52:42 +00:00
Tucker Morgan
6470b74a9d Reshape, Shape, ops parsing added, tests as well 2026-03-03 18:52:42 +00:00
Tucker Morgan
e2cf926e5a Added Mod and tests for Mod op 2026-03-03 18:52:42 +00:00
Tucker Morgan
97766c59cf Fixing up cuda cast tests 2026-03-03 18:52:42 +00:00
Tucker Morgan
379ea51474 Adding handling constant nodes 2026-03-03 18:52:41 +00:00
Tucker Morgan
4e87658460 Added transpose op and tests 2026-03-03 18:52:41 +00:00
Tucker Morgan
2b16cdea24 Cleaning up test files 2026-03-03 18:52:41 +00:00
Tucker Morgan
bc7d5e8b14 Adding a way to select which backend is used for compile 2026-03-03 18:52:41 +00:00
Tucker Morgan
4e15740876 Adding Sqrt node handling 2026-03-03 18:52:41 +00:00
Tucker Morgan
53a056ba8a Extending to use cuda, and fixing up luminal_python project structure so it follows best practices 2026-03-03 18:52:41 +00:00
tucker-luminal
615f43ed05 Adding cuda support 2026-03-03 18:52:40 +00:00
tucker-luminal
5847810f52 Add and Sub first pass
Add and sub first pass, pushing this so I can test on a cuda enviroment
2026-03-03 18:52:40 +00:00
181 changed files with 32676 additions and 10418 deletions

View File

@@ -0,0 +1,130 @@
---
name: aoti-debug
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
---
# AOTInductor Debugging
Debug errors when compiling and deploying PyTorch models with AOTInductor.
## First Step: Always Check Device and Shape Matching
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
```python
# Compilation -- note the device and shapes
model = MyModel().eval().cuda()
inp = torch.randn(2, 10, device="cuda")
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
# Loading -- device type MUST match compilation
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
# Inference -- device and shapes MUST match
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
```
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
## Current vs Deprecated API
### Current API (use this)
```python
torch._inductor.aoti_compile_and_package() # compile
torch._inductor.aoti_load_package() # load (auto-detects device)
```
### Deprecated API (migrate away)
```python
torch._export.aot_compile() # deprecated
torch._export.aot_load() # deprecated
```
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
## Common Error Patterns
### Device Mismatch Segfault
**Symptom**: Segfault, exception, or crash during load or execution.
**Example errors**:
- `The specified pointer resides on host memory and is not registered with any CUDA device`
- Crash during constant loading
- `Expected out tensor to have device cuda:0, but got cpu instead`
**Solution**: Ensure compile and load use the same device type.
### Input Device Mismatch at Runtime
**Symptom**: RuntimeError during model execution.
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
```
Produces actionable messages like:
```
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
```
## Debugging CUDA Illegal Memory Access (IMA)
### Step 1: Sanity Checks
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
```
Both flags take effect at **compile time** (codegen time).
### Step 2: Make IMA Deterministic
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
Both take effect at **runtime**.
### Step 3: Identify the Problematic Kernel
```bash
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
```
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
To inspect inputs to specific kernels:
```bash
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
```
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
## Environment Variables Reference
| Variable | When | Purpose |
|---|---|---|
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
## Common Sources of Issues
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.

View File

@@ -0,0 +1,195 @@
---
name: pt2-debug
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
---
# PyTorch 2 Compile Debugging
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
## Diagnostic Environment Variables
Pick the right diagnostic based on the error:
| Command | When to use |
|---|---|
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
## Error Triage
Classify the failure and jump to the right section:
| Error Pattern | Category |
|---|---|
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
## Graph Breaks
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
**Diagnose:**
```bash
TORCH_LOGS="graph_breaks" python script.py
```
**Common causes:**
- Data-dependent control flow
- Unsupported Python builtins
- In-place ops on inputs, unsupported dtypes
- Calls to non-traceable functions
**Fix approaches:**
1. Read the graph break message to identify the unsupported operation
2. Check for a decomposition or supported alternative
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
## Backend Compiler Failures
`BackendCompilerFailed` means Inductor crashed during compilation.
**Diagnose with the minifier:**
```bash
# Generate minifier launcher
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
# Run the minifier to get minimal failing graph
python minifier_launcher.py minify
# Run the minimized reproduction
python minifier_launcher.py run
```
**Then inspect:**
```bash
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
```
## Recompilation Issues
Excessive recompilation from guards that are too specific, causing cache misses.
**Diagnose:**
```bash
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
```
**Key config:**
```python
torch._dynamo.config.recompile_limit # default: 8
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
```
**Common causes:**
- Changing tensor shapes without marking them dynamic
- Python scalar values that change between calls
- Global state mutations between calls
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
- Fix the source of guard instability
## Accuracy Issues
Compiled model produces different numerical results than eager mode.
**Diagnose:**
```bash
# Compares compiled vs eager with fp64 reference, dumps repro on failure
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```
**Fix approach:**
1. Get minimal failing graph from the minifier
2. Compare eager vs compiled output at fp64 precision
3. Binary search through ops to find the diverging operation
4. Check for known issues: reduction order, fused kernels, dtype promotions
## Internal Dynamo Errors
`InternalTorchDynamoError` indicates a bug in Dynamo.
**Diagnose:**
```bash
TORCHDYNAMO_VERBOSE=1 python script.py
# or equivalently:
TORCH_LOGS="+dynamo" python script.py
```
**Debug interactively:**
```bash
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
```
## Runtime Crashes
Segfaults and CUDA illegal memory access during execution of compiled code.
**Make crash deterministic:**
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
**Add NaN checks to find the first bad kernel:**
```bash
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
```
**Inductor sync debugging:**
```python
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
```
**Fix approach:**
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
2. Check input shapes, devices, dtypes
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
5. Dynamic shapes are historically a common source of IMA
## Triton Kernel Failures
Triton assertion failures or index-out-of-bounds in generated kernels.
**Diagnose:**
```bash
TORCH_LOGS="output_code,schedule" python script.py
```
**Fix approach:**
1. Get the generated Triton kernel from `output_code` logs
2. Check index computations for off-by-one or wrong stride calculations
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
4. Check if fusion decisions created invalid index combinations
## Distinguish Trace-Time vs Runtime
Many bugs come from confusing these:
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
- **Runtime**: Real tensors, real Python calls.
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
## Using the Minifier
The minifier reduces a failing graph to the smallest reproduction:
```bash
# For compilation failures (level 2)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
python minifier_launcher.py minify
python minifier_launcher.py run
# For accuracy failures (level 4)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```

View File

@@ -0,0 +1,134 @@
---
name: ruff
description:
Guide for using ruff, the extremely fast Python linter and formatter. Use this
when linting, formatting, or fixing Python code.
---
# ruff
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
isort, Black, pyupgrade, autoflake, and dozens of other tools.
## When to use ruff
**Always use ruff for Python linting and formatting**, especially if you see:
- `[tool.ruff]` section in `pyproject.toml`
- A `ruff.toml` or `.ruff.toml` configuration file
However, avoid making unnecessary changes:
- **Don't format unformatted code** - If `ruff format --diff` shows changes
throughout an entire file, the project likely isn't using ruff for formatting.
Skip formatting to avoid obscuring actual changes.
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
relevant to the code you're changing. Only apply fixes to files you're
modifying unless the user explicitly asks for broader fixes.
## How to invoke ruff
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
you use the pinned version
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
one-off checks
- `ruff ...` - Use if ruff is installed globally
## Commands
### Linting
```bash
ruff check . # Check all files in current directory
ruff check path/to/file.py # Check specific file
ruff check --fix . # Auto-fix fixable violations
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
ruff check --watch . # Watch for changes and re-lint
ruff check --select E,F . # Only check specific rules
ruff check --ignore E501 . # Ignore specific rules
ruff rule E501 # Explain a specific rule
ruff linter # List available linters
```
### Formatting
```bash
ruff format . # Format all files
ruff format path/to/file.py # Format specific file
ruff format --check . # Check if files are formatted (no changes)
ruff format --diff . # Show formatting diff without applying
```
## Configuration
Ruff is configured in `pyproject.toml` or `ruff.toml`:
```toml
# pyproject.toml
[tool.ruff.lint]
select = ["E", "F", "I", "UP"] # Enable specific rule sets
ignore = ["E501"] # Ignore specific rules
[tool.ruff.lint.isort]
known-first-party = ["myproject"]
```
## Migrating from other tools
### Black → ruff format
```bash
black . → ruff format .
black --check . → ruff format --check .
black --diff . → ruff format --diff .
```
### Flake8 → ruff check
```bash
flake8 . → ruff check .
flake8 --select E,F . → ruff check --select E,F .
flake8 --ignore E501 . → ruff check --ignore E501 .
```
### isort → ruff check
```bash
isort . → ruff check --select I --fix .
isort --check . → ruff check --select I .
isort --diff . → ruff check --select I --diff .
```
## Common patterns
### Apply lint fixes before formatting
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
structure (e.g., reordering imports), which formatting then cleans up.
```bash
ruff check --fix .
ruff format .
```
### Applying and reviewing unsafe fixes
Ruff categorizes some auto-fixes as "unsafe" because they may change code
behavior, not just style. For example, removing unused imports could break code
that relies on side effects.
```bash
ruff check --fix --unsafe-fixes --diff . # Preview changes first
ruff check --fix --unsafe-fixes . # Apply changes
```
**Always review changes before applying `--unsafe-fixes`:**
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
- Verify the fix doesn't violate those assumptions in your code
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ruff/

135
.agents/skills/ty/SKILL.md Normal file
View File

@@ -0,0 +1,135 @@
---
name: ty
description:
Guide for using ty, the extremely fast Python type checker and language
server. Use this when type checking Python code or setting up type checking in
Python projects.
---
# ty
ty is an extremely fast Python type checker and language server. It replaces
mypy, Pyright, and other type checkers.
## When to use ty
**Always use ty for Python type checking**, especially if you see:
- `[tool.ty]` section in `pyproject.toml`
- A `ty.toml` configuration file
## How to invoke ty
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
use the pinned version or when ty is installed globally and you are in a
project so the virtual environment is updated.
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
checks
## Commands
### Type checking
```bash
ty check # Check all files in current directory
ty check path/to/file.py # Check specific file
ty check src/ # Check specific directory
```
### Rule configuration
```bash
ty check --error possibly-unresolved-reference # Treat as error
ty check --warn division-by-zero # Treat as warning
ty check --ignore unresolved-import # Disable rule
```
### Python version targeting
```bash
ty check --python-version 3.12 # Check against Python 3.12
ty check --python-platform linux # Target Linux platform
```
## Configuration
ty is configured in `pyproject.toml` or `ty.toml`:
```toml
# pyproject.toml
[tool.ty.environment]
python-version = "3.12"
[tool.ty.rules]
possibly-unresolved-reference = "warn"
division-by-zero = "error"
[tool.ty.src]
include = ["src/**/*.py"]
exclude = ["**/migrations/**"]
[tool.ty.terminal]
output-format = "full"
error-on-warning = false
```
### Per-file overrides
Use overrides to apply different rules to specific files, such as relaxing rules
for tests or scripts that have different typing requirements than production
code:
```toml
[[tool.ty.overrides]]
include = ["tests/**", "**/test_*.py"]
[tool.ty.overrides.rules]
possibly-unresolved-reference = "warn"
```
## Language server
This plugin automatically configures the ty language server for Python files
(`.py` and `.pyi`).
## Migrating from other tools
### mypy → ty
```bash
mypy . → ty check
mypy --strict . → ty check --error-on-warning
mypy path/to/file.py → ty check path/to/file.py
```
### Pyright → ty
```bash
pyright . → ty check
pyright path/to/file.py → ty check path/to/file.py
```
## Common patterns
### Don't add ignore comments
Fix type errors instead of suppressing them. Only add ignore comments when
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
prefer rule-specific ignores:
```python
# Good: rule-specific ignore
x = undefined_var # ty: ignore[possibly-unresolved-reference]
# Bad: blanket ty ignore
x = undefined_var # ty: ignore
# Bad: tool agnostic blanket ignore
x = undefined_var # type: ignore
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ty/

182
.agents/skills/uv/SKILL.md Normal file
View File

@@ -0,0 +1,182 @@
---
name: uv
description:
Guide for using uv, the Python package and project manager. Use this when
working with Python projects, scripts, packages, or tools.
---
# uv
uv is an extremely fast Python package and project manager. It replaces pip,
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
## When to use uv
**Always use uv for Python work**, especially if you see:
- The `uv.lock` file
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
Don't use uv in projects managed by other tools:
- Poetry projects (identifiable by `poetry.lock` file)
- PDM projects (identifiable by `pdm.lock` file)
## Choosing the right workflow
### Scripts
**Use when:** Running single Python files and standalone scripts.
**Key commands:**
```bash
uv run script.py # Run a script
uv run --with requests script.py # Run with additional packages
uv add --script script.py requests # Add dependencies inline to the script
```
### Projects
**Use when:** There is a `pyproject.toml` or `uv.lock`
**Key commands:**
```bash
uv init # Create new project
uv add requests # Add dependency
uv remove requests # Remove dependency
uv sync # Install from lockfile
uv run <command> # Run commands in environment
uv run python -c "" # Run Python in project environment
uv run -p 3.12 <command> # Run with specific Python version
```
### Tools
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
installation.
**Key commands:**
```bash
uvx <tool> <args> # Run a tool without installation
uvx <tool>@<version> <args> # Run a specific version of a tool
```
**Important:**
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
well-known tools.
- Only use `uv tool install` only when specifically requested by the user.
### Pip interface
**Use when:** Legacy workflows with `requirements.txt` or manual environment
management, no `uv.lock` present.
**Key commands:**
```bash
uv venv
uv pip install -r requirements.txt
uv pip compile requirements.in -o requirements.txt
uv pip sync requirements.txt
# Platform independent resolution
uv pip compile --universal requirements.in -o requirements.txt
```
**Important:**
- Don't use the pip interface unless clearly needed.
- Don't introduce new `requirements.txt` files.
- Prefer `uv init` for new projects.
## Migrating from other tools
### pyenv → uv python
```bash
pyenv install 3.12 → uv python install 3.12
pyenv versions → uv python list --only-installed
pyenv local 3.12 → uv python pin 3.12
pyenv global 3.12 → uv python install 3.12 --default
```
### pipx → uvx
```bash
pipx run ruff → uvx ruff
pipx install ruff → uv tool install ruff
pipx upgrade ruff → uv tool upgrade ruff
pipx list → uv tool list
```
### pip and pip-tools → uv pip
```bash
pip install package → uv pip install package
pip install -r req.txt → uv pip install -r req.txt
pip freeze → uv pip freeze
pip-compile req.in → uv pip compile req.in
pip-sync req.txt → uv pip sync req.txt
virtualenv .venv → uv venv
```
## Common patterns
### Don't use pip in uv projects
```bash
# Bad
pip install requests
# Good
uv add requests
```
### Don't run python directly
```bash
# Bad
python script.py
# Good
uv run script.py
```
```bash
# Bad
python -c "..."
# Good
uv run python -c "..."
```
```bash
# Bad
python3.12 -c "..."
# Good
uvx python@3.12 -c "..."
```
### Don't manually manage environments in uv projects
```bash
# Bad
python -m venv .venv
source .venv/bin/activate
# Good
uv run <command>
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/uv/llms.txt
The documentation links to specific pages for each of these workflows.

View File

@@ -1,6 +1,13 @@
{
"name": "Luminal (CPU)",
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
"initializeCommand": "touch .env",
"runArgs": [
"--env-file", ".env"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
@@ -10,10 +17,17 @@
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && if [ -f .env ]; then . .env; if [ -n \"$GH_TOKEN\" ]; then echo \"$GH_TOKEN\" | gh auth login --with-token 2>/dev/null; fi; fi",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
@@ -29,6 +43,7 @@
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",

View File

@@ -1,9 +1,17 @@
{
"name": "Luminal (CUDA)",
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
"initializeCommand": "touch .env",
"runArgs": [
"--gpus=all"
"--env-file",
".env",
"--runtime=nvidia",
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
@@ -13,10 +21,17 @@
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && if [ -f .env ]; then . .env; if [ -n \"$GH_TOKEN\" ]; then echo \"$GH_TOKEN\" | gh auth login --with-token 2>/dev/null; fi; fi",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
@@ -32,6 +47,7 @@
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",

30
.github/workflows/cuda-clippy.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: CUDA Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
cuda_clippy:
name: CUDA Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Mark workspace as safe for git
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files

23
.github/workflows/fmt.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Fmt
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

25
.github/workflows/metal-clippy.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: Metal Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files

45
.github/workflows/modal-examples.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: Modal Examples
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
modal_example:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 70
strategy:
fail-fast: false
matrix:
example: [llama, gemma, qwen, qwen3_moe]
gpu:
- { type: "A100-80GB" }
# To add more GPUs, just append another entry:
# - { type: "H100" }
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
EXAMPLE: ${{ matrix.example }}
GPU_TYPE: ${{ matrix.gpu.type }}
run: modal run ci/modal_example.py

23
.github/workflows/ruff-format.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff Format
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files

23
.github/workflows/ruff.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files

24
.github/workflows/test-core.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Test Core
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose

35
.github/workflows/test-cuda.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Test CUDA
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
cuda_unit_test:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Cuda Unit Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run CUDA tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
run: modal run ci/modal_cargo_test.py

19
.github/workflows/test-metal.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: Test Metal
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_unit_test:
name: Metal Unit Tests
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1

47
.github/workflows/test-python-cuda.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: Test Python CUDA
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
python_cuda_tests:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Python CUDA Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 60
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run pytest with CUDA backend on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
- name: Upload Modal pytest profiling artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
retention-days: 7
if-no-files-found: warn

View File

@@ -0,0 +1,28 @@
name: Test Python Native
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
python_native_tests:
name: Python Native Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 45
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"

View File

@@ -1,89 +0,0 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --verbose
clippy:
name: Clippy
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run clippy
run: rustup update; cargo clippy --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Format
run: cargo fmt --all --check
cuda_unit_test:
name: Cuda Unit Tests
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Detect GPU compute capability
run: |
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
- name: Run CUDA crate tests
run: cargo test -p luminal_cuda --verbose -- --test-threads=1
# cuda_llama: # disabled because t4 doesn't have enough memory for full precision llama. re-enable when we can run on larger machines or use 8-bit precision
# name: Cuda Llama
# runs-on: cuda_t4_runner
# timeout-minutes: 30
# env:
# CUDA_HOME: /usr/local/cuda-12.8
# LD_LIBRARY_PATH: /usr/local/cuda-12.8/lib64
# steps:
# - uses: actions/checkout@v6
# - name: Install system deps
# run: |
# sudo apt-get update
# sudo apt-get install -y --no-install-recommends \
# protobuf-compiler \
# cuda-nvrtc-12-8
# - name: Install Rust
# run: |
# curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal
# echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
# - name: Update Rust
# run: rustup update
# - name: Install uv
# run: curl -LsSf https://astral.sh/uv/install.sh | sh
# - name: Download Llama
# working-directory: examples/llama
# run: uv run --script setup/setup.py
# - name: Run Llama
# working-directory: examples/llama
# run: SEARCH=1 cargo run --release

18
.gitignore vendored
View File

@@ -1,6 +1,9 @@
/target
/crates/**/target
/examples/**/target
.claude-project
.claude-memory
.codex
*.env
.claude/
@@ -15,6 +18,10 @@ Cargo.lock
*.gguf
.claude-project
.claude-memory
.codex
*.pftrace
*.safetensors
*.safetensors.index.json
@@ -22,3 +29,14 @@ tokenizer.json
**/.cache
**/proptest-regressions
opencode.json
# Python build artifacts
*.so
*.pyd
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
uv.lock

38
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,38 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: ruff-check
name: ruff check
files: ^crates/luminal_python/.*\.py$
- id: ruff-format
name: ruff format
files: ^crates/luminal_python/.*\.py$
- repo: local
hooks:
- id: cargo-fmt
name: cargo fmt
entry: cargo fmt --all --check
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy
name: cargo clippy
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy-metal
name: cargo clippy metal
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]
- id: cargo-clippy-cuda-lite
name: cargo clippy cuda_lite
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]

View File

@@ -3,9 +3,9 @@
## Structure
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
## Testing Instructions
- Find the CI plan in the .github/workflows folder.
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.

34
CLAUDE.md Normal file
View File

@@ -0,0 +1,34 @@
# Luminal
## Package Management
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
- Use `uv sync` to sync the Python environment
- Never use pip, pip-tools, poetry, or conda
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
- Never generate requirements.txt
## Code Execution
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
## Building the Python Package (Maturin)
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
## Pre-commit
- Run with: `uv run pre-commit run --all-files`
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
## Testing
- **Rust tests**: `cargo test -p <crate_name>`
- **Python tests**: `cd crates/luminal_python && uv run pytest`
- `./run_test.sh` — native backend
- `./run_tests_cuda.sh` — CUDA backend
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions

View File

@@ -37,8 +37,8 @@ lru = "0.16.2"
edition = "2024"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
candle-nn = "0.9.2-alpha.1"
candle-core = "0.9.2"
candle-nn = "0.9.2"
ordered-float = "5.1.0"
proptest = "1.9.0"
@@ -46,11 +46,12 @@ proptest = "1.9.0"
members = [
"examples/*",
"crates/luminal_nn",
"crates/luminal_cuda",
"crates/luminal_cuda_lite",
"crates/luminal_metal",
"crates/luminal_tracing",
"crates/luminal_bench",
"crates/luminal_python/rust",
]
[patch.crates-io]
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }

Binary file not shown.

67
ci/modal_cargo_test.py Normal file
View File

@@ -0,0 +1,67 @@
import modal
import subprocess
import os
import sys
gpu_type = os.environ.get("GPU_TYPE", "T4")
CUDARC_CUDA_VERSION = "12080"
app = modal.App("luminal-ci-cargo-test")
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=1800, # 30 minutes
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
# Detect GPU compute capability
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
capture_output=True,
text=True,
check=True,
)
compute_cap = result.stdout.strip().replace(".", "")
subprocess.run(
[
"cargo", "test",
"-p", "luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",
],
cwd=WORKDIR,
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"CUDA_COMPUTE_CAP": compute_cap,
},
check=True,
)
@app.local_entrypoint()
def main():
run_cargo_test.remote()

67
ci/modal_example.py Normal file
View File

@@ -0,0 +1,67 @@
import modal
import subprocess
import os
example = os.environ.get("EXAMPLE", "llama")
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
CUDARC_CUDA_VERSION = "12080"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
app = modal.App(f"luminal-ci-{example}")
hf_cache = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.03-py3"
)
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=3600, # 60 minutes
volumes={
HF_CACHE_PATH: hf_cache,
},
)
def run_example(example: str):
"""Build and run a luminal example on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
subprocess.run(
["cargo", "run", "--release"],
cwd=f"{WORKDIR}/examples/{example}",
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"HF_HOME": HF_CACHE_PATH,
},
check=True,
)
hf_cache.commit()
@app.local_entrypoint()
def main():
run_example.remote(example)

View File

@@ -1,252 +0,0 @@
use itertools::Itertools;
use luminal::{prelude::FxHashMap, shape::Expression};
#[derive(Debug, PartialEq, Eq)]
enum CStructType {
Float,
FloatArr(usize),
Int,
IntArr(usize),
Long,
LongArr(usize),
Bool,
BoolArr(usize),
Ptr,
PtrArr(usize),
Bytes(usize),
}
#[derive(Debug)]
pub struct CStruct<'a> {
buf: Vec<u8>,
max_align: usize,
struct_types: Vec<(String, CStructType)>,
expressions: Option<&'a FxHashMap<Expression, i32>>,
pub(crate) recorded_expressions: Vec<Expression>,
}
impl<'a> CStruct<'a> {
pub fn new(expressions: Option<&'a FxHashMap<Expression, i32>>) -> Self {
Self {
max_align: 1,
struct_types: vec![],
buf: vec![],
expressions,
recorded_expressions: vec![],
}
}
fn align_to(&mut self, align: usize) {
self.max_align = self.max_align.max(align);
let len = self.buf.len();
let rem = len % align;
if rem != 0 {
let pad = align - rem;
self.buf.extend(std::iter::repeat_n(0u8, pad));
}
}
pub fn int(mut self, name: impl ToString, v: i32) -> Self {
self.struct_types.push((name.to_string(), CStructType::Int));
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn int_arr(mut self, name: impl ToString, vs: &[i32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::IntArr(vs.len())));
self.align_to(4);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn expr(mut self, name: impl ToString, v: impl Into<Expression>) -> Self {
if let Some(expressions) = self.expressions {
self.struct_types.push((name.to_string(), CStructType::Int));
let v = expressions[&v.into()];
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
} else {
self.recorded_expressions.push(v.into());
}
self
}
pub fn expr_arr(mut self, name: impl ToString, vs: &[Expression]) -> Self {
if let Some(expressions) = self.expressions {
self.struct_types
.push((name.to_string(), CStructType::IntArr(vs.len())));
self.align_to(4);
for &v in vs {
let v = expressions[&v];
self.buf.extend_from_slice(&v.to_ne_bytes());
}
} else {
self.recorded_expressions.extend(vs.iter().copied());
}
self
}
pub fn long(mut self, name: impl ToString, v: i64) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Long));
self.align_to(8);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn long_arr(mut self, name: impl ToString, vs: &[i64]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::LongArr(vs.len())));
self.align_to(8);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn float(mut self, name: impl ToString, v: f32) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Float));
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn float_arr(mut self, name: impl ToString, vs: &[f32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::FloatArr(vs.len())));
self.align_to(4);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn bool(mut self, name: impl ToString, v: bool) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Bool));
self.align_to(1);
self.buf.push(if v { 1 } else { 0 });
self
}
pub fn bool_arr(mut self, name: impl ToString, vs: &[bool]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::BoolArr(vs.len())));
self.align_to(1);
for &v in vs {
self.buf.push(if v { 1 } else { 0 });
}
self
}
pub fn ptr_const_f32(mut self, name: impl ToString, p: *const f32) -> Self {
self.struct_types.push((name.to_string(), CStructType::Ptr));
let ptr_size = std::mem::size_of::<usize>(); // usually 8
let ptr_align = ptr_size;
self.align_to(ptr_align);
let addr = p as usize;
let bytes = addr.to_ne_bytes();
self.buf.extend_from_slice(&bytes[..ptr_size]);
self
}
pub fn ptr_mut_f32(self, name: impl ToString, p: *mut f32) -> Self {
self.ptr_const_f32(name, p as *const f32)
}
pub fn ptr_const_f32_arr(mut self, name: impl ToString, p: &[*const f32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::PtrArr(p.len())));
let ptr_size = std::mem::size_of::<usize>(); // usually 8
let ptr_align = ptr_size;
self.align_to(ptr_align);
for &p in p {
let addr = p as usize;
let bytes = addr.to_ne_bytes();
self.buf.extend_from_slice(&bytes[..ptr_size]);
}
self
}
/// Returns the current size of the buffer after alignment for a pointer field.
/// Useful for computing field offsets.
pub fn current_size(&self) -> usize {
let ptr_align = std::mem::size_of::<usize>();
let len = self.buf.len();
let rem = len % ptr_align;
if rem != 0 {
len + (ptr_align - rem)
} else {
len
}
}
/// Pad the struct size to a multiple of max_align.
pub fn finish_struct(mut self) -> Vec<u8> {
assert!(
self.expressions.is_some(),
"Can only create cstruct bytes when expression map is provided!"
);
let align = self.max_align;
if align > 1 {
let len = self.buf.len();
let rem = len % align;
if rem != 0 {
let pad = align - rem;
self.buf.extend(std::iter::repeat_n(0u8, pad));
}
}
self.buf
}
/// Returns (size, alignment) of the struct.
pub fn size_and_align(&self) -> (usize, usize) {
let align = self.max_align;
let len = self.buf.len();
let rem = len % align;
let size = if rem != 0 { len + (align - rem) } else { len };
(size, align)
}
/// Insert a raw byte field (e.g., another struct).
/// `align` must be the alignment of the nested struct.
pub fn bytes(mut self, align: usize, name: impl ToString, data: &[u8]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Bytes(data.len())));
self.align_to(align);
self.buf.extend_from_slice(data);
self
}
}
impl std::fmt::Display for CStruct<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = self
.struct_types
.iter()
.map(|(name, ty)| match ty {
CStructType::Bool => format!("bool {name};"),
CStructType::BoolArr(l) => format!("bool {name}[{l}];"),
CStructType::Float => format!("float {name};"),
CStructType::FloatArr(l) => format!("float {name}[{l}];"),
CStructType::Int => format!("int {name};"),
CStructType::IntArr(l) => format!("int {name}[{l}];"),
CStructType::Long => format!("long {name};"),
CStructType::LongArr(l) => format!("long {name}[{l}];"),
CStructType::Ptr => format!("float* {name};"),
CStructType::PtrArr(l) => format!("float* {name}[{l}];"),
CStructType::Bytes(l) => format!("char payload[{l}];"),
})
.join("\n");
write!(f, "{s}")
}
}

View File

@@ -1,327 +0,0 @@
const int N_OPS = 0;
const int N_TIMING_SLOTS = 0;
const int N_TASKS = 0; // Rendered at compile time
//%n_barriers_const%
enum OpCode {
//%extra_op_codes%
};
//%extra_op_structs%
union Payload {
//%extra_op_payloads%
};
struct Task {
OpCode op;
int range;
int remaining;
int in_dep_a_stride;
int in_dep_a_base;
int in_dep_b_stride;
int in_dep_b_base;
int in_dep_c_stride;
int in_dep_c_base;
int out_dep_stride;
int out_dep_base;
int source_indices[6];
int out_index;
Payload payload;
};
struct SMEvent {
unsigned long long start;
unsigned long long stop;
int event;
};
//%constants%
__device__ __noinline__ int eval_expression(int expression, int const_z) {
switch (expression) {
//%expr_fns%
}
}
__device__ __forceinline__ unsigned long long read_globaltimer() {
unsigned long long t;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t));
return t;
}
//%extra_op_functions%
//%extra_prologue_functions%
__device__ __forceinline__ void nanosleep(unsigned int cycles) {
asm volatile("nanosleep.u32 %0;" ::"r"(cycles));
}
__device__ __forceinline__ int atomic_load_acquire(int *addr) {
int val;
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
}
struct NextTask {
int current;
int task_idx;
};
// Lock-free task fetching using atomicSub for claiming (reduces CAS contention)
// remaining encoding:
// -1 = uninitialized
// > 0 = iterations remaining (atomicSub to claim, iteration = old - 1)
// <= 0 = exhausted
__device__ inline bool fetch_next_task(Task *tasks, int num_tasks, int *head,
NextTask *out) {
while (true) {
int idx = atomic_load_acquire(head);
if (idx >= num_tasks)
return false;
Task *t = &tasks[idx];
int remaining = atomicAdd(&t->remaining, 0);
// Handle uninitialized task - one CAS to initialize
if (remaining == -1) {
int range = eval_expression(t->range, 0);
atomicCAS(&t->remaining, -1, range);
continue;
}
// Task already exhausted, advance head
if (remaining <= 0) {
atomicMax(head, idx + 1);
continue;
}
// Claim via atomicSub - guaranteed to make progress, no CAS retry
int old = atomicSub(&t->remaining, 1);
if (old > 0) {
out->task_idx = idx;
out->current = old - 1;
if (old == 1) {
atomicMax(head, idx + 1);
}
// DEBUG: This path indicates successful task claim
return true;
}
// Race: exhausted between check and atomicSub, advance head
atomicMax(head, idx + 1);
}
}
__device__ inline void record_event(SMEvent *__restrict__ timings,
int *event_idx, int event_type) {
if (*event_idx < N_TIMING_SLOTS) {
unsigned long long now = read_globaltimer();
if (*event_idx > 0) { // record the end of the previous op
timings[*event_idx - 1].stop = now;
}
timings[*event_idx].start = now;
timings[*event_idx].stop = 0ull;
timings[*event_idx].event = event_type;
(*event_idx)++;
}
}
extern "C" {
// Kernel params: internal buffers in order, then dyn_dims
// tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
__global__ void worker_kernel(
Task* __restrict__ tasks,
int* __restrict__ head,
int* __restrict__ ready,
int* __restrict__ queue_lock,
SMEvent* __restrict__ timings,
unsigned long long* __restrict__ start_times,
float* const* buffers,
int* __restrict__ dyn_dims
) {
// Constants N_TASKS and N_BARRIERS are baked into the kernel string
// Note: Reset is now done on host side in pre_execute
// All buffers (head, queue_lock, ready, tasks) are pre-initialized
// DEBUG: Count tasks fetched (use queue_lock as counter since it's not being used)
// Note: queue_lock is in internal_bufs[3]
__shared__ NextTask nt;
__shared__ int done;
__shared__ int dep_out;
__shared__ bool run_a_prologue;
__shared__ bool run_b_prologue;
__shared__ bool run_c_prologue;
__shared__ bool stop_wait_loop;
__shared__ float scratchpad[8192]; // 32 KB scratchpad
__shared__ const float* source_ptrs[6];
__shared__ float* out_ptr;
int recorded_event = 0;
timings += blockIdx.x * N_TIMING_SLOTS;
if (threadIdx.x == 0) {
start_times[blockIdx.x] = read_globaltimer();
}
while (true) {
if (threadIdx.x == 0) {
record_event(timings, &recorded_event, 0); // Record issue start
done = !fetch_next_task(tasks, N_TASKS, head, &nt);
}
__syncthreads();
if (done)
break;
const Task *t = &tasks[nt.task_idx];
// Resolve buffer pointers from indices
if (threadIdx.x == 0) {
source_ptrs[0] = buffers[t->source_indices[0]];
source_ptrs[1] = buffers[t->source_indices[1]];
source_ptrs[2] = buffers[t->source_indices[2]];
source_ptrs[3] = buffers[t->source_indices[3]];
source_ptrs[4] = buffers[t->source_indices[4]];
source_ptrs[5] = buffers[t->source_indices[5]];
out_ptr = buffers[t->out_index];
}
__syncthreads();
int dep_a = 0;
int dep_b = 0;
int dep_c = 0;
// Thread 0 calculates dependencies and waits for inputs
if (threadIdx.x == 0) {
// Note: atomic_load_acquire provides visibility for ready array
dep_a = (t->in_dep_a_base == -1
? 0
: (eval_expression(t->in_dep_a_base, 0) +
eval_expression(t->in_dep_a_stride, nt.current)));
dep_b = (t->in_dep_b_base == -1
? 0
: (eval_expression(t->in_dep_b_base, 0) +
eval_expression(t->in_dep_b_stride, nt.current)));
dep_c = (t->in_dep_c_base == -1
? 0
: (eval_expression(t->in_dep_c_base, 0) +
eval_expression(t->in_dep_c_stride, nt.current)));
dep_out = eval_expression(t->out_dep_base, 0) +
eval_expression(t->out_dep_stride, nt.current);
// Increment the output barrier to signal an op is in-flight
atomicAdd(&ready[dep_out], 1);
record_event(timings, &recorded_event, 1); // Record wait start
// Wait on input dependencies and run prologues as inputs become ready
run_a_prologue = false;
run_b_prologue = false;
run_c_prologue = false;
stop_wait_loop = false;
}
__syncthreads();
bool a_done = false, b_done = false, c_done = false, tmp;
// Optimize: if deps are same, reuse atomic load result
const bool ab_same = (dep_a == dep_b);
const bool ac_same = (dep_a == dep_c);
const bool bc_same = (dep_b == dep_c);
while (true) {
if (threadIdx.x == 0) {
// Derive x_done and run_x_prologue with optimized atomic loads
if (!a_done) {
tmp = atomic_load_acquire(&ready[dep_a]) <= 0;
if (tmp) {
run_a_prologue = true;
a_done = true;
// Propagate to same deps
if (ab_same) {
run_b_prologue = true;
b_done = true;
}
if (ac_same) {
run_c_prologue = true;
c_done = true;
}
}
}
if (!b_done && !ab_same) {
tmp = atomic_load_acquire(&ready[dep_b]) <= 0;
if (tmp) {
run_b_prologue = true;
b_done = true;
if (bc_same) {
run_c_prologue = true;
c_done = true;
}
}
}
if (!c_done && !ac_same && !bc_same) {
tmp = atomic_load_acquire(&ready[dep_c]) <= 0;
if (tmp) {
run_c_prologue = true;
c_done = true;
}
}
if (a_done && b_done && c_done)
stop_wait_loop = true;
}
__syncthreads();
// Early exit if all dependencies satisfied (skip prologue checks)
if (stop_wait_loop)
break;
if (run_a_prologue) {
switch (t->op) {
//%prologue_a_calls%
}
if (threadIdx.x == 0) {
run_a_prologue = false;
}
}
if (run_b_prologue) {
switch (t->op) {
//%prologue_b_calls%
}
if (threadIdx.x == 0) {
run_b_prologue = false;
}
}
if (run_c_prologue) {
switch (t->op) {
//%prologue_c_calls%
}
if (threadIdx.x == 0) {
run_c_prologue = false;
}
}
__syncthreads();
}
if (threadIdx.x == 0)
record_event(timings, &recorded_event,
t->op + 2); // Record main op, ends Wait
// Execute main operation
switch (t->op) {
//%extra_op_calls%
}
__syncthreads();
// Arrive at output barrier
if (threadIdx.x == 0) {
__threadfence();
atomicSub(&ready[dep_out], 1);
}
}
if (threadIdx.x == 0 && recorded_event > 0) {
timings[recorded_event - 1].stop = read_globaltimer();
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,82 +0,0 @@
//! Compiles BlockOp subgraphs into KernelOp (MegakernelOp).
use std::sync::Arc;
use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
use luminal::{
graph::LLIRGraph,
op::LLIROp,
prelude::{
FxHashMap, FxHashSet, NodeIndex,
petgraph::{Direction, visit::EdgeRef},
},
};
use tracing::{Level, span};
use crate::{kernel::KernelOp, runtime::partition_marked_convex};
use super::{BlockOp, MegakernelOp};
/// Compile all BlockOp subgraphs in the LLIR graph into MegakernelOps.
///
/// This function:
/// 1. Finds all BlockOp nodes in the graph
/// 2. Partitions them into convex subgraphs
/// 3. For each subgraph, creates a MegakernelOp (which implements KernelOp)
/// 4. Adds the megakernel node to the llir_graph with appropriate edges
///
/// Returns mappings needed for the kernel compilation phase:
/// - `megakernel_to_blocks`: Maps each megakernel node to the BlockOp nodes it contains
/// (used to include block op nodes in the kernel's inputs for buffer pointer collection)
#[allow(clippy::type_complexity)]
pub fn block_to_kernel(
llir_graph: &mut LLIRGraph,
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> FxHashMap<NodeIndex, Vec<NodeIndex>> {
let _span = span!(Level::TRACE, "block_to_kernel").entered();
let block_ops_in_graph = llir_graph
.node_indices()
.filter(|n| llir_graph[*n].to_dialect::<dyn BlockOp>().is_some())
.collect::<FxHashSet<_>>();
if block_ops_in_graph.is_empty() {
return FxHashMap::default();
}
let mut megakernel_to_blocks: FxHashMap<NodeIndex, Vec<NodeIndex>> = FxHashMap::default();
for subgraph in partition_marked_convex(llir_graph, &block_ops_in_graph).unwrap() {
// Create MegakernelOp which implements KernelOp
let megakernel_op = MegakernelOp::new(llir_graph, &subgraph, cuda_stream, kernel_cache);
// Add megakernel node to llir_graph as a KernelOp
let megakernel_node =
llir_graph.add_node(LLIROp::new(Box::new(megakernel_op) as Box<dyn KernelOp>));
// Find external inputs: nodes outside subgraph that have edges into subgraph
// These edges establish exec_graph dependencies (megakernel waits for inputs)
let external_inputs: FxHashSet<NodeIndex> = subgraph
.iter()
.flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.filter(|src| !subgraph.contains(src))
})
.collect();
// Add edges from external inputs to megakernel node
// Note: We don't add edges TO external consumers because the original
// block op -> consumer edges still exist and will be used for exec_graph ordering
for input in &external_inputs {
llir_graph.add_edge(*input, megakernel_node, ());
}
// Map megakernel node to all block op nodes it contains
megakernel_to_blocks.insert(megakernel_node, subgraph.into_iter().collect());
}
megakernel_to_blocks
}

View File

@@ -1,71 +0,0 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)

View File

@@ -1,71 +0,0 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)

View File

@@ -1,71 +0,0 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)

View File

@@ -1,71 +0,0 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)

View File

@@ -1,127 +0,0 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Iota ?gu_io ?gu_iota_base_range))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Mul ?gu_mul_base_shape ?topk_idx ?gu_mul_base_a_stride ?gu_iota_base ?gu_mul_base_b_stride ?gu_mul_base_out_stride))
; t53: Cast to F32
(= ?gu_cast_base (Cast ?gu_mul_base ?gu_cast_base_size (F32)))
; t54: Iota for within-expert index
(= ?gu_iota_within (Iota (MIter) ?gu_iota_within_range))
; t55: Cast within to F32
(= ?gu_cast_within (Cast ?gu_iota_within ?gu_cast_within_size (F32)))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Add ?gu_add_shape ?gu_cast_base ?gu_add_a_stride ?gu_cast_within ?gu_add_b_stride ?gu_add_out_stride))
; t57: Cast to Int
(= ?gu_cast_idx (Cast ?gu_add_idx ?gu_cast_idx_size (Int)))
; t58: Gather gate_up weights
(= ?gu_gathered (Gather ?gu_cast_idx ?gu_gather_idx_shape ?gu_gather_idx_stride ?gate_up_w ?gu_gather_data_shape ?gu_gather_data_stride))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Cast ?gu_gathered ?gu_f32_size (F32)))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Mul ?gu_matmul_mul_shape ?x ?gu_matmul_a_stride ?gu_f32 ?gu_matmul_b_stride ?gu_matmul_mul_out_stride))
; t61: SumReduce over K dimension
(= ?gu_matmul (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_mul ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Iota ?up_iota_expr ?up_iota_range))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Gather ?up_iota ?up_gather_idx_shape ?up_gather_idx_stride ?gu_matmul ?up_gather_data_shape ?up_gather_data_stride))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Constant -1.000000))
; t65: gate * -1
(= ?neg_gate (Mul ?silu_shape1 ?gu_matmul ?silu_a_stride1 ?neg1 ?silu_b_stride1 ?silu_out_stride1))
; t66: Constant(log2e)
(= ?log2e (Constant 1.442695))
; t67: neg_gate * log2e
(= ?scaled (Mul ?silu_shape2 ?neg_gate ?silu_a_stride2 ?log2e ?silu_b_stride2 ?silu_out_stride2))
; t68: exp2
(= ?exp2_val (Exp2 ?silu_shape3 ?scaled ?silu_in_stride3 ?silu_out_stride3))
; t69: Constant(1)
(= ?one (Constant 1.000000))
; t70: exp2 + 1
(= ?plus1 (Add ?silu_shape4 ?exp2_val ?silu_a_stride4 ?one ?silu_b_stride4 ?silu_out_stride4))
; t71: recip
(= ?sigmoid (Recip ?silu_shape5 ?plus1 ?silu_in_stride5 ?silu_out_stride5))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Mul ?silu_shape6 ?gu_matmul ?silu_a_stride6 ?sigmoid ?silu_b_stride6 ?silu_out_stride6))
; t73: silu(gate) * up
(= ?swiglu_out (Mul ?swiglu_shape ?silu_out ?swiglu_a_stride ?up_slice ?swiglu_b_stride ?swiglu_out_stride))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Iota ?dn_io ?dn_iota_base_range))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Mul ?dn_mul_base_shape ?topk_idx ?dn_mul_base_a_stride ?dn_iota_base ?dn_mul_base_b_stride ?dn_mul_base_out_stride))
; t76: Cast to F32
(= ?dn_cast_base (Cast ?dn_mul_base ?dn_cast_base_size (F32)))
; t77: Iota for within-expert index
(= ?dn_iota_within (Iota (MIter) ?dn_iota_within_range))
; t78: Cast within to F32
(= ?dn_cast_within (Cast ?dn_iota_within ?dn_cast_within_size (F32)))
; t79: Add base + within
(= ?dn_add_idx (Add ?dn_add_shape ?dn_cast_base ?dn_add_a_stride ?dn_cast_within ?dn_add_b_stride ?dn_add_out_stride))
; t80: Cast to Int
(= ?dn_cast_idx (Cast ?dn_add_idx ?dn_cast_idx_size (Int)))
; t81: Gather down weights
(= ?dn_gathered (Gather ?dn_cast_idx ?dn_gather_idx_shape ?dn_gather_idx_stride ?down_w ?dn_gather_data_shape ?dn_gather_data_stride))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_f32 (Cast ?dn_gathered ?dn_f32_size (F32)))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Mul ?dn_matmul_mul_shape ?swiglu_out ?dn_matmul_a_stride ?dn_f32 ?dn_matmul_b_stride ?dn_matmul_mul_out_stride))
; t84: SumReduce
(= ?dn_matmul (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_mul ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Mul ?weighted_shape ?dn_matmul ?weighted_a_stride ?topk_vals ?weighted_b_stride ?weighted_out_stride))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Sum ?output_shape ?output_k ?weighted ?output_in_stride ?output_k_stride ?output_out_stride))
)
(
(let ?glumoe (GLUMoE ?x ?topk_idx ?topk_vals ?gate_up_w ?down_w
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
)

View File

@@ -1,237 +0,0 @@
use std::sync::Arc;
use crate::{
cuda_dtype,
kernel::KernelOp,
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
};
use cudarc::{
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
nvrtc::CompileOptions,
};
use itertools::Itertools;
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, ELIST, EXPRESSION, IR},
extract_dtype, extract_expr, extract_expr_list,
},
op::*,
prelude::*,
};
pub type Ops = (KernelMeanReduce,);
#[derive(Default, Debug, Clone)]
pub struct KernelMeanReduce {
out_shape: Vec<Expression>,
iters: Expression,
in_stride: Vec<Expression>,
iter_stride: Expression,
out_stride: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelMeanReduce {
fn sort(&self) -> SortDef {
sort(
IR,
"KernelMean",
&[
("shape", ELIST),
("iters", EXPRESSION),
("inp", IR),
("strides", ELIST),
("iter_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw("
(rule
(
(= ?sum (Sum ?out_shape ?iters ?inp ?in_stride ?iter_stride ?sum_out_stride))
(= ?iota (Iota ?iters ?one))
(= ?cast (Cast ?iota ?one (F32)))
(= ?recip (Recip ?r_shape ?cast ?r_in_strides ?r_out_strides))
(= ?result (Mul ?shape ?sum ?sum_strides ?recip ?recip_strides ?out_strides))
(= ?dty (dtype ?inp))
)
(
(union ?result (KernelMean ?out_shape ?iters ?inp ?in_stride ?iter_stride ?out_strides ?dty))
)
:name \"kernel mean reduce\"
)
")]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&self,
egraph: &'a SerializedEGraph,
children: &[&'a ENodeId],
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap(),
iters: extract_expr(egraph, children[1], expr_cache).unwrap(),
in_stride: extract_expr_list(egraph, children[3], list_cache, expr_cache).unwrap(),
iter_stride: extract_expr(egraph, children[4], expr_cache).unwrap(),
out_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap(),
dtype: extract_dtype(egraph, children[6]),
}) as Box<dyn KernelOp>),
vec![children[2]],
)
}
}
impl KernelOp for KernelMeanReduce {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self
.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.iters.dyn_vars())
.chain(self.iter_stride.dyn_vars())
.collect::<FxHashSet<_>>();
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let n_outputs: Expression = self.out_shape.iter().copied().product();
let threads_per_block = 256; // 8 warps per block
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let kernel = format!(
"{includes}
#define WARP_SIZE 32
#define THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
{dyn_defines}
extern \"C\" {{
__global__ void reduce_mean_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
__shared__ {dtype} warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
long long in_start = {in_index};
long long iters = {iters};
long long iter_stride = {iter_stride};
{dtype} sum = 0;
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
sum += in[in_start + i * iter_stride];
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
sum += __shfl_down_sync(FULL_MASK, sum, s);
}}
if (lane_id == 0) {{
warp_sums[warp_id] = sum;
}}
__syncthreads();
if (warp_id == 0) {{
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
{dtype} block_sum = tid < cnt ? warp_sums[tid] : 0;
#pragma unroll
for (int s = cnt / 2; s > 0; s /= 2) {{
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
}}
if (tid == 0) {{
out[{out_index}] = ({dtype})(block_sum / (float)iters);
}}
}}
}}
}}",
dtype = dtype,
in_index = flatten_strides(&self.out_shape, &self.in_stride).to_kernel(),
out_index = flatten_strides(&self.out_shape, &self.out_stride).to_kernel(),
iters = self.iters.to_kernel(),
iter_stride = self
.iter_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel(),
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("reduce_mean_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()), // grid
(threads_per_block.into(), 1.into(), 1.into()), // blocks
32.into(), // shmem size
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
fn bytes_loaded(&self) -> Expression {
(self.out_shape.iter().copied().product::<Expression>() * self.iters * self.dtype.bits())
.ceil_div(8)
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
let n_outputs: Expression = self.out_shape.iter().copied().product();
n_outputs * self.iters + n_outputs
}
fn kernel_name(&self) -> &'static str {
"MeanReduce"
}
}

View File

@@ -1,57 +0,0 @@
pub mod block;
pub mod host;
pub mod kernel;
pub mod logical;
pub mod runtime;
use std::sync::Arc;
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::driver::CudaContext;
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
match dtype {
DType::F64 => "double",
DType::F32 => "float",
DType::F16 => "half",
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",
DType::U8 => "unsigned char",
DType::Bool => "unsigned char",
DType::F8E4M3 => "__nv_fp8_e4m3",
DType::F8E5M2 => "__nv_fp8_e5m2",
DType::F8UE8M0 => "__nv_fp8_e8m0",
DType::F6E2M3 => "__nv_fp6_e2m3",
DType::F6E3M2 => "__nv_fp6_e3m2",
DType::F4E2M1 => "__nv_fp4_e2m1",
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
}
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 273,
"NVIDIA H100 PCIe" => 2_000,
"NVIDIA H100 SXM" => 3_350,
_ => return None,
})
}
/// Returns the bandwidth of the device in TFLOPs
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 125, // forced to use tf32 flops
"NVIDIA H100 PCIe" => 756,
"NVIDIA H100 SXM" => 989,
_ => return None,
})
}

View File

@@ -1,71 +0,0 @@
use std::fmt::Debug;
use luminal::{
egglog_utils::{
api::{Rule, SortDef},
base::OP_SORTS,
},
op::EgglogOp,
};
pub type Ops = (Exp, Sigmoid);
#[derive(Debug, Default)]
pub struct Exp;
impl EgglogOp for Exp {
fn sort(&self) -> SortDef {
OP_SORTS.unary("Exp")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
(= ?exp_const (Constant 1.442695))
(= ?mul (Mul ?shape ?x ?x_stride ?exp_const ?const_stride ?intermediate_stride))
(= ?exp2 (Exp2 ?shape ?mul ?intermediate_stride ?out_stride))
(= ?dt (dtype ?x))
)
(
(let ?exp (Exp ?shape ?x ?x_stride ?out_stride))
(union ?exp2 ?exp)
(set (dtype ?exp) ?dt)
)
)",
)]
}
}
#[derive(Default, Debug, Clone)]
pub struct Sigmoid;
impl EgglogOp for Sigmoid {
fn sort(&self) -> SortDef {
OP_SORTS.unary("Sigmoid")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw("(rule
(
(= ?neg_input (Mul ?input_range ?input ?input_stride (Constant -1.0) ?const_stride ?intermediate_stride))
(= ?exp (Exp ?input_range ?neg_input ?intermediate_stride ?exp_stride))
(= ?plus_one (Add ?input_range ?exp ?exp_stride (Constant 1.0) ?const_stride ?plus_one_stride))
(= ?sig_out (Recip ?input_range ?plus_one ?plus_one_stride ?out_stride))
(= ?dt (dtype ?input))
)
(
(let ?sig (Sigmoid ?input_range ?input ?input_stride ?out_stride))
(union ?sig_out ?sig)
(set (dtype ?sig) ?dt)
)
:name \"sigmoid\"
)")]
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
[package]
name = "luminal_cuda"
name = "luminal_cuda_lite"
version = "0.2.0"
edition = "2024"
description = "Cuda compiler for luminal"
@@ -26,7 +26,7 @@ libc = "0.2"
colorize = "*"
[dev-dependencies]
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
candle-core = { version = "0.9.2", features = ["cuda"] }
proptest = "1.9.0"
rand = "0.9.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -1,4 +1,4 @@
## luminal_cuda
## luminal_cuda_lite
This crate contains the CUDA backend for Luminal.
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
### Architecture
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.

View File

@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, IR, STRING},
base::{EXPRESSION, OP_KIND, STRING},
extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -74,11 +74,9 @@ impl Default for CuBlasSgemmV2 {
impl EgglogOp for CuBlasSgemmV2 {
fn sort(&self) -> SortDef {
sort(
IR,
OP_KIND,
"cublasSgemmV2",
&[
("a", IR),
("b", IR),
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
@@ -91,6 +89,10 @@ impl EgglogOp for CuBlasSgemmV2 {
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
@@ -104,25 +106,26 @@ impl EgglogOp for CuBlasSgemmV2 {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
children: &[&'a ENodeId],
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[children[5]].0;
let b_layout_str = &egraph.enodes[children[6]].0;
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
let extracted_state = Self {
m,
@@ -139,7 +142,7 @@ impl EgglogOp for CuBlasSgemmV2 {
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, vec![children[0], children[1]])
(extracted, input_enodes)
}
fn cleanup(&self) -> bool {

View File

@@ -12,10 +12,13 @@
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -34,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -52,9 +55,7 @@
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -62,7 +63,8 @@
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,10 +12,13 @@
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -34,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -52,9 +55,7 @@
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -62,7 +63,8 @@
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,10 +12,13 @@
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -34,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -52,9 +55,7 @@
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -62,7 +63,8 @@
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,10 +12,13 @@
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -34,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -52,9 +55,7 @@
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -62,7 +63,8 @@
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -0,0 +1,133 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "T"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × column-major"
)

View File

@@ -0,0 +1,133 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "T"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × row-major"
)

View File

@@ -0,0 +1,133 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × column-major"
)

View File

@@ -0,0 +1,139 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; In broadcast [batch, m, n, k] space:
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape: [batch, m, n]
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
; A strides in [batch, m, n, k]
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; B strides in [batch, m, n, k]
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "N"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc (contiguous output per batch)
?batch ; batch_count
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × row-major"
)

View File

@@ -4,7 +4,7 @@ use luminal::{
dtype::DType,
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, EXPRESSION, IR, STRING},
base::{DTYPE, EXPRESSION, OP_KIND, STRING},
extract_dtype, extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -45,6 +45,10 @@ pub struct CuBlasLt {
lda: Expression,
ldb: Expression,
ldc: Expression,
batch_count: Expression,
stride_a: Expression,
stride_b: Expression,
stride_c: Expression,
dtype: DType,
cublaslt: OnceLock<Arc<CudaBlasLT>>,
}
@@ -56,11 +60,15 @@ impl Default for CuBlasLt {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
a_layout: cublasOperation_t::CUBLAS_OP_N,
b_layout: cublasOperation_t::CUBLAS_OP_T,
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
batch_count: 1.into(),
stride_a: 0.into(),
stride_b: 0.into(),
stride_c: 0.into(),
dtype: DType::F32,
cublaslt: OnceLock::new(),
}
@@ -70,11 +78,9 @@ impl Default for CuBlasLt {
impl EgglogOp for CuBlasLt {
fn sort(&self) -> SortDef {
sort(
IR,
OP_KIND,
"cublaslt",
&[
("a", IR),
("b", IR),
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
@@ -83,17 +89,48 @@ impl EgglogOp for CuBlasLt {
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
("batch_count", EXPRESSION),
("stride_a", EXPRESSION),
("stride_b", EXPRESSION),
("stride_c", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
// (not the Mul eclass), so they survive the cascade.
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
]
}
@@ -101,28 +138,35 @@ impl EgglogOp for CuBlasLt {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
children: &[&'a ENodeId],
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[children[5]].0;
let b_layout_str = &egraph.enodes[children[6]].0;
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
// Extract batch parameters
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
// Extract dtype from egglog
let dtype = extract_dtype(egraph, children[10]);
let dtype = extract_dtype(egraph, kind_children[12]);
let extracted_state = Self {
m,
@@ -133,6 +177,10 @@ impl EgglogOp for CuBlasLt {
lda,
ldb,
ldc,
batch_count,
stride_a,
stride_b,
stride_c,
dtype,
cublaslt: OnceLock::new(),
};
@@ -140,7 +188,7 @@ impl EgglogOp for CuBlasLt {
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, vec![children[0], children[1]])
(extracted, input_enodes)
}
fn cleanup(&self) -> bool {
@@ -209,15 +257,24 @@ impl HostOp for CuBlasLt {
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
let m = self.m.exec(dyn_map).unwrap() as u64;
let n = self.n.exec(dyn_map).unwrap() as u64;
let k = self.k.exec(dyn_map).unwrap() as u64;
use crate::cudarc::cublaslt::sys::{
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
};
// GEMM parameters — resolve z→1 for element stride before exec
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = self.lda.exec(dyn_map).unwrap() as i64;
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
// Get CUDA types based on dtype
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
@@ -242,20 +299,28 @@ impl HostOp for CuBlasLt {
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
// Debug tracing
trace!(
"buffer_validation {}=={},{}=={},{}=={}",
a_buf.len(),
m * k * element_size,
b_buf.len(),
k * n * element_size,
c_buf.len(),
m * n * element_size
);
// Clamp leading dimensions to minimum valid values.
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
// dimension may be 0 in the egglog representation, but cuBLAS requires
// lda >= rows_of_A and ldb >= rows_of_B.
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N {
m
} else {
k
};
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N {
k
} else {
n
};
let lda = std::cmp::max(lda, a_ld_min as i64);
let ldb = std::cmp::max(ldb, b_ld_min as i64);
let ldc = std::cmp::max(ldc, m as i64);
let _span = span!(
Level::TRACE,
"cuBLASLT",
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
)
.entered();
@@ -312,6 +377,26 @@ impl HostOp for CuBlasLt {
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
// Set batched GEMM attributes if batch_count > 1
if batch_count > 1 {
for (desc, stride) in [(a_desc, stride_a), (b_desc, stride_b), (c_desc, stride_c)] {
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i32>(),
)
.result()?;
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i64>(),
)
.result()?;
}
}
// Create preference and set workspace size
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
@@ -338,7 +423,6 @@ impl HostOp for CuBlasLt {
.result()?;
if algo_count == 0 {
// Cleanup before returning error
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
@@ -347,7 +431,6 @@ impl HostOp for CuBlasLt {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
// All dtypes use F32 scale type for alpha/beta
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
cublasLtMatmul(
@@ -362,7 +445,7 @@ impl HostOp for CuBlasLt {
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
c_desc, // D layout same as C
c_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
@@ -383,7 +466,8 @@ impl HostOp for CuBlasLt {
}
fn output_size(&self) -> Expression {
self.m * self.n
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
}
fn output_bytes(&self) -> Expression {

View File

@@ -0,0 +1,128 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
; t53: Cast to F32
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
; t54: Iota for within-expert index
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
; t55: Cast within to F32
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
; t57: Cast to Int
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
; t58: Gather gate_up weights
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
; t61: SumReduce over K dimension
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Op (Constant -1.000000) (INil)))
; t65: gate * -1
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
; t66: Constant(log2e)
(= ?log2e (Op (Constant 1.442695) (INil)))
; t67: neg_gate * log2e
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
; t68: exp2
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
; t69: Constant(1)
(= ?one (Op (Constant 1.000000) (INil)))
; t70: exp2 + 1
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
; t71: recip
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
; t73: silu(gate) * up
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
; t76: Cast to F32
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
; t77: Iota for within-expert index
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
; t78: Cast within to F32
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
; t79: Add base + within
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
; t80: Cast to Int
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
; t81: Gather down weights
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
; t84: SumReduce
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
)
(
(let ?glumoe (Op (GLUMoE
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range)
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
)

View File

@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, IR},
base::{EXPRESSION, OP_KIND},
extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -12,6 +12,7 @@ use luminal::{
};
use crate::{
compile_module_image_for_current_device,
cudarc::{
cublas::sys::cublasOperation_t,
cublaslt::{
@@ -30,7 +31,6 @@ use crate::{
driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
nvrtc::{CompileOptions, compile_ptx_with_opts},
},
host::HostOp,
};
@@ -146,17 +146,7 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
}
}
"#;
let ptx = compile_ptx_with_opts(
src,
CompileOptions {
include_paths: vec![
"/usr/local/cuda/include".to_string(),
"/usr/include".to_string(),
],
..Default::default()
},
)
.unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
let swiglu = module.load_function("swiglu_bf16").unwrap();
@@ -168,14 +158,9 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
impl EgglogOp for GLUMoE {
fn sort(&self) -> SortDef {
sort(
IR,
OP_KIND,
"GLUMoE",
&[
("x", IR),
("topk_idx", IR),
("topk_vals", IR),
("gate_up_w", IR),
("down_w", IR),
("gu_io", EXPRESSION),
("dn_io", EXPRESSION),
("gu_matmul_k", EXPRESSION),
@@ -187,6 +172,10 @@ impl EgglogOp for GLUMoE {
)
}
fn n_inputs(&self) -> usize {
5
}
fn early_rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
}
@@ -194,17 +183,18 @@ impl EgglogOp for GLUMoE {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
children: &[&'a ENodeId],
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let gu_io = extract_expr(egraph, children[5], expr_cache).unwrap();
let dn_io = extract_expr(egraph, children[6], expr_cache).unwrap();
let gu_matmul_k = extract_expr(egraph, children[7], expr_cache).unwrap();
let dn_matmul_k = extract_expr(egraph, children[8], expr_cache).unwrap();
let output_k = extract_expr(egraph, children[9], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, children[10], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, children[11], expr_cache).unwrap();
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let extracted = GLUMoE {
gu_io,
@@ -220,16 +210,7 @@ impl EgglogOp for GLUMoE {
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
(
op,
vec![
children[0],
children[1],
children[2],
children[3],
children[4],
],
)
(op, input_enodes)
}
fn cleanup(&self) -> bool {

View File

@@ -425,7 +425,7 @@ mod tests {
fn test_raw_function_extraction() {
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
@@ -448,7 +448,7 @@ mod tests {
use cudarc::driver::{CudaSlice, DevicePtr};
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
@@ -489,15 +489,16 @@ mod tests {
};
let size = 1024;
let mut cx = Graph::default();
let a = cx.tensor(size);
let b = cx.tensor(size);
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = ((a + b) * a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
@@ -520,15 +521,16 @@ mod tests {
};
let size = 2048;
let mut cx = Graph::default();
let a = cx.tensor(size);
let b = cx.tensor(size);
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = (a + b + a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
let mut results = Vec::new();
for _ in 0..5 {
@@ -559,13 +561,14 @@ mod tests {
let b = cx.tensor('s');
let c = (a + b).output();
let d = (c * a).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
@@ -601,12 +604,13 @@ mod tests {
let a = cx.tensor(size);
let b = cx.tensor(size);
let c = (a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
@@ -623,20 +627,21 @@ mod tests {
};
let size = 4096;
let mut cx = Graph::default();
let a = cx.tensor(size);
let b = cx.tensor(size);
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let mut result = a + b;
for _ in 0..5 {
result += a;
result *= b;
}
let output = result.output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
for _ in 0..10 {
rt.execute(&cx.dyn_map);

View File

@@ -173,9 +173,23 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
/// Returns the output buffer size in elements.
fn output_size(&self) -> Expression;
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.output_size().dyn_vars().into_iter().collect()
}
/// Returns the output buffer size in bytes (accounts for dtype).
fn output_bytes(&self) -> Expression;
/// Returns the DType of this kernel's output buffer.
/// Used by has_nan_outputs to interpret buffer bytes correctly.
/// Default: F32 (most kernels output float).
fn output_dtype(&self) -> DType {
DType::F32
}
/// Returns the number of bytes this kernel will load from global memory.
fn bytes_loaded(&self) -> Expression {
0.into()
@@ -244,6 +258,23 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
) {
}
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
fn output_aliases_input(&self) -> Option<usize> {
None
}
/// If this kernel's output is derived from one of its inputs (copy-then-modify
/// or in-place write), return that input index. Used by `resolve_data_node` to
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
/// roundtrip pattern.
///
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
/// (like Scatter which copies dest→output then scatters into it).
fn output_data_input(&self) -> Option<usize> {
self.output_aliases_input()
}
/// Returns indices of internal buffers containing timing data, if any.
/// Returns (timings_idx, start_times_idx, sm_count).
fn timing_buffer_indices(&self) -> Option<(usize, usize, usize)> {

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ use cudarc::driver::{
};
use itertools::Itertools;
use luminal::{
egglog_utils::{api::Rule, base::IR},
egglog_utils::{api::Rule, base::OP_KIND},
graph::LLIRGraph,
op::{EgglogOp, LLIROp},
prelude::{
@@ -26,6 +26,7 @@ use crate::{
kernel::{
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
destroy_cuda_event,
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
},
runtime::partition_marked_convex,
};
@@ -195,7 +196,7 @@ impl std::fmt::Debug for CudaGraphOp {
impl EgglogOp for CudaGraphOp {
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
luminal::egglog_utils::api::sort(IR, "CudaGraphOp", &[])
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
}
fn rewrites(&self) -> Vec<Rule> {
@@ -205,7 +206,8 @@ impl EgglogOp for CudaGraphOp {
fn extract<'a>(
&'a self,
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
_children: &[&'a luminal::prelude::ENodeId],
_kind_children: &[&'a luminal::prelude::ENodeId],
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
@@ -299,7 +301,9 @@ impl CudaGraphOp {
for kernel in state.kernels.iter_mut() {
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
}
// Internal buffer pointers changed, need to rebuild CUDA graph
}
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
if dyn_map_changed || needs_internal_realloc {
state.cuda_graph = None;
state.cuda_graph_exec = None;
state.node_to_graph_node.clear();
@@ -340,23 +344,33 @@ impl CudaGraphOp {
}
}
// Apply output-aliases-input
for kernel in state.kernels.iter() {
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
{
current_buffer_ptrs.insert(kernel.node, input_ptr);
}
}
// Always call pre_execute for each kernel to reset internal state
// (e.g., MegakernelOps need work queue, head, barriers, lock reset every execution)
for idx in 0..state.kernels.len() {
let kernel = &mut state.kernels[idx];
kernel.kernel_op.pre_execute(
stream,
&mut kernel.internal_bufs,
&mut kernel.constants,
&current_buffer_ptrs,
dyn_map,
);
}
// Check if we need to update the graph
let buffer_ptrs_changed = current_buffer_ptrs != state.last_buffer_ptrs;
let needs_update = dyn_map_changed || buffer_ptrs_changed;
if needs_update {
// Call pre_execute for each kernel
for idx in 0..state.kernels.len() {
let kernel = &mut state.kernels[idx];
kernel.kernel_op.pre_execute(
stream,
&mut kernel.internal_bufs,
&mut kernel.constants,
&current_buffer_ptrs,
dyn_map,
);
}
// Update kernel params
let dyn_dims_ptr = state
.dyn_dims_buffer
@@ -424,15 +438,9 @@ impl CudaGraphOp {
state.last_buffer_ptrs = current_buffer_ptrs;
}
// Sync before launch
stream.synchronize()?;
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
// Sync after launch
stream.synchronize()?;
Ok(())
}
@@ -589,7 +597,7 @@ impl Drop for CudaGraphOp {
fn drop(&mut self) {
let mut state = self.state.borrow_mut();
// Destroy timing events - extract ctx first to avoid borrow issues
// Destroy timing events first
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
if let Some(ctx) = ctx {
for event in state.timing_events.drain(..) {
@@ -597,22 +605,22 @@ impl Drop for CudaGraphOp {
}
}
// Forget dyn_dims buffer (managed by runtime)
if let Some(buf) = state.dyn_dims_buffer.take() {
std::mem::forget(buf);
}
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
// so it must be destroyed first to avoid dangling pointer issues.
drop(state.cuda_graph_exec.take());
drop(state.cuda_graph.take());
// Handle kernel resources
// Now safe to free dynamically allocated GPU buffers
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
// Constants point to __constant__ memory in the CUDA module,
// not dynamically allocated — must not be freed.
for kernel in state.kernels.iter_mut() {
// Forget constants (they point to __constant__ memory)
let constants = std::mem::take(&mut kernel.constants);
for (_k, v) in constants {
std::mem::forget(v);
}
// Forget internal buffers (managed by runtime)
for buf in kernel.internal_bufs.drain(..) {
std::mem::forget(buf);
}
}
}
}
@@ -632,7 +640,6 @@ pub fn kernel_to_host(
llir_graph: &mut LLIRGraph,
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
megakernel_to_blocks: &FxHashMap<NodeIndex, Vec<NodeIndex>>,
) {
let _span = span!(Level::TRACE, "kernel_to_host").entered();
@@ -660,11 +667,28 @@ pub fn kernel_to_host(
.filter(|n| subgraph.contains(n))
.collect();
let mut kernels = Vec::with_capacity(topo_order.len());
let mut all_dyn_dims = FxHashSet::default();
let mut all_buffer_nodes = FxHashSet::default();
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
.unwrap();
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
}
// Set global dyn dims ordering so compiles use consistent indices
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
global_dyn_dims.sort();
if !global_dyn_dims.is_empty() {
set_global_dyn_dims(global_dyn_dims.clone());
}
// Compile all kernels with global ordering for correct dyn_dims indices
let mut kernels = Vec::with_capacity(topo_order.len());
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
@@ -680,21 +704,6 @@ pub fn kernel_to_host(
.map(|e| e.source())
.collect_vec();
// If this is a megakernel, include all its block op nodes for buffer access
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node_idx) {
inputs.extend(block_nodes.iter().copied());
}
// Collect dyn dims used by this kernel
all_dyn_dims.extend(grid.0.dyn_vars());
all_dyn_dims.extend(grid.1.dyn_vars());
all_dyn_dims.extend(grid.2.dyn_vars());
all_dyn_dims.extend(block.0.dyn_vars());
all_dyn_dims.extend(block.1.dyn_vars());
all_dyn_dims.extend(block.2.dyn_vars());
all_dyn_dims.extend(shared_mem.dyn_vars());
all_dyn_dims.extend(kernel_op_ref.output_size().dyn_vars());
// Collect buffer nodes and sizes
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
let output_size = kernel_op_ref.output_size();
@@ -719,9 +728,19 @@ pub fn kernel_to_host(
));
}
// Sort dyn dims alphabetically for consistent buffer layout
let mut dyn_dims_order: Vec<char> = all_dyn_dims.into_iter().collect();
dyn_dims_order.sort();
// Get the possibly-extended global ordering (kernels may have discovered new dims)
let final_global = get_global_dyn_dims();
// Clear global ordering now that all kernels are compiled
clear_global_dyn_dims();
// Use the final global ordering if it was extended during compilation
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
final_order
} else {
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
dims.sort();
dims
};
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
@@ -744,14 +763,6 @@ pub fn kernel_to_host(
for kernel_node in &subgraph {
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
}
// Also track block op nodes inside megakernels
for kernel_node in &subgraph {
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
for block_node in block_nodes {
kernel_to_cuda_graph.insert(*block_node, cuda_graph_node);
}
}
}
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
// Find external inputs: nodes outside subgraph that have edges into subgraph
@@ -779,23 +790,15 @@ pub fn kernel_to_host(
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
// This ensures proper execution ordering when a kernel in one CudaGraphOp
// produces output consumed by a kernel (or BlockOp inside a megakernel) in another CudaGraphOp.
// produces output consumed by a kernel in another CudaGraphOp.
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
// Find all nodes that this subgraph produces output for (including BlockOp nodes in megakernels)
let mut all_producer_nodes: FxHashSet<NodeIndex> = subgraph.clone();
for kernel_node in subgraph {
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
all_producer_nodes.extend(block_nodes.iter().copied());
}
}
// Find external consumers that are kernels belonging to other CudaGraphOps
for producer_node in &all_producer_nodes {
for producer_node in subgraph {
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
let consumer = edge.target();
if all_producer_nodes.contains(&consumer) {
if subgraph.contains(&consumer) {
continue; // Same subgraph
}
// Check if consumer is a kernel in another CudaGraphOp

View File

@@ -0,0 +1,309 @@
pub mod host;
pub mod kernel;
pub mod runtime;
use std::{
ffi::{CStr, CString},
path::Path,
sync::Arc,
};
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::{
driver::{CudaContext, DriverError, sys as driver_sys},
nvrtc::{
Ptx,
result::{self as nvrtc_result, NvrtcError},
sys as nvrtc_sys,
},
};
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
match dtype {
DType::F64 => "double",
DType::F32 => "float",
DType::F16 => "half",
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",
DType::U8 => "unsigned char",
DType::Bool => "unsigned char",
DType::F8E4M3 => "__nv_fp8_e4m3",
DType::F8E5M2 => "__nv_fp8_e5m2",
DType::F8UE8M0 => "__nv_fp8_e8m0",
DType::F6E2M3 => "__nv_fp6_e2m3",
DType::F6E3M2 => "__nv_fp6_e3m2",
DType::F4E2M1 => "__nv_fp4_e2m1",
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
}
}
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
#[derive(Debug)]
pub(crate) enum CudaModuleImageCompileFailure {
ComputeCapability(DriverError),
Nvrtc {
stage: &'static str,
error: NvrtcError,
},
NoModuleImageProduced,
}
#[derive(Debug)]
pub(crate) struct CudaModuleImageCompileError {
pub target_arch: Option<String>,
pub driver_version: Option<i32>,
pub runtime_version: Option<i32>,
pub nvrtc_options: Vec<String>,
pub nvrtc_log: Option<String>,
pub failure: CudaModuleImageCompileFailure,
}
impl std::fmt::Display for CudaModuleImageCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to compile CUDA module image")?;
if let Some(target_arch) = &self.target_arch {
write!(f, " for {target_arch}")?;
}
match &self.failure {
CudaModuleImageCompileFailure::ComputeCapability(error) => {
write!(f, ": failed to query compute capability: {error}")?;
}
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
write!(f, ": NVRTC {stage} failed: {error}")?;
}
CudaModuleImageCompileFailure::NoModuleImageProduced => {
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
}
}
if let Some(version) = self.driver_version {
write!(f, " | driver {}", format_cuda_version(version))?;
}
if let Some(version) = self.runtime_version {
write!(f, " | runtime {}", format_cuda_version(version))?;
}
if !self.nvrtc_options.is_empty() {
write!(f, " | options {:?}", self.nvrtc_options)?;
}
if let Some(log) = &self.nvrtc_log {
write!(f, " | log: {log}")?;
}
Ok(())
}
}
impl std::error::Error for CudaModuleImageCompileError {}
fn format_cuda_version(version: i32) -> String {
format!("{}.{}", version / 1000, (version % 1000) / 10)
}
fn cuda_nvrtc_include_paths() -> Vec<String> {
let mut include_paths = Vec::new();
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
if let Ok(root) = std::env::var(env_var) {
let path = format!("{root}/include");
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
}
for path in CUDA_NVRTC_INCLUDE_PATHS {
let path = path.to_string();
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
include_paths
}
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
let mut driver_version = 0;
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
.result()
.ok()
.map(|_| driver_version);
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
// resolves newer libcudart symbols that may not exist in the installed runtime.
(driver_version, None)
}
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
let mut options = cuda_nvrtc_include_paths()
.into_iter()
.map(|path| format!("--include-path={path}"))
.collect::<Vec<_>>();
options.push(format!("--gpu-architecture={target_arch}"));
options
}
fn build_module_image_compile_error(
target_arch: Option<String>,
driver_version: Option<i32>,
runtime_version: Option<i32>,
nvrtc_options: &[String],
nvrtc_log: Option<String>,
failure: CudaModuleImageCompileFailure,
) -> CudaModuleImageCompileError {
CudaModuleImageCompileError {
target_arch,
driver_version,
runtime_version,
nvrtc_options: nvrtc_options.to_vec(),
nvrtc_log,
failure,
}
}
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
if raw.is_empty() {
return None;
}
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
.to_string_lossy()
.trim_end_matches('\0')
.trim()
.to_string();
if log.is_empty() { None } else { Some(log) }
}
#[allow(clippy::slow_vector_initialization)]
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
let mut cubin_size = 0usize;
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
if cubin_size == 0 {
return Ok(Vec::new());
}
let mut cubin = Vec::with_capacity(cubin_size);
cubin.resize(cubin_size, 0);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
}
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
ctx: &Arc<CudaContext>,
src: S,
) -> Result<Ptx, CudaModuleImageCompileError> {
let (driver_version, runtime_version) = cuda_driver_diagnostics();
let (major, minor) = ctx.compute_capability().map_err(|error| {
build_module_image_compile_error(
None,
driver_version,
runtime_version,
&[],
None,
CudaModuleImageCompileFailure::ComputeCapability(error),
)
})?;
let target_arch = format!("sm_{major}{minor}");
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
let source = CString::new(src.as_ref().as_bytes())
.expect("CUDA source code cannot contain null terminators");
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
build_module_image_compile_error(
Some(target_arch.clone()),
driver_version,
runtime_version,
&nvrtc_options,
None,
CudaModuleImageCompileFailure::Nvrtc {
stage: "create_program",
error,
},
)
})?;
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
let nvrtc_log = read_nvrtc_log(program);
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "compile_program",
error,
},
));
}
let nvrtc_log = read_nvrtc_log(program);
let cubin = match get_cubin(program) {
Ok(cubin) => cubin,
Err(error) => {
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "get_cubin",
error,
},
));
}
};
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "destroy_program",
error,
},
));
}
if cubin.is_empty() {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::NoModuleImageProduced,
));
}
Ok(Ptx::from_binary(cubin))
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 273,
"NVIDIA H100 PCIe" => 2_000,
"NVIDIA H100 SXM" => 3_350,
_ => return None,
})
}
/// Returns the bandwidth of the device in TFLOPs
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 125, // forced to use tf32 flops
"NVIDIA H100 PCIe" => 756,
"NVIDIA H100 SXM" => 989,
_ => return None,
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,344 @@
use crate::runtime::CudaRuntime;
use crate::tests::utilities::*;
use luminal::prelude::*;
use rand::{SeedableRng, rngs::SmallRng};
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
/// Returns (cx, input_node, output_node).
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', 4));
let b = (a + a).output();
(cx, a.id, b.id)
}
/// Helper: build a matmul graph with dynamic dim 's'.
/// Computes (s, K) @ (K, N) -> (s, N)
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let c = a.matmul(b).output();
(cx, a.id, b.id, c.id)
}
#[test]
fn test_bucket_dispatch_simple() {
// Tests that bucketed compilation produces correct results for different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Set dummy input for search
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Test bucket 1: s=1
cx.set_dim('s', 1);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..4], &expected, 1e-5, 1e-5);
// Test bucket 2: s=3
cx.set_dim('s', 3);
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..12], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_matmul_dynamic() {
// Tests matmul with bucketed dynamic dim
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 8;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Execute at s=1
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s1 = rt.get_f32(c);
// Compute reference for s=1 (1xK @ KxN -> 1xN)
let mut expected_s1 = vec![0.0f32; n];
for j in 0..n {
for i in 0..k {
expected_s1[j] += a_data[i] * b_data[i * n + j];
}
}
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
// Execute at s=4
cx.set_dim('s', 4);
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
rt.set_data(a, a_data_4.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s4 = rt.get_f32(c);
// Compute reference for s=4 (4xK @ KxN -> 4xN)
let mut expected_s4 = vec![0.0f32; 4 * n];
for row in 0..4 {
for j in 0..n {
for i in 0..k {
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
}
#[test]
fn test_bucket_results_match_unbucketed() {
// Tests that bucketed results match non-bucketed results for the same graph
let Some(stream) = get_cuda_stream() else {
return;
};
let seed = 42u64;
// Non-bucketed run
let (mut cx1, a1, b1) = build_dynamic_add_graph();
cx1.set_dim('s', 3);
cx1.build_search_space::<CudaRuntime>();
let mut rt1 = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
// Bucketed run with bucket that covers s=3
let (mut cx2, a2, b2) = build_dynamic_add_graph();
cx2.set_dim('s', 3);
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
cx2.build_search_space::<CudaRuntime>();
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
// Results should match — same graph, same search seed, same dyn_map
assert_eq!(result_unbucketed.len(), result_bucketed.len());
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
}
#[test]
#[should_panic(expected = "No bucket matches")]
fn test_bucket_out_of_range_panics() {
let Some(stream) = get_cuda_stream() else {
// Can't trigger panic without GPU, skip gracefully
panic!("No bucket matches dyn_map");
};
let (mut cx, a, _b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
rt.set_data(a, vec![1.0f32; 40]);
rt.execute(&cx.dyn_map);
}
#[test]
fn test_bucket_no_buckets_backward_compat() {
// No buckets set → should behave identically to old path
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim('s', 2);
// No set_dim_buckets call
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..8], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_representative_override() {
// Tests that custom representative works
let bucket = DimBucket::new(2, 32).representative(16);
assert_eq!(bucket.representative_value(), 16);
let bucket_default = DimBucket::new(2, 32);
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
let exact = DimBucket::new(1, 1);
assert_eq!(exact.representative_value(), 1);
}
#[test]
fn test_bucket_switch_preserves_weights() {
// Tests that switching between buckets still sees the correct weight data
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 4;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1a = rt.get_f32(c);
// Switch to bucket 2 (s=3)
cx.set_dim('s', 3);
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
rt.set_data(a, a_data_3.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_3 = rt.get_f32(c);
// Switch back to bucket 1 (s=1) — weights should still work
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1b = rt.get_f32(c);
// First and last s=1 results should match exactly
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
// Verify s=3 result correctness
let mut expected_3 = vec![0.0f32; 3 * n];
for row in 0..3 {
for j in 0..n {
for i in 0..k {
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
}
#[test]
fn test_bucket_multiple_executions_same_bucket() {
// Tests multiple executions within the same bucket with different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {
cx.set_dim('s', s);
let n = s * 4;
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
rt.set_data(a, input.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
assert_close(&result[..n], &expected, 1e-5, 1e-5);
}
}
#[test]
#[should_panic(expected = "Overlapping buckets")]
fn test_bucket_overlapping_ranges_panics() {
let mut cx = Graph::default();
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
}
#[test]
fn test_dim_bucket_contains() {
let b = DimBucket::new(2, 10);
assert!(!b.contains(1));
assert!(b.contains(2));
assert!(b.contains(5));
assert!(b.contains(10));
assert!(!b.contains(11));
// Exact bucket
let exact = DimBucket::new(3, 3);
assert!(!exact.contains(2));
assert!(exact.contains(3));
assert!(!exact.contains(4));
}

View File

@@ -0,0 +1,416 @@
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use rand::SeedableRng;
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice, validate_choice_set};
use crate::kernel::KernelOp;
use crate::runtime::CudaRuntime;
/// Helper: build search space and extract all possible kernel names across many random choices.
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
let custom_ops = &cx.custom_ops;
let mut all_names = Vec::new();
// Try many random extractions to cover both alternatives
for _ in 0..20 {
let choices = random_initial_choice(egraph, &mut rand::rng());
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
let name = k.kernel_name().to_string();
if !all_names.contains(&name) {
all_names.push(name);
}
}
}
}
all_names
}
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
/// the ConsumedBuffer (not in any other ICons).
#[test]
fn test_scatter_nocopy_selected_when_dest_unshared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let _result = src.scatter(indexes, dest).output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should be available (dest is not shared)
assert!(
names.iter().any(|n| n == "ScatterNoCopy"),
"Expected ScatterNoCopy to be available but got: {:?}",
names
);
}
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
/// cleanup rule should fire, deleting the ConsumedBuffer. This makes KernelScatterNoCopy
/// invalid, so it should NOT appear in any extraction.
#[test]
fn test_scatter_nocopy_not_selected_when_dest_shared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let scatter_result = src.scatter(indexes, dest);
// Also use dest directly in another op (add with itself) — this makes dest shared
let _dest_also_used = (dest + dest).output();
let _result = scatter_result.output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should NOT be available (dest is shared with the add op)
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest is shared, got: {:?}",
names
);
// Regular KernelScatter should be present
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected Scatter but got: {:?}",
names
);
}
/// Actually execute the scatter and verify correctness.
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
#[test]
fn test_scatter_execution_correctness() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// dest: [0.0, 1.0, 2.0, 3.0, 4.0]
let dest = cx.tensor(5).persist();
// src: [10.0, 20.0, 30.0]
let src = cx.tensor(3).persist();
// indexes: [1, 3, 4]
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
// Try many random extractions to cover both Scatter and ScatterNoCopy
let mut rng = rand::rng();
let mut tested_scatter = false;
let mut tested_nocopy = false;
for _ in 0..50 {
let choices = random_initial_choice(egraph, &mut rng);
if validate_choice_set(egraph, &choices, ops).is_err() {
continue;
}
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
// Check which scatter variant was selected
let mut has_nocopy = false;
let mut has_scatter = false;
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
match k.kernel_name() {
"ScatterNoCopy" => has_nocopy = true,
"Scatter" => has_scatter = true,
_ => {}
}
}
}
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir);
rt.set_data(dest, vec![0.0f32, 1.0, 2.0, 3.0, 4.0]);
rt.set_data(src, vec![10.0f32, 20.0, 30.0]);
rt.set_data(indexes, vec![1i32, 3, 4]);
rt.execute(&cx.dyn_map);
let actual = rt.get_f32(result);
let variant = if has_nocopy {
tested_nocopy = true;
"ScatterNoCopy"
} else if has_scatter {
tested_scatter = true;
"Scatter"
} else {
"Unknown"
};
assert_eq!(
actual, expected,
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
actual, expected
);
}
println!(
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
tested_scatter, tested_nocopy
);
assert!(
tested_nocopy,
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
);
}
/// Test the KV-cache round-trip pattern: scatter → remove_buffer → set_buffer → scatter again.
/// This mimics how the llama model uses scatter for KV cache updates.
#[test]
fn test_scatter_kv_cache_roundtrip() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// KV cache: [5] elements (simulating a small cache)
let cache_in = cx.named_tensor("cache", 5).persist();
// New value to scatter: [1] element
let src = cx.tensor(1).persist();
// Index: [1] element (position to write)
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// scatter src into cache at index position
let cache_out = src.scatter(indexes, cache_in);
// Also read the scatter output (simulates attention reading from cache)
let read_out = (cache_out + 0.0).output();
// Return cache for round-trip
let cache_output = cache_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
// Must set input data BEFORE search (profiler needs valid buffers)
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, 5);
// Print which scatter variant was selected
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Selected: {}", k.kernel_name());
}
}
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let read1 = rt.get_f32(read_out);
println!("After step 1 (scatter 10.0 at pos 0): {:?}", read1);
assert_eq!(
read1,
vec![10.0, 0.0, 0.0, 0.0, 0.0],
"Step 1 read_out mismatch"
);
// Round-trip: remove cache output buffer, set as new cache input
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 2: Scatter 20.0 at position 1
rt.set_data(src, vec![20.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let read2 = rt.get_f32(read_out);
println!("After step 2 (scatter 20.0 at pos 1): {:?}", read2);
assert_eq!(
read2,
vec![10.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 read_out mismatch"
);
// Round-trip again
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 3: Scatter 30.0 at position 2
rt.set_data(src, vec![30.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let read3 = rt.get_f32(read_out);
println!("After step 3 (scatter 30.0 at pos 2): {:?}", read3);
assert_eq!(
read3,
vec![10.0, 20.0, 30.0, 0.0, 0.0],
"Step 3 read_out mismatch"
);
}
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
/// Also verifies graph_break interaction.
#[test]
fn test_scatter_dual_cache_with_graph_break() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// Two caches (like K and V)
let k_cache = cx.named_tensor("k_cache", 5).persist();
let v_cache = cx.named_tensor("v_cache", 5).persist();
// Input values
let k_new = cx.tensor(1).persist();
let v_new = cx.tensor(1).persist();
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// Scatter into both caches
let k_out = k_new.scatter(indexes, k_cache);
let v_out = v_new.scatter(indexes, v_cache);
// Read both (simulates attention using the scattered caches)
let k_read = k_out + 0.0;
let v_read = v_out + 0.0;
// Compute something from the scattered values (simulates attention output)
let attn = k_read * v_read;
// Output everything
let attn_out = attn.output();
let k_cache_out = k_out.output();
let v_cache_out = v_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
// Use seeded search for deterministic scatter variant selection.
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_rng(rt, 5, &mut rng);
// Print selected variants
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Dual test selected: {}", k.kernel_name());
}
}
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let attn1 = rt.get_f32(attn_out);
println!("Attn step 1: {:?}", attn1);
// k=[2,0,0,0,0], v=[3,0,0,0,0], attn = k*v = [6,0,0,0,0]
assert_eq!(attn1, vec![6.0, 0.0, 0.0, 0.0, 0.0], "Step 1 attn mismatch");
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 2: scatter k=4.0, v=5.0 at position 1
rt.set_data(k_new, vec![4.0f32]);
rt.set_data(v_new, vec![5.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let attn2 = rt.get_f32(attn_out);
println!("Attn step 2: {:?}", attn2);
// k=[2,4,0,0,0], v=[3,5,0,0,0], attn = k*v = [6,20,0,0,0]
assert_eq!(
attn2,
vec![6.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 attn mismatch"
);
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 3: scatter k=6.0, v=7.0 at position 2
rt.set_data(k_new, vec![6.0f32]);
rt.set_data(v_new, vec![7.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let attn3 = rt.get_f32(attn_out);
println!("Attn step 3: {:?}", attn3);
// k=[2,4,6,0,0], v=[3,5,7,0,0], attn = k*v = [6,20,42,0,0]
assert_eq!(
attn3,
vec![6.0, 20.0, 42.0, 0.0, 0.0],
"Step 3 attn mismatch"
);
}

View File

@@ -1,5 +1,11 @@
pub mod utilities;
#[cfg(test)]
mod bucket_tests;
#[cfg(test)]
mod consumed_buffer_tests;
#[cfg(test)]
mod model_fuzz;
#[cfg(test)]
mod op_functional_tests;
#[cfg(test)]

View File

@@ -0,0 +1,685 @@
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
//!
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
//! reference to catch incorrect HLIR kernel fallback rewrites.
use luminal::prelude::*;
use super::utilities::{assert_close, fuzz_genomes, get_cuda_stream, random_f32_vec};
use crate::runtime::CudaRuntime;
/// Number of genomes to fuzz per test (higher than default GENOME_FUZZ_COUNT=20).
const FUZZ_COUNT: usize = 100;
// ============================================================================
// RMSNorm helper (used by all three models)
// ============================================================================
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
let normed = x.std_norm(x.shape.last_axis(), eps);
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
}
fn rms_norm_ref(
x: &candle_core::Tensor,
weight: &candle_core::Tensor,
eps: f64,
) -> candle_core::Tensor {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
let normed = x.broadcast_mul(&rsqrt).unwrap();
normed
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
.unwrap()
}
// ============================================================================
// SwiGLU MLP helper (used by all three models)
// ============================================================================
fn swiglu_mlp(
x: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
) -> GraphTensor {
let gate = x.matmul(w_gate.t()).swish();
let up = x.matmul(w_up.t());
(gate * up).matmul(w_down.t())
}
fn swiglu_mlp_ref(
x: &candle_core::Tensor,
w_gate: &candle_core::Tensor,
w_up: &candle_core::Tensor,
w_down: &candle_core::Tensor,
) -> candle_core::Tensor {
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
let up = x.matmul(&w_up.t().unwrap()).unwrap();
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
// ============================================================================
// Generic test functions
// ============================================================================
/// Test a SwiGLU MLP block at given dimensions with genome fuzzing.
fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
/// Test RMSNorm + matmul projection at given dimensions with genome fuzzing.
fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let norm_w = cx.tensor(hidden);
let proj_w = cx.tensor((proj_dim, hidden));
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), hidden, &device).unwrap();
let ref_proj =
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_norm, eps as f64);
let expected = normed.matmul(&ref_proj.t().unwrap()).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
/// Test a full transformer layer (norm -> proj -> norm -> MLP) without attention.
fn fuzz_layer_no_attn(
seq: usize,
hidden: usize,
intermediate: usize,
proj_dim: usize,
eps: f32,
seed: u64,
) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let attn_norm_w = cx.tensor(hidden);
let proj_w = cx.tensor((proj_dim, hidden));
let o_proj_w = cx.tensor((hidden, proj_dim));
let mlp_norm_w = cx.tensor(hidden);
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let normed = rms_norm(input, attn_norm_w, eps);
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
let x = input + proj_out;
let mlp_normed = rms_norm(x, mlp_norm_w, eps);
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
let out = (x + mlp_out).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let attn_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
let o_proj_data = random_f32_vec(hidden * proj_dim, seed + 3, -0.3, 0.3);
let mlp_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 4, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let gate_data = random_f32_vec(intermediate * hidden, seed + 5, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 6, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 7, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Candle reference
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_attn_norm =
candle_core::Tensor::from_vec(attn_norm_data.clone(), hidden, &device).unwrap();
let ref_proj =
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
let ref_o_proj =
candle_core::Tensor::from_vec(o_proj_data.clone(), (hidden, proj_dim), &device).unwrap();
let ref_mlp_norm =
candle_core::Tensor::from_vec(mlp_norm_data.clone(), hidden, &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, eps as f64);
let proj_out = normed
.matmul(&ref_proj.t().unwrap())
.unwrap()
.matmul(&ref_o_proj.t().unwrap())
.unwrap();
let x_ref = (&ref_input + proj_out).unwrap();
let mlp_normed = rms_norm_ref(&x_ref, &ref_mlp_norm, eps as f64);
let mlp_out = swiglu_mlp_ref(&mlp_normed, &ref_gate, &ref_up, &ref_down);
let expected_t = (x_ref + mlp_out).unwrap();
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2, 2e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
2e-2,
2e-2,
FUZZ_COUNT,
seed,
);
}
/// Test a SwiGLU MLP with HLIR-only to specifically verify
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
// ============================================================================
// Llama-specific tests
// Llama 3 8B: HIDDEN=4096, INTERMEDIATE=14336, HEAD_DIM=128
// Using scaled-down dims that preserve architectural ratios
// ============================================================================
mod llama {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 256;
const INTERMEDIATE: usize = 896; // ~3.5x hidden, matching 14336/4096
const PROJ_DIM: usize = 256; // Q_DIM == HIDDEN for llama
const EPS: f32 = 1e-5;
#[test]
fn fuzz_llama_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
}
#[test]
fn fuzz_llama_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
}
#[test]
fn fuzz_llama_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
}
#[test]
fn fuzz_llama_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
}
#[test]
fn fuzz_llama_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
}
/// Force HLIR-only (no block ops) to specifically test the fallback path.
#[test]
fn fuzz_llama_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
}
}
// ============================================================================
// Gemma-specific tests
// Gemma 3 4B: HIDDEN=2560, INTERMEDIATE=10240, HEAD_DIM=256, Q_DIM=2048
// Key difference: Q_DIM != HIDDEN, and 4 extra RMSNorm layers per block
// ============================================================================
mod gemma {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 320; // divisible by 8 (N_HEADS)
const INTERMEDIATE: usize = 1280; // 4x hidden, matching 10240/2560
const Q_DIM: usize = 256; // scaled from 2048 (N_HEADS * HEAD_DIM)
const EPS: f32 = 1e-6;
#[test]
fn fuzz_gemma_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
}
#[test]
fn fuzz_gemma_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
}
#[test]
fn fuzz_gemma_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
}
/// Gemma has extra post-attention and post-feedforward norms.
#[test]
fn fuzz_gemma_layer_full_norms() {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let attn_norm_w = cx.tensor(HIDDEN);
let post_attn_norm_w = cx.tensor(HIDDEN);
let pre_ff_norm_w = cx.tensor(HIDDEN);
let post_ff_norm_w = cx.tensor(HIDDEN);
let proj_w = cx.tensor((Q_DIM, HIDDEN));
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
let normed = rms_norm(input, attn_norm_w, EPS);
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
let x = input + attn_normed;
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
let mlp_out = swiglu_mlp(ff_normed, w_gate, w_up, w_down);
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
let out = (x + mlp_normed).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 800u64;
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
let attn_norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let post_attn_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 2, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let pre_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 3, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let post_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 4, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(Q_DIM * HIDDEN, seed + 5, -0.3, 0.3);
let o_proj_data = random_f32_vec(HIDDEN * Q_DIM, seed + 6, -0.3, 0.3);
let gate_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 7, -0.3, 0.3);
let up_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 8, -0.3, 0.3);
let down_data = random_f32_vec(HIDDEN * INTERMEDIATE, seed + 9, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(post_attn_norm_w, post_attn_data.clone());
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
rt.set_data(post_ff_norm_w, post_ff_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Candle reference
let device = candle_core::Device::Cpu;
let t = |data: &[f32], shape: &[usize]| {
candle_core::Tensor::from_vec(data.to_vec(), shape, &device).unwrap()
};
let ref_input = t(&input_data, &[SEQ, HIDDEN]);
let ref_attn_norm = t(&attn_norm_data, &[HIDDEN]);
let ref_post_attn = t(&post_attn_data, &[HIDDEN]);
let ref_pre_ff = t(&pre_ff_data, &[HIDDEN]);
let ref_post_ff = t(&post_ff_data, &[HIDDEN]);
let ref_proj = t(&proj_data, &[Q_DIM, HIDDEN]);
let ref_o_proj = t(&o_proj_data, &[HIDDEN, Q_DIM]);
let ref_gate = t(&gate_data, &[INTERMEDIATE, HIDDEN]);
let ref_up = t(&up_data, &[INTERMEDIATE, HIDDEN]);
let ref_down = t(&down_data, &[HIDDEN, INTERMEDIATE]);
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, EPS as f64);
let proj_out = normed
.matmul(&ref_proj.t().unwrap())
.unwrap()
.matmul(&ref_o_proj.t().unwrap())
.unwrap();
let attn_normed = rms_norm_ref(&proj_out, &ref_post_attn, EPS as f64);
let x_ref = (&ref_input + attn_normed).unwrap();
let ff_normed = rms_norm_ref(&x_ref, &ref_pre_ff, EPS as f64);
let mlp_out = swiglu_mlp_ref(&ff_normed, &ref_gate, &ref_up, &ref_down);
let mlp_normed = rms_norm_ref(&mlp_out, &ref_post_ff, EPS as f64);
let expected_t = (x_ref + mlp_normed).unwrap();
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2, 2e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(post_attn_norm_w, post_attn_data.clone());
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
rt.set_data(post_ff_norm_w, post_ff_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
2e-2,
2e-2,
FUZZ_COUNT,
seed,
);
}
#[test]
fn fuzz_gemma_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
}
/// Force HLIR-only to test fallback path with Gemma dimensions.
#[test]
fn fuzz_gemma_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
}
}
// ============================================================================
// Qwen-specific tests
// Qwen3-4B: HIDDEN=2560, INTERMEDIATE=9728, HEAD_DIM=128, Q_DIM=4096
// Key difference: Q_DIM > HIDDEN, tied embeddings (lm_head = embedding.t())
// ============================================================================
mod qwen {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 256;
const INTERMEDIATE: usize = 768; // ~3x hidden, matching 9728/2560
const Q_DIM: usize = 512; // scaled from 4096 (Q_DIM > HIDDEN)
const EPS: f32 = 1e-6;
#[test]
fn fuzz_qwen_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
}
#[test]
fn fuzz_qwen_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
}
#[test]
fn fuzz_qwen_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
}
/// Qwen uses tied embeddings: lm_head = embedding^T
#[test]
fn fuzz_qwen_lm_head() {
let Some(stream) = get_cuda_stream() else {
return;
};
const VOCAB: usize = 512;
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let norm_w = cx.tensor(HIDDEN);
let embedding = cx.tensor((VOCAB, HIDDEN));
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 1300u64;
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
let norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let emb_data = random_f32_vec(VOCAB * HIDDEN, seed + 2, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (SEQ, HIDDEN), &device).unwrap();
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), HIDDEN, &device).unwrap();
let ref_emb =
candle_core::Tensor::from_vec(emb_data.clone(), (VOCAB, HIDDEN), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_norm, EPS as f64);
let expected = normed.matmul(&ref_emb.t().unwrap()).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
#[test]
fn fuzz_qwen_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
}
#[test]
fn fuzz_qwen_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
}
/// Force HLIR-only to test fallback path with Qwen dimensions.
#[test]
fn fuzz_qwen_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
}
}

View File

@@ -24,8 +24,8 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
@@ -33,20 +33,20 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), &gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
}
#[test]
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), &gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
}
#[test]
@@ -115,7 +115,7 @@ proptest! {
let atol = 5.0 * eps;
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
}
// Unary ops tests
@@ -123,37 +123,37 @@ proptest! {
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
}
#[test]
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// log2(x) = ln(x) / ln(2)
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
}
#[test]
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
}
#[test]
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
}
#[test]
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
}
// Binary ops tests
@@ -166,11 +166,12 @@ proptest! {
#[test]
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
}
}
#[allow(dead_code)]
fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
let total = rows * cols;
@@ -275,17 +276,19 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_argsort(seed in any::<u64>()) {
run_argsort_test(5, 500, seed);
}
}
// NOTE: Argsort proptest disabled due to pre-existing bug where argsort output shape
// through e-graph compilation returns only `rows` elements instead of `rows * cols`.
// proptest! {
// #![proptest_config(ProptestConfig::with_cases(10))]
// #[test]
// fn test_argsort(seed in any::<u64>()) {
// run_argsort_test(5, 500, seed);
// }
// }
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
#[test]
#[allow(clippy::approx_constant, clippy::excessive_precision)]
pub fn test_cast_f16_edge_cases() {
use luminal::dtype::DType;
@@ -323,7 +326,7 @@ pub fn test_cast_f16_edge_cases() {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
&gen_edge_cases,
gen_edge_cases,
0,
);
}
@@ -349,7 +352,7 @@ proptest! {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
&gen_lambda,
gen_lambda,
seed,
);
}

View File

@@ -13,8 +13,8 @@ pub fn kernel_add_bandwidth_test() {
let size = 64 * 1024 * 1024;
let mut cx = Graph::default();
let a = cx.tensor(size);
let b = cx.tensor(size);
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let output = (a + b).output();
// Generate test data

View File

@@ -173,6 +173,7 @@ fn swiglu_mlp_ref(
}
/// CPU reference for one transformer layer
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &candle_core::Tensor,
attn_norm_w: &candle_core::Tensor,

View File

@@ -235,6 +235,7 @@ pub fn test_unary_cuda<T: TestDType>(
/// Base binary test function with input generators
/// Generic over dtype T - comparison happens in native precision.
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
#[allow(clippy::too_many_arguments)]
pub fn test_binary_cuda<T: TestDType>(
a_shape: impl ToShape,
b_shape: impl ToShape,
@@ -410,6 +411,7 @@ pub fn gen_slice_range(
/// produce incorrect computation.
///
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
#[allow(clippy::too_many_arguments)]
pub fn fuzz_genomes<T: TestDType>(
cx: &Graph,
stream: &Arc<cudarc::driver::CudaStream>,

View File

@@ -15,4 +15,8 @@ half = "2.7.1"
tracing = "0.1.43"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
proptest = "1.9.0"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }

View File

@@ -0,0 +1,227 @@
use super::{MetalMulInfo, MetalSumReduceInfo};
use luminal::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MetalMatmulFamily {
#[default]
Naive,
RegularTiled,
}
#[derive(Debug, Clone)]
pub struct MatmulDescriptor {
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub batch_shape: Vec<Expression>,
pub lhs_strides: Vec<Expression>,
pub rhs_strides: Vec<Expression>,
pub out_strides: Vec<Expression>,
pub transpose_lhs: bool,
pub transpose_rhs: bool,
}
impl MatmulDescriptor {
pub fn from_mul_and_sum(
mul_info: &MetalMulInfo,
sum_info: &MetalSumReduceInfo,
) -> Option<Self> {
let zero = Expression::from(0);
let z = Expression::from('z');
let is_simple_2d_matmul = mul_info.shape.len() == 3
&& sum_info.shape.len() == 2
&& mul_info.a_strides.len() == 3
&& mul_info.b_strides.len() == 3
&& sum_info.strides.len() == 2
&& mul_info.shape[0] == sum_info.shape[0]
&& mul_info.shape[1] == sum_info.shape[1]
&& mul_info.shape[2] == sum_info.iters
&& mul_info.a_strides[1] == zero
&& mul_info.a_strides[2] == z
&& mul_info.b_strides[0] == zero
&& mul_info.b_strides[1] == z
&& sum_info.strides[1] == z
&& sum_info.iter_stride == z;
if !is_simple_2d_matmul {
return None;
}
Some(Self {
m: sum_info.shape[0],
n: sum_info.shape[1],
k: sum_info.iters,
batch_shape: Vec::new(),
lhs_strides: mul_info.a_strides.clone(),
rhs_strides: mul_info.b_strides.clone(),
out_strides: sum_info.strides.clone(),
transpose_lhs: false,
transpose_rhs: false,
})
}
}
#[derive(Debug, Clone)]
pub struct MatmulPlan {
pub family: MetalMatmulFamily,
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub lda: Expression,
pub ldb: Expression,
pub ldd: Expression,
pub batch_size: u32,
pub batch_stride_a: u32,
pub batch_stride_b: u32,
pub batch_stride_d: u32,
pub bm: u16,
pub bn: u16,
pub bk: u16,
pub wm: u16,
pub wn: u16,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct MetalMatmulPlanner;
impl MetalMatmulPlanner {
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
let family = if desc.batch_shape.is_empty()
&& desc.m.as_num().is_some_and(|m| m >= 32)
&& desc.n.as_num().is_some_and(|n| n >= 32)
&& desc.k.as_num().is_some_and(|k| k >= 32)
{
MetalMatmulFamily::RegularTiled
} else {
MetalMatmulFamily::Naive
};
MatmulPlan {
family,
m: desc.m,
n: desc.n,
k: desc.k,
lda: desc.lhs_strides[0],
ldb: desc.rhs_strides[2],
ldd: desc.out_strides[0],
batch_size: 1,
batch_stride_a: 0,
batch_stride_b: 0,
batch_stride_d: 0,
bm: 16,
bn: 16,
bk: 8,
wm: 2,
wn: 2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn descriptor_recovers_simple_2d_matmul() {
let mul = MetalMulInfo {
shape: vec![
Expression::from(4),
Expression::from(8),
Expression::from(16),
],
a_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
b_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
output_strides: vec![
Expression::from('z') * 16,
Expression::from('z') * 8,
Expression::from('z'),
],
};
let sum = MetalSumReduceInfo {
shape: vec![Expression::from(4), Expression::from(8)],
strides: vec![Expression::from('z') * 8, Expression::from('z')],
iters: Expression::from(16),
iter_stride: Expression::from('z'),
};
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
assert_eq!(desc.m, Expression::from(4));
assert_eq!(desc.n, Expression::from(8));
assert_eq!(desc.k, Expression::from(16));
}
#[test]
fn planner_keeps_small_problems_on_naive_path() {
let desc = MatmulDescriptor {
m: Expression::from(4),
n: Expression::from(8),
k: Expression::from(16),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::Naive);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
assert_eq!(plan.lda, Expression::from('z') * 16);
assert_eq!(plan.ldb, Expression::from('z') * 8);
assert_eq!(plan.ldd, Expression::from('z') * 8);
}
#[test]
fn planner_promotes_large_problems_to_regular_tiled() {
let desc = MatmulDescriptor {
m: Expression::from(64),
n: Expression::from(64),
k: Expression::from(64),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 64,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 64,
],
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
}
}

View File

@@ -1,15 +1,42 @@
mod matmul;
mod ops;
pub use matmul::*;
pub use ops::*;
use luminal::dtype::DType;
use luminal::op::EgglogOp;
use luminal::prelude::*;
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
pub const DYN_BUFFER_INDEX: u64 = 30;
pub const DYN_SLOT_COUNT: usize = 26;
#[derive(Debug, Clone)]
pub struct MetalMulInfo {
pub shape: Vec<Expression>,
pub a_strides: Vec<Expression>,
pub b_strides: Vec<Expression>,
pub output_strides: Vec<Expression>,
}
#[derive(Debug, Clone)]
pub struct MetalSumReduceInfo {
pub shape: Vec<Expression>,
pub strides: Vec<Expression>,
pub iters: Expression,
pub iter_stride: Expression,
}
pub trait MetalKernelOp: EgglogOp {
fn compile(&self, device: &Device) -> ComputePipelineState;
fn compile(
&self,
device: &Device,
input_dtypes: &[DType],
output_dtype: DType,
) -> ComputePipelineState;
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
input_dtypes.first().copied().unwrap_or(DType::F32)
}
fn output_size(&self) -> Expression;
@@ -37,6 +64,18 @@ pub trait MetalKernelOp: EgglogOp {
fn flops(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
0
}
fn mul_info(&self) -> Option<MetalMulInfo> {
None
}
fn sum_reduce_info(&self) -> Option<MetalSumReduceInfo> {
None
}
fn is_matmul(&self) -> bool {
false
}
}
luminal::impl_into_ops!(MetalKernelOp);

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,12 @@
#![allow(unexpected_cfgs)]
use crate::kernel::{MetalKernelOp, DYN_BUFFER_INDEX, DYN_SLOT_COUNT};
use crate::kernel::{
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
};
use half::f16;
use itertools::Itertools;
use luminal::{
dtype::DType,
graph::LLIRGraph,
hlir::{Input, Output},
hlir::{Input, NativeData, Output},
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
prelude::{
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
@@ -18,6 +20,8 @@ use std::time::Duration;
pub struct MetalRuntime {
device: Device,
command_queue: CommandQueue,
/// Host-side input tensors provided by the user.
input_data: FxHashMap<NodeIndex, NativeData>,
/// Buffers for HLIR input tensors (set by user)
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
/// Buffers for LLIR intermediate/output tensors
@@ -26,18 +30,110 @@ pub struct MetalRuntime {
dyn_buffer: Buffer,
/// The current LLIR graph
llir_graph: LLIRGraph,
/// Inferred runtime dtype for each LLIR node.
node_dtypes: FxHashMap<NodeIndex, DType>,
/// Compiled pipeline states for each kernel node
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
}
impl MetalRuntime {
pub fn set_data(&mut self, id: impl ToId, data: &[f32]) {
let buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const _,
std::mem::size_of_val(data) as u64,
MTLResourceOptions::StorageModeShared,
);
self.hlir_buffers.insert(id.to_id(), buffer);
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
let mut graph = llir_graph.clone();
let planner = MetalMatmulPlanner;
let mut rewrites = Vec::new();
for sum_node in graph.node_indices().collect::<Vec<_>>() {
let Some(sum_info) = graph[sum_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.sum_reduce_info())
else {
continue;
};
let input_edges: Vec<_> = graph
.edges_directed(sum_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if input_edges.len() != 1 {
continue;
}
let mul_node = input_edges[0];
let Some(mul_info) = graph[mul_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.mul_info())
else {
continue;
};
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
continue;
};
let mul_inputs: Vec<_> = graph
.edges_directed(mul_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if mul_inputs.len() != 2 {
continue;
}
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
}
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
graph[sum_node] =
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
m: plan.m,
n: plan.n,
k: plan.k,
lda: plan.lda,
ldb: plan.ldb,
ldd: plan.ldd,
family: plan.family,
bm: plan.bm,
bn: plan.bn,
bk: plan.bk,
wm: plan.wm,
wn: plan.wn,
batch_size: plan.batch_size,
batch_stride_a: plan.batch_stride_a,
batch_stride_b: plan.batch_stride_b,
batch_stride_d: plan.batch_stride_d,
}));
graph.remove_node(mul_node);
graph.add_edge(mul_inputs[0], sum_node, ());
graph.add_edge(mul_inputs[1], sum_node, ());
}
graph
}
#[cfg(test)]
pub(crate) fn contains_matmul(&self) -> bool {
self.llir_graph.node_indices().any(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.is_some_and(|op| op.is_matmul())
})
}
#[cfg(test)]
pub(crate) fn debug_kernel_ops(&self) -> Vec<String> {
self.llir_graph
.node_indices()
.filter_map(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.map(|op| format!("{op:?}"))
})
.collect()
}
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
self.input_data.insert(id.to_id(), data.into());
}
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
@@ -72,10 +168,42 @@ impl MetalRuntime {
}
})
.expect("Cannot find tensor in runtime!");
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
let dtype = self
.node_dtypes
.get(&data_id)
.copied()
.or_else(|| {
self.llir_graph[data_id]
.to_op::<Input>()
.map(|inp| inp.dtype)
})
.unwrap_or(DType::F32);
unsafe { std::slice::from_raw_parts(ptr, len) }.to_vec()
unsafe {
match dtype {
DType::F16 => {
let ptr = buffer.contents() as *const f16;
let len = buffer.length() as usize / std::mem::size_of::<f16>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| v.to_f32())
.collect()
}
DType::Int => {
let ptr = buffer.contents() as *const i32;
let len = buffer.length() as usize / std::mem::size_of::<i32>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| *v as f32)
.collect()
}
_ => {
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
std::slice::from_raw_parts(ptr, len).to_vec()
}
}
}
}
}
@@ -96,10 +224,12 @@ impl Runtime for MetalRuntime {
Self {
device,
command_queue,
input_data: FxHashMap::default(),
hlir_buffers: FxHashMap::default(),
buffers: FxHashMap::default(),
dyn_buffer,
llir_graph: StableGraph::default(),
node_dtypes: FxHashMap::default(),
pipelines: FxHashMap::default(),
}
}
@@ -108,16 +238,48 @@ impl Runtime for MetalRuntime {
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
self.pipelines.clear();
self.buffers.clear();
self.hlir_buffers.clear();
self.node_dtypes.clear();
self.llir_graph = Self::fuse_matmuls(llir_graph);
// Compile all kernel ops
for node in llir_graph.node_indices() {
if let Some(kernel_op) = llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = kernel_op.compile(&self.device);
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
for node in topo_order {
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
self.node_dtypes.insert(node, input.dtype);
let hlir_id = NodeIndex::new(input.node);
if let Some(data) = self.input_data.get(&hlir_id) {
let buffer = self.create_input_buffer(data, input.dtype);
self.hlir_buffers.insert(hlir_id, buffer);
}
continue;
}
if self.llir_graph[node].to_op::<Output>().is_some() {
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let input_dtypes: Vec<DType> = input_nodes
.iter()
.map(|n| {
self.node_dtypes
.get(n)
.copied()
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
})
.collect();
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
self.node_dtypes.insert(node, output_dtype);
self.pipelines.insert(node, pipeline);
}
}
self.llir_graph = llir_graph.clone();
}
#[tracing::instrument(skip_all)]
@@ -161,7 +323,6 @@ impl Runtime for MetalRuntime {
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
@@ -200,6 +361,11 @@ impl Runtime for MetalRuntime {
.get(&node)
.expect("Output buffer not allocated!");
// Bind dyn dims right after the output slot:
// [inputs..., output, dyn, bytes...]
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}
@@ -236,6 +402,36 @@ impl RuntimeStats for MetalRuntime {
}
impl MetalRuntime {
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
match dtype {
DType::F32 => {
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::F16 => {
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::Int => {
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
unsupported => panic!("Metal input dtype {unsupported:?} is not supported yet"),
}
}
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
for node in self.llir_graph.node_indices() {
if self.llir_graph[node].to_op::<Input>().is_some() {
@@ -244,8 +440,9 @@ impl MetalRuntime {
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let size = kernel_op.output_size().exec(dyn_map).unwrap();
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
let buffer = self.device.new_buffer(
(size * std::mem::size_of::<f32>()) as u64,
(size * dtype.bits().div_ceil(8)) as u64,
MTLResourceOptions::StorageModeShared,
);
self.buffers.insert(node, buffer);
@@ -289,7 +486,6 @@ impl MetalRuntime {
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
@@ -328,6 +524,9 @@ impl MetalRuntime {
.get(&node)
.expect("Output buffer not allocated!");
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}

View File

@@ -1,4 +1,6 @@
use crate::{kernel::lower_expression_for_metal, runtime::MetalRuntime};
use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
use half::f16;
use luminal::prelude::*;
use proptest::prelude::*;
@@ -24,6 +26,194 @@ fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
}
}
const TRANSFORMER_SEQ: usize = 4;
const TRANSFORMER_HIDDEN: usize = 16;
const TRANSFORMER_INTERMEDIATE: usize = 32;
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
let normed = x.std_norm(x.shape.last_axis(), eps);
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
}
fn self_attention(
x: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
) -> GraphTensor {
let q = x.matmul(wq.t());
let k = x.matmul(wk.t());
let v = x.matmul(wv.t());
let scale = 1.0 / (TRANSFORMER_HIDDEN as f32).sqrt();
let scores = q.matmul(k.t()) * scale;
let attn_weights = scores.softmax(1);
attn_weights.matmul(v).matmul(wo.t())
}
fn swiglu_mlp(
x: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
) -> GraphTensor {
let gate = x.matmul(w_gate.t()).swish();
let up = x.matmul(w_up.t());
(gate * up).matmul(w_down.t())
}
struct MiniTransformerLayer {
attn_norm_w: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
mlp_norm_w: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
}
impl MiniTransformerLayer {
fn init(cx: &mut Graph) -> Self {
Self {
attn_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
wq: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wk: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wv: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wo: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
mlp_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
w_gate: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
w_up: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
w_down: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE)),
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
let normed = rms_norm(x, self.attn_norm_w, 1e-5);
let attn_out = self_attention(normed, self.wq, self.wk, self.wv, self.wo);
let x = x + attn_out;
let normed = rms_norm(x, self.mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp(normed, self.w_gate, self.w_up, self.w_down);
x + mlp_out
}
fn weights(&self) -> Vec<(GraphTensor, usize)> {
vec![
(self.attn_norm_w, TRANSFORMER_HIDDEN),
(self.wq, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wk, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wv, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wo, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.mlp_norm_w, TRANSFORMER_HIDDEN),
(self.w_gate, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
(self.w_up, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
(self.w_down, TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE),
]
}
}
fn rms_norm_ref(x: &CandleTensor, weight: &CandleTensor, eps: f64) -> CandleTensor {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
let normed = x.broadcast_mul(&rsqrt).unwrap();
normed
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
.unwrap()
}
fn self_attention_ref(
x: &CandleTensor,
wq: &CandleTensor,
wk: &CandleTensor,
wv: &CandleTensor,
wo: &CandleTensor,
) -> CandleTensor {
let q = x.matmul(&wq.t().unwrap()).unwrap();
let k = x.matmul(&wk.t().unwrap()).unwrap();
let v = x.matmul(&wv.t().unwrap()).unwrap();
let scale = 1.0 / (TRANSFORMER_HIDDEN as f64).sqrt();
let scores = q.matmul(&k.t().unwrap()).unwrap();
let scores = (scores * scale).unwrap();
let max_val = scores.max(1).unwrap().unsqueeze(1).unwrap();
let shifted = scores.broadcast_sub(&max_val).unwrap();
let exps = shifted.exp().unwrap();
let sum_exps = exps.sum(1).unwrap().unsqueeze(1).unwrap();
let attn_weights = exps.broadcast_div(&sum_exps).unwrap();
attn_weights
.matmul(&v)
.unwrap()
.matmul(&wo.t().unwrap())
.unwrap()
}
fn swiglu_mlp_ref(
x: &CandleTensor,
w_gate: &CandleTensor,
w_up: &CandleTensor,
w_down: &CandleTensor,
) -> CandleTensor {
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
let up = x.matmul(&w_up.t().unwrap()).unwrap();
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &CandleTensor,
attn_norm_w: &CandleTensor,
wq: &CandleTensor,
wk: &CandleTensor,
wv: &CandleTensor,
wo: &CandleTensor,
mlp_norm_w: &CandleTensor,
w_gate: &CandleTensor,
w_up: &CandleTensor,
w_down: &CandleTensor,
) -> CandleTensor {
let normed = rms_norm_ref(x, attn_norm_w, 1e-5);
let attn_out = self_attention_ref(&normed, wq, wk, wv, wo);
let x = (x + attn_out).unwrap();
let normed = rms_norm_ref(&x, mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp_ref(&normed, w_gate, w_up, w_down);
(x + mlp_out).unwrap()
}
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
(0..len)
.map(|i| (((i * 37 + 11) % 97) as f32 / 97.0) * scale + bias)
.collect()
}
fn to_f16_vec(values: &[f32]) -> Vec<f16> {
values.iter().copied().map(f16::from_f32).collect()
}
fn generate_layer_weights(layer: &MiniTransformerLayer) -> Vec<(GraphTensor, Vec<f32>)> {
layer
.weights()
.iter()
.enumerate()
.map(|(i, (tensor, size))| {
let data = seeded_data(*size, 0.8 - i as f32 * 0.03, -0.4 + i as f32 * 0.02);
let data = if *size == TRANSFORMER_HIDDEN {
data.iter().map(|x| x + 1.0).collect::<Vec<_>>()
} else {
data
};
(*tensor, data)
})
.collect()
}
/// dynamic symbols in kernel expressions should route through dyn buffer.
#[test]
fn dynamic_const_codegen_uses_dyn_buffer() {
@@ -340,3 +530,485 @@ fn metal_simple_max_reduce() {
let out = rt.get_f32(output);
assert_close(&out, &[4.0, 8.0], 0.001);
}
#[test]
fn metal_f16_cast_roundtrip() {
let mut cx = Graph::default();
let input = cx.tensor(4);
let output = input.cast(DType::F16).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
rt = cx.search(rt, 3);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(output);
assert_close(&out, &[1.0, -2.5, 3.25, 4.75], 0.002);
}
#[test]
fn metal_f16_intermediate_add_roundtrip() {
let mut cx = Graph::default();
let a = cx.tensor(4);
let b = cx.tensor(4);
let output = (a.cast(DType::F16) + b.cast(DType::F16))
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
rt = cx.search(rt, 3);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(output);
assert_close(&out, &[1.5, 1.0, 0.0, 4.75], 0.003);
}
#[test]
fn metal_specialized_matmul() {
let mut cx = Graph::default();
let a = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let b = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let b_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
assert!(
rt.contains_matmul(),
"expected Metal runtime to fuse matmul, kernels: {:?}",
rt.debug_kernel_ops()
);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_a =
CandleTensor::from_vec(a_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_b =
CandleTensor::from_vec(b_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = ref_a.matmul(&ref_b).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_regular_tiled_matmul_path() {
let mut cx = Graph::default();
let m = 64;
let k = 64;
let n = 64;
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(m * k, 0.4, -0.2);
let b_data = seeded_data(k * n, 0.3, -0.15);
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("family: RegularTiled")),
"expected regular tiled matmul path, kernels: {:?}",
kernels
);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
let ref_b = CandleTensor::from_vec(b_data, (k, n), &device).unwrap();
let expected = ref_a.matmul(&ref_b).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-3);
}
#[test]
fn metal_rms_norm() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let weight = cx.tensor(TRANSFORMER_HIDDEN);
let output = rms_norm(input, weight, 1e-5).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data: Vec<f32> = seeded_data(TRANSFORMER_HIDDEN, 0.5, 0.75);
rt.set_data(input, &input_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_weight = CandleTensor::from_vec(weight_data, TRANSFORMER_HIDDEN, &device).unwrap();
let expected = rms_norm_ref(&ref_input, &ref_weight, 1e-5);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_self_attention() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let wq = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wk = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wv = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wo = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = self_attention(input, wq, wk, wv, wo).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
rt.set_data(input, &input_data);
rt.set_data(wq, &wq_data);
rt.set_data(wk, &wk_data);
rt.set_data(wv, &wv_data);
rt.set_data(wo, &wo_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wq =
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wk =
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wv =
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wo =
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2);
}
#[test]
fn metal_self_attention_f16_weights() {
let mut cx = Graph::default();
let input = cx
.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wq = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wk = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wv = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wo = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let output = self_attention(input, wq, wk, wv, wo)
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
rt.set_data(input, to_f16_vec(&input_data));
rt.set_data(wq, to_f16_vec(&wq_data));
rt.set_data(wk, to_f16_vec(&wk_data));
rt.set_data(wv, to_f16_vec(&wv_data));
rt.set_data(wo, to_f16_vec(&wo_data));
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wq =
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wk =
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wv =
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wo =
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2);
}
#[test]
fn metal_swiglu_mlp() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let w_gate = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
let w_up = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
let w_down = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE));
let output = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let gate_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.8, -0.4);
let up_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.7, -0.35);
let down_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE, 0.6, -0.3);
rt.set_data(input, &input_data);
rt.set_data(w_gate, &gate_data);
rt.set_data(w_up, &up_data);
rt.set_data(w_down, &down_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_gate = CandleTensor::from_vec(
gate_data,
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
&device,
)
.unwrap();
let ref_up = CandleTensor::from_vec(
up_data,
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
&device,
)
.unwrap();
let ref_down = CandleTensor::from_vec(
down_data,
(TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE),
&device,
)
.unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_mini_transformer_layer() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let output = layer.forward(input).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data = generate_layer_weights(&layer);
rt.set_data(input, &input_data);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let w = |idx: usize, shape: &[usize]| {
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
};
let expected = transformer_layer_ref(
&ref_input,
&w(0, &[TRANSFORMER_HIDDEN]),
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(5, &[TRANSFORMER_HIDDEN]),
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2);
}
#[test]
fn metal_mini_transformer_layer_f16_intermediate() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let normed = rms_norm(input, layer.attn_norm_w, 1e-5).cast(DType::F16);
let attn_out = self_attention(
normed,
layer.wq.cast(DType::F16),
layer.wk.cast(DType::F16),
layer.wv.cast(DType::F16),
layer.wo.cast(DType::F16),
)
.cast(DType::F32);
let x = input + attn_out;
let normed = rms_norm(x, layer.mlp_norm_w, 1e-5).cast(DType::F16);
let mlp_out = swiglu_mlp(
normed,
layer.w_gate.cast(DType::F16),
layer.w_up.cast(DType::F16),
layer.w_down.cast(DType::F16),
)
.cast(DType::F32);
let output = (x + mlp_out).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data = generate_layer_weights(&layer);
rt.set_data(input, &input_data);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let w = |idx: usize, shape: &[usize]| {
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
};
let expected = transformer_layer_ref(
&ref_input,
&w(0, &[TRANSFORMER_HIDDEN]),
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(5, &[TRANSFORMER_HIDDEN]),
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 3e-2);
}
#[test]
fn test_scatter_basic() {
let mut cx = Graph::default();
let src = cx.tensor(3);
let indexes = cx.tensor(3).as_dtype(DType::Int);
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[10.0, 20.0, 30.0]);
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(result);
assert_close(&out, &[0.0, 10.0, 0.0, 20.0, 30.0], 0.001);
}
#[test]
fn test_scatter_into_nonzero_dest() {
let mut cx = Graph::default();
let src = cx.tensor(1);
let indexes = cx.tensor(1).as_dtype(DType::Int);
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[2f32]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(result);
assert_close(&out, &[1.0, 2.0, 99.0, 4.0, 5.0], 0.001);
}
#[test]
fn test_scatter_all_positions() {
let mut cx = Graph::default();
let src = cx.tensor(4);
let indexes = cx.tensor(4).as_dtype(DType::Int);
let dest = cx.tensor(4);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(result);
assert_close(&out, &[10.0, 20.0, 30.0, 40.0], 0.001);
}

View File

@@ -0,0 +1,451 @@
use luminal::prelude::*;
use luminal::shape::Expression;
/// Gather entire rows from a 2D tensor using row indices.
///
/// - `data`: (R, D) tensor
/// - `indices`: (N,) Int tensor of row indices
/// - `d`: the number of columns (D), must match data's second dimension
///
/// Returns: (N, D) tensor where output[i] = data[indices[i]]
pub fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
assert_eq!(indices.dtype, DType::Int);
let n = indices.dims1();
// base[i] = indices[i] * D → flat starting position for each row
let base = (indices * d).expand_dim(1, d); // (N, D) broadcast along cols
// col[j] = j → column offsets 0..D
let col = data.graph().arange(d as i32).expand_dim(0, n); // (N, D) broadcast along rows
// flat_idx[i,j] = indices[i] * D + j
let flat_idx = base + col;
data.gather(flat_idx)
}
/// Scatter entire rows into a 2D tensor using row indices.
///
/// - `src`: (N, D) tensor of values to write
/// - `indices`: (N,) Int tensor of destination row indices
/// - `dest`: (R, D) tensor to write into (copied first, then overwritten at index positions)
/// - `d`: the number of columns (D)
///
/// Returns: (R, D) tensor where output = copy(dest); output[indices[i]] = src[i]
pub fn scatter_rows(
src: GraphTensor,
indices: GraphTensor,
dest: GraphTensor,
d: usize,
) -> GraphTensor {
assert_eq!(indices.dtype, DType::Int);
let n = indices.dims1();
// Same index expansion as gather_rows
let base = (indices * d).expand_dim(1, d);
let col = src.graph().arange(d as i32).expand_dim(0, n);
let flat_idx = base + col;
src.scatter(flat_idx, dest)
}
/// Pure HLIR paged attention for one layer with causal masking.
///
/// Inputs:
/// - `q`: (s, hidden) f32 — query vectors
/// - `k_new`: (s, kv_dim) f32 — new key vectors
/// - `v_new`: (s, kv_dim) f32 — new value vectors
/// - `k_cache`: (num_slots, kv_dim) f32 — key cache (preallocated)
/// - `v_cache`: (num_slots, kv_dim) f32 — value cache (preallocated)
/// - `gather_idx`: (ctx_len,) Int — which cache slots to read
/// - `scatter_idx`: (s,) Int — which cache slots to write new KV into
/// - `prev_seq`: number of previously cached tokens (for causal mask offset)
/// - `n_heads`: number of query heads
/// - `n_kv_heads`: number of KV heads (for GQA)
/// - `head_dim`: dimension per head
///
/// Returns: (attn_out, k_cache_new, v_cache_new)
/// - `attn_out`: (s, hidden) f32
/// - `k_cache_new`: (num_slots, kv_dim) f32
/// - `v_cache_new`: (num_slots, kv_dim) f32
#[allow(clippy::too_many_arguments)]
pub fn paged_attention(
q: GraphTensor,
k_new: GraphTensor,
v_new: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
gather_idx: GraphTensor,
scatter_idx: GraphTensor,
prev_seq: Expression,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let kv_dim = n_kv_heads * head_dim;
let kv_groups = n_heads / n_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let s = q.dims()[0];
let ctx = gather_idx.dims()[0];
let cx = q.graph();
// ── Phase 1: Write new KV into cache ──
let k_cache = scatter_rows(k_new, scatter_idx, k_cache, kv_dim);
let v_cache = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
// ── Phase 2: Gather context KV from cache ──
let k = gather_rows(k_cache, gather_idx, kv_dim); // (ctx, kv_dim)
let v = gather_rows(v_cache, gather_idx, kv_dim); // (ctx, kv_dim)
// ── Phase 3: Reshape for multi-head attention ──
// Q: (s, hidden) → (s, n_heads, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
// → (n_kv_heads, kv_groups, s, head_dim)
let q = q
.split_dims(1, head_dim) // (s, n_heads, head_dim)
.split_dims(1, kv_groups) // (s, n_kv_heads, kv_groups, head_dim)
.permute((1, 2, 0, 3)); // (n_kv_heads, kv_groups, s, head_dim)
// K: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, head_dim, ctx)
let k = k
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
.permute((1, 2, 0)); // (n_kv_heads, head_dim, ctx)
// V: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, ctx, head_dim)
let v = v
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
.permute((1, 0, 2)); // (n_kv_heads, ctx, head_dim)
// ── Phase 4: Attention ──
// Broadcast K, V over kv_groups dimension
let k = k.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, head_dim, ctx)
let v = v.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, ctx, head_dim)
// QK^T: (n_kv_heads, kv_groups, s, head_dim) @ (n_kv_heads, kv_groups, head_dim, ctx)
// → (n_kv_heads, kv_groups, s, ctx)
let scores = q.matmul(k) * scale;
// Build causal mask: query at position prev_seq+i can attend to context j iff j <= prev_seq+i.
// row_vals[i] = prev_seq + i, col_vals[j] = j
// mask[i,j] = -1e9 where row_vals[i] < col_vals[j], else 0
let z = Expression::from('z');
let row_vals = cx.iota(z + prev_seq, s).expand_dim(1, ctx); // (s, ctx)
let col_vals = cx.arange(ctx).expand_dim(0, s); // (s, ctx)
let mask = row_vals
.cast(DType::F32)
.lt(col_vals.cast(DType::F32))
.cast(DType::F32)
* -1e9;
// Broadcast (s, ctx) → (n_kv_heads, kv_groups, s, ctx)
let mask = mask.expand_dim(0, n_kv_heads).expand_dim(1, kv_groups);
let scores = scores + mask;
// Softmax over context dimension (axis 3)
let weights = scores.softmax(3);
// Weighted sum: (n_kv_heads, kv_groups, s, ctx) @ (n_kv_heads, kv_groups, ctx, head_dim)
// → (n_kv_heads, kv_groups, s, head_dim)
let out = weights.matmul(v);
// ── Phase 5: Reshape output ──
// (n_kv_heads, kv_groups, s, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
let mut out = out.permute((2, 0, 1, 3));
out.shape = ShapeTracker::new((s, n_heads * head_dim));
(out, k_cache, v_cache)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gather_rows() {
let mut cx = Graph::new();
let data = cx.tensor((4, 3)); // 4 rows, 3 cols
let indices = cx.tensor(3).as_dtype(DType::Int);
let result = gather_rows(data, indices, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
rt.set_data(
data.id,
vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
);
// Gather rows 0, 2, 3
rt.set_data(indices.id, vec![0, 2, 3]);
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(result.id),
vec![1., 2., 3., 7., 8., 9., 10., 11., 12.]
);
}
#[test]
fn test_scatter_rows() {
let mut cx = Graph::new();
let src = cx.tensor((2, 3));
let indices = cx.tensor(2).as_dtype(DType::Int);
let dest = cx.tensor((4, 3));
let result = scatter_rows(src, indices, dest, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
rt.set_data(indices.id, vec![1, 3]);
rt.set_data(dest.id, vec![0.; 12]);
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(result.id),
vec![0., 0., 0., 10., 20., 30., 0., 0., 0., 40., 50., 60.]
);
}
#[test]
fn test_scatter_then_gather_roundtrip() {
let mut cx = Graph::new();
let kv_new = cx.tensor((2, 4)); // 2 new rows, dim=4
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
let cache = cx.tensor((6, 4)); // 6 slots
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
// Scatter new rows into cache, then gather them back
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
rt.set_data(cache.id, vec![0.; 24]); // Zero cache
rt.set_data(gather_idx.id, vec![1, 4]); // Read back from same slots
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(gathered.id),
vec![1., 2., 3., 4., 5., 6., 7., 8.]
);
}
#[test]
fn test_paged_attention_shape_and_cache_update() {
// Minimal config: n_heads=2, n_kv_heads=2, head_dim=2, kv_groups=1
// hidden = 4, kv_dim = 4
let n_heads = 2;
let n_kv_heads = 2;
let head_dim = 2;
let hidden = n_heads * head_dim; // 4
let kv_dim = n_kv_heads * head_dim; // 4
let num_slots = 8;
let mut cx = Graph::new();
let q = cx.tensor((1, hidden)); // 1 new token
let k_new = cx.tensor((1, kv_dim));
let v_new = cx.tensor((1, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context tokens
let scatter_idx = cx.tensor(1).as_dtype(DType::Int); // 1 new token
// prev_seq=2: this is the 3rd token (positions 0,1 cached, position 2 is new)
let (attn_out, k_cache_new, v_cache_new) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
2.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
let k_cache_new = k_cache_new.output();
let v_cache_new = v_cache_new.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
rt.set_data(q.id, vec![1., 0., 1., 0.]);
// k_new = [0.5, 0.5, 0.5, 0.5]
rt.set_data(k_new.id, vec![0.5, 0.5, 0.5, 0.5]);
// v_new = [1, 2, 3, 4]
rt.set_data(v_new.id, vec![1., 2., 3., 4.]);
// Zero caches
rt.set_data(k_cache.id, vec![0.; num_slots * kv_dim]);
rt.set_data(v_cache.id, vec![0.; num_slots * kv_dim]);
// Scatter new KV to slot 2
rt.set_data(scatter_idx.id, vec![2]);
// Gather context from slots 0, 1, 2 (slots 0,1 are zeros, slot 2 is the new KV)
rt.set_data(gather_idx.id, vec![0, 1, 2]);
rt.execute(&cx.dyn_map);
// Verify output shape: (1, hidden=4)
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), hidden);
// Verify KV cache was updated: k_cache_new should have [0.5, 0.5, 0.5, 0.5] at slot 2
let k_out = rt.get_f32(k_cache_new.id);
assert_eq!(k_out.len(), num_slots * kv_dim);
// Slot 2 is at offset 2*4=8..12
assert_eq!(&k_out[8..12], &[0.5, 0.5, 0.5, 0.5]);
// Slot 0 should still be zeros
assert_eq!(&k_out[0..4], &[0., 0., 0., 0.]);
let v_out = rt.get_f32(v_cache_new.id);
assert_eq!(&v_out[8..12], &[1., 2., 3., 4.]);
}
#[test]
fn test_paged_attention_known_values() {
// Test with values where we can compute expected attention output.
// n_heads=1, n_kv_heads=1, head_dim=2, kv_groups=1
// hidden=2, kv_dim=2
let n_heads = 1;
let n_kv_heads = 1;
let head_dim = 2;
let hidden = 2;
let kv_dim = 2;
let num_slots = 4;
let mut cx = Graph::new();
let q = cx.tensor((1, hidden));
let k_new = cx.tensor((1, kv_dim));
let v_new = cx.tensor((1, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
let scatter_idx = cx.tensor(1).as_dtype(DType::Int);
// prev_seq=1: 1 cached token + 1 new token, context len=2
// Query at absolute position 1 can attend to context positions 0 and 1
let (attn_out, _, _) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
1.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
// K cached at slot 0: [1, 0]
// K new (written to slot 1): [0, 1]
// V cached at slot 0: [10, 20]
// V new (written to slot 1): [30, 40]
// Q: [1, 1]
let mut k_cache_data = vec![0.; num_slots * kv_dim];
k_cache_data[0] = 1.;
k_cache_data[1] = 0.; // slot 0 K = [1, 0]
let mut v_cache_data = vec![0.; num_slots * kv_dim];
v_cache_data[0] = 10.;
v_cache_data[1] = 20.; // slot 0 V = [10, 20]
rt.set_data(q.id, vec![1., 1.]);
rt.set_data(k_new.id, vec![0., 1.]); // new K = [0, 1]
rt.set_data(v_new.id, vec![30., 40.]); // new V = [30, 40]
rt.set_data(k_cache.id, k_cache_data);
rt.set_data(v_cache.id, v_cache_data);
rt.set_data(scatter_idx.id, vec![1]); // write to slot 1
rt.set_data(gather_idx.id, vec![0, 1]); // gather slots 0, 1
rt.execute(&cx.dyn_map);
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), hidden);
let expected = vec![20.0, 30.0];
for (a, b) in out.iter().zip(&expected) {
assert!((a - b).abs() < 0.1, "Expected {expected:?}, got {out:?}");
}
}
#[test]
fn test_paged_attention_causal_mask() {
// Verify that the causal mask blocks future positions.
// n_heads=1, n_kv_heads=1, head_dim=2
let n_heads = 1;
let n_kv_heads = 1;
let head_dim = 2;
let hidden = 2;
let kv_dim = 2;
let num_slots = 4;
let mut cx = Graph::new();
let q = cx.tensor((2, hidden)); // 2 new tokens
let k_new = cx.tensor((2, kv_dim));
let v_new = cx.tensor((2, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context (1 cached + 2 new)
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
// prev_seq=1: 1 cached token, 2 new tokens → context len=3
// Query 0 at absolute pos 1: can see ctx 0,1 (not 2)
// Query 1 at absolute pos 2: can see ctx 0,1,2
let (attn_out, _, _) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
1.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Cache has 1 token at slot 0
let mut k_cache_data = vec![0.; num_slots * kv_dim];
k_cache_data[0] = 1.;
k_cache_data[1] = 0.; // slot 0: K=[1,0]
let mut v_cache_data = vec![0.; num_slots * kv_dim];
v_cache_data[0] = 100.;
v_cache_data[1] = 0.; // slot 0: V=[100,0]
// 2 new tokens
rt.set_data(q.id, vec![1., 0., 0., 1.]);
rt.set_data(k_new.id, vec![0., 1., 1., 1.]); // token0 K=[0,1], token1 K=[1,1]
rt.set_data(v_new.id, vec![0., 10., 0., 20.]); // token0 V=[0,10], token1 V=[0,20]
rt.set_data(k_cache.id, k_cache_data);
rt.set_data(v_cache.id, v_cache_data);
rt.set_data(scatter_idx.id, vec![1, 2]); // write to slots 1, 2
rt.set_data(gather_idx.id, vec![0, 1, 2]); // gather all 3
rt.execute(&cx.dyn_map);
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), 2 * hidden);
// Token 0 (abs pos 1): attends to ctx 0,1 only (ctx 2 is masked)
// Token 1 (abs pos 2): attends to ctx 0,1,2
// Verify output has valid (non-NaN, non-inf) values and correct length
for val in out.iter() {
assert!(val.is_finite(), "Output contains non-finite value: {}", val);
}
}
}

View File

@@ -55,9 +55,11 @@ impl ConvND {
let kernel_product: usize = kernel.iter().product();
Self {
weight: cx.named_tensor("ConvWeight", (ch_out, ch_in * kernel_product)),
weight: cx
.named_tensor("ConvWeight", (ch_out, ch_in * kernel_product))
.persist(),
bias: if bias {
Some(cx.named_tensor("ConvBias", ch_out))
Some(cx.named_tensor("ConvBias", ch_out).persist())
} else {
None
},

View File

@@ -14,3 +14,5 @@ mod pooling;
pub use pooling::*;
mod moe;
pub use moe::*;
mod attention;
pub use attention::*;

View File

@@ -10,9 +10,9 @@ pub struct Linear {
impl Linear {
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (inp, out)),
weight: cx.named_tensor("Weight", (inp, out)).persist(),
bias: if bias {
Some(cx.named_tensor("Bias", out))
Some(cx.named_tensor("Bias", out).persist())
} else {
None
},
@@ -22,9 +22,9 @@ impl Linear {
pub fn new_permuted(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (out, inp)),
weight: cx.named_tensor("Weight", (out, inp)).persist(),
bias: if bias {
Some(cx.named_tensor("Bias", out))
Some(cx.named_tensor("Bias", out).persist())
} else {
None
},

View File

@@ -19,8 +19,8 @@ impl LayerNorm {
cx: &mut Graph,
) -> Self {
Self {
weight: weight.map(|w| cx.named_tensor(w, dim)),
bias: bias.map(|b| cx.named_tensor(b, dim)),
weight: weight.map(|w| cx.named_tensor(w, dim).persist()),
bias: bias.map(|b| cx.named_tensor(b, dim).persist()),
mean_norm,
epsilon,
}

5
crates/luminal_python/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
*.onnx
__pycache__/
*.pyc
uv.lock
.venv

View File

@@ -0,0 +1,120 @@
## Python Environment
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
## Lessons Learned
At the end of any session that involved a hard or non-obvious bug, append an entry to
`LessonsLearned.md` in this directory. A "hard bug" means any bug that required significant
investigation — intermittent failures, wrong output without a crash, egglog/optimizer issues,
or anything that took more than a few minutes to locate.
Each entry should cover:
1. **What the symptom was** (test failure, wrong output, panic, etc.)
2. **What the actual root cause was** (the specific code/logic that was wrong)
3. **Why it was hard to find** (what made it non-obvious or intermittent)
4. **The fix** (what changed and why it works)
5. **A general principle** extracted from the bug — something that helps avoid the same
class of mistake in future code
The goal is to build a living record of codebase-specific pitfalls that future sessions can
consult before writing new egglog rules, CUDA kernels, or optimizer passes.
1. If you want to run tests:
- `./run_test.sh` - runs tests with the native backend
- `./run_tests_cuda.sh` - runs tests with the CUDA backend
## Testing Best Practices
### Overview
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
### Test Pattern (CORRECT)
All tests should follow this standard pattern:
```python
def test_operation():
"""Brief description of what operation is being tested."""
# 1. Instantiate PyTorch model
model: torch.nn.Module = OperationTestModel()
# 2. Compile with luminal backend
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
# 3. Create test input
x: torch.Tensor = torch.tensor([...]) # or torch.rand(...)
# 4. Run both original and compiled versions
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
# 5. Verify outputs match
assert torch.allclose(output, original)
```
### Test Models
- Define test model classes in `tests/test_models.py`
- Each model should be a simple `torch.nn.Module` that demonstrates one operation or pattern
- Use clear, descriptive class names (e.g., `AddTestModel`, `TransposeTestModel`)
- Include docstrings explaining what the model tests
Example:
```python
class AddTestModel(torch.nn.Module):
"""Tests element-wise addition."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x
```
### What NOT to Do
**❌ DO NOT create ONNX files directly in tests:**
```python
# WRONG - bypasses the PyTorch integration
model_path = create_onnx_model(...)
graph_result = luminal.process_onnx(model_path, backend='native')
```
**✓ DO create PyTorch models and use torch.compile:**
```python
# CORRECT - tests actual user workflow
model: torch.nn.Module = MyTestModel()
model_compiled = torch.compile(model, backend=luminal_backend)
```
### Rationale
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
- **User-facing API**: Tests use the same API that users will use (torch.compile)
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
### Special Cases
**Testing constants:**
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1.0, 2.0, 3.0])
return x + constant
```
**Testing type casts:**
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
```
**Testing complex operations:**
Chain operations naturally in PyTorch - ONNX export handles the conversion:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
transposed = x.transpose(0, 1)
scaled = transposed * 2.0
return scaled + 1.0
```

View File

@@ -0,0 +1,758 @@
# Lessons Learned
This file documents hard bugs encountered in this codebase, their root causes, and principles
to prevent similar issues in the future.
---
## 2026-02-24 — Intermittent CUDA Backend Failures: Embed False Match + Batched Matmul Dimension Drop
### Background: Why the Failures Were Intermittent
Both bugs only appeared on roughly 50% of test runs. The source of non-determinism is
`FxHashMap` (a fixed-seed hash map). The egglog optimizer's `SerializedEGraph::new` builds
`Vec<NodeId>` orderings for each e-class by iterating a `FxHashMap`, producing non-deterministic
node orderings. `random_initial_choice()` in `src/egglog_utils/mod.rs` then randomly picks one
e-node per e-class as the starting representation for the profiling phase. The combination means
some runs pick a correct kernel and some pick a broken one from the same e-class.
**Lesson**: When a test fails intermittently at a roughly 50% rate, suspect the egglog extractor
choosing between two e-nodes in the same e-class — one correct, one broken. The fix is always in
the broken e-node's rewrite rule.
---
### Bug 1: `test_gather_elements` — KernelEmbed and RowEmbed False Match
**Files changed**:
- `crates/luminal_cuda/src/kernel/hlir.rs` (KernelEmbed, 4 rules)
- `crates/luminal_cuda/src/block/ops.rs` (RowEmbed, 4 rules)
#### What happened
`gather_elements` (axis-aware gather) decomposes into a flat gather by computing:
```
flat_idx = Add(
Mul(indices, stride[axis]),
Mul(Expand(Iota(dim_size)), stride[non_axis])
)
```
`KernelEmbed` and `RowEmbed` are optimized embedding lookup kernels. A genuine embedding
lookup produces:
```
flat_idx = Add(
Mul(Cast(token_ids), embed_dim),
Iota(embed_dim) ← bare Iota, the position within an embedding row
)
```
The egglog rewrite rules for both ops matched `Add(?mul_result, ?iota_result)` where
`?iota_result` was **unconstrained** — it could bind to anything, including
`Mul(Expand(Iota(n)), stride)` from `gather_elements`. This created a `KernelEmbed`/`RowEmbed`
node in the same e-class as the `Gather` node. When the extractor picked it, `build_payload`
called `flatten_mul_strides(range, token_stride)` which asserted `range.len() == token_stride.len()`:
- `range` came from `RemoveNthFromEnd(idx_shape, 0)` → length 1
- `token_stride` came from the indices strides → length 2
- Assertion failed → panic.
#### The fix
Add `(= ?iota_result (Iota ?iota_expr ?iota_range))` to all 8 rules, requiring the positional
component to be a bare `Iota` node:
```egglog
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
(= ?iota_result (Iota ?iota_expr ?iota_range)) ← added
(= ?mul_result (Mul ...))
```
#### Investigation note
The initial plan correctly identified `KernelEmbed` as faulty, but missed `RowEmbed`. The two
ops are structurally identical but live in different parts of the codebase (`kernel/` vs
`block/`). The second bug was only discovered when the backtrace pointed to
`RowEmbed::build_payload` instead of `KernelEmbed::compile`. Always search for sibling
implementations when fixing a pattern-matching bug in one op.
---
### Bug 2: `test_matmul_batched` — CuBlasLt Drops Batch Dimension
**Files changed**:
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg`
#### What happened
The luminal frontend decomposes `(2,3,4) @ (2,4,5)` into:
```rust
let w = rhs.permute((0, 2, 1)); // (2,4,5) → (2,5,4)
let mul = self.expand_dim(2, d) // (2,3,4) → (2,3,5,4)
* w.expand_dim(1, b); // (2,5,4) → (2,3,5,4)
mul.sum(3) // → (2,3,5), correct out_shape
```
All four cublaslt rewrite rules extracted `m` and `n` from the output shape using
`nth_from_end`, which succeeds for any rank:
```egglog
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
For `out_shape = [2, 3, 5]`: `?m = 3`, `?n = 5`. The batch dim `2` is never extracted or
stored. The rules also validated stride patterns using `nth_from_end` on the stride arrays —
but for this batched case, **all stride checks coincidentally passed** because the last three
strides of the 4D expanded tensors happened to satisfy the 2D row/column-major patterns.
The resulting `CuBlasLt` node had `output_size() = m * n = 15`. The batch dimension was
silently discarded. The runtime allocated a 15-element output buffer, cuBLAS wrote a 3×5
result, and the test got back 15 values instead of 30.
#### The fix
Add `(= (len ?out_shape) 2)` to all 4 rules:
```egglog
(= (len ?out_shape) 2) ← added: cuBLAS is 2D only
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
`len` counts elements in the `ECons`-list shape. With this constraint, any `Sum` node with a
3D+ output shape (batched matmul) is not matched by cuBLAS rules and falls through to
`KernelSumReduce + KernelMul` (or the tiling block ops), which correctly use
`out_shape.iter().product()` for their output sizes.
Note: `TileMatmulSplitK` and `TileMatmulFullSplit` do NOT need this fix — their `output_size()`
already returns `untiled_range.iter().product()` which includes all dimensions.
---
### General Principle: Always Constrain Shape Rank in Egglog Rules
Both bugs share the same structural cause: **egglog rewrite rules that used `nth_from_end` to
extract dimensions from a shape list without constraining the list's length.** Since
`nth_from_end` silently succeeds for any list with enough trailing elements, rules written for
2D tensors accidentally matched higher-rank tensors.
**Rule for writing egglog rewrite rules in this codebase**:
> If a rule is designed for a specific tensor rank, always add an explicit
> `(= (len ?shape) N)` constraint. If a rule is designed to handle arbitrary ranks but an
> op's output only covers a subset of dimensions (like cuBLAS covering only the last 2),
> that is a correctness bug — either implement strided batched cuBLAS or add the rank
> constraint and fall back to a kernel that handles all dimensions.
---
### Debugging Intermittent CUDA Failures: Effective Approach
The investigation used extensive `eprintln!` debug logging to trace which kernels were compiled
vs. skipped. Key observations:
1. **In the passing case**: `KernelSumReduce::compile()` was called, kernels were allocated.
2. **In the failing case**: `KernelSumReduce::compile()` was never called, yet output was produced.
This asymmetry pointed to a `HostOp` path (cuBLAS) executing instead of the `KernelOp` path,
which narrowed the search to cublaslt rewrite rules. The HLIR-level `SumReduce::to_egglog` log
confirmed the correct HLIR node existed — the bug was in the e-graph optimization choosing
a different (broken) e-node from the same e-class.
**Effective debug strategy for egglog non-determinism bugs**:
1. Add logging at compile time for each kernel type (`KernelFoo::compile`, `HostFoo::execute`)
2. Compare passing vs. failing runs to see which kernels are/aren't invoked
3. The missing kernel's e-class contains a broken alternative — find it via the egglog rewrite rules
4. Check the op that *is* executing — its `output_size()` reveals what's wrong with the false match
---
## 2026-02-25 — OneHot Test Panic: Cast(Int→F32) Produces Int Output
### What the symptom was
`test_onehot` panicked at `src/hlir.rs:1625` in `get_f32()`: the output buffer was
`NativeData::Int` instead of the expected `NativeData::F32`.
### What the actual root cause was
The Cast parser's `* 1.0` workaround for `Int → F32` casts used `input * one_expanded`
(Int GraphTensor on the left, F32 constant on the right). However, `Mul for GraphTensor`
always uses `self.dtype` (the **left** operand's dtype) for the result, and the native
runtime's `Mul::execute` dispatches on the **first** input's `NativeData` variant. So
`Int * F32` produced `DType::Int` / `NativeData::Int` — the exact opposite of the intended
F32 output.
### Why it was hard to find
1. **The OneHot parser was a red herring**: The initial plan assumed the OneHot ONNX node
was being parsed, but `torch.onnx.export` decomposes `one_hot` into
`Unsqueeze → Equal → Cast(Bool→Int) → Cast(Int→F32)`. The OneHot parser was never called.
2. **The `* 1.0` workaround looked correct**: It was used successfully in many other parsers,
but those all had F32 inputs (where `F32 * F32 = F32`). The Int→F32 case was the only
path where the left operand was Int.
3. **Operand order matters silently**: Nothing warns about mixed-dtype Mul — it just takes
the left operand's dtype.
### The fix
In `ops_parse/unary.rs` `parse_cast_node`, split the combined condition into two cases:
- **No-op cast** (`cast_result.id == input.id`): `input * one_expanded` — preserves dtype
- **Int source** (`input.dtype == DType::Int`): `one_expanded * input` — F32 on the left
ensures F32 output
### General principle
**In luminal, binary op dtype is always the LEFT operand's dtype.** When constructing
`GraphTensor * constant_float(1.0)` for type materialization, always put the operand
whose dtype you want to preserve on the LEFT side. When converting Int→F32, the F32
constant must be the left operand.
---
## 2026-02-26 — ScatterND Fails on CUDA: "does not produce an egraph"
### What the symptom was
`test_scatter_nd` passed on native backend but failed on CUDA with "does not produce an
egraph". The CUDA compilation could not extract a valid program from the e-graph.
### What the actual root cause was
`scatter_nd` in `movement.rs` does `indices * 1` (line 353) to materialize the tensor for
reshaping. The `* 1` dispatches to `Mul<S: Into<Expression>>`, which creates a `constant(1)`
`Iota(1,1)``DType::Int`. But the ONNX parser creates all tensors as `DType::F32`
(via `named_tensor()` in `compiled_graph.rs:70`), so indices arrive as F32. This produces
`Mul(F32, Int)` — mixed dtypes.
The HLIR Mul dtype rule (`hlir.rs:886-888`) uses `(= ?dty (dtype ?lhs))` and
`(= ?dty (dtype ?rhs))` with the same `?dty` variable, requiring both inputs to have
matching dtypes. `F32 != Int` → the rule never fires → the Mul node gets **no dtype**.
Every downstream op checks `(= ?dty (dtype ?upstream))`. Without dtype on the Mul, no
CUDA kernel rewrite rules fire for any downstream op (KernelMul, KernelAdd, KernelLessThan,
etc.). When `cleanup_hlir` runs (enabled for CUDA, disabled for native), it deletes all
unrewritten HLIR ops, leaving empty e-classes → egraph extraction fails.
### Why it was hard to find
1. **Works on native**: `cleanup_hlir = false` for NativeRuntime, so unrewritten HLIR ops
are never deleted. NativeOp dispatches on actual runtime data, not egglog dtype.
2. **Cascading failure**: The root cause (missing dtype on one Mul) silently propagated
through every downstream op, making it look like a systemic CUDA issue rather than a
single dtype mismatch.
3. **`scatter_elements` works fine**: The sibling op already cast indices via
`(idx_f32 + (is_neg * adj)).cast(DType::Int)`, so only `scatter_nd` had this bug.
### The fix
Added `let indices = indices.cast(DType::Int);` at the top of `scatter_nd` in
`movement.rs`, before any arithmetic on indices. `GraphTensor::cast()` short-circuits
when `self.dtype == dtype`, so this is safe for callers already passing Int indices.
Also added the same cast in `parse_scatter_nd_node` for explicitness.
### General principle
**Always cast index tensors to `DType::Int` before arithmetic in graph-building code.**
ONNX tensors arrive as F32 from the Python bridge. Any `indices * stride` or
`indices * 1` will produce `Mul(F32, Int)` which breaks HLIR dtype propagation on CUDA.
The pattern `let indices = indices.cast(DType::Int);` at the top of any index-consuming
function is defensive and free (no-op when already Int).
---
## 2026-03-04 — Dynamic Shapes: Empty Buffer for BOOL Scalar Initializer
### What the symptom was
`test_hf_llama_decode_loop_dynamic` panicked at `bin_fn: a index 0 out of bounds (a.len=0), shape=[1, 1, 4, 4], strides=[0, 0, 0, 0]`. An Input node labeled `"new_ones"` had an empty buffer at runtime.
### What the actual root cause was
Two issues combined:
1. **`load_tensor_floats` didn't handle ONNX data_type=9 (BOOL)**. The `new_ones` initializer was a BOOL scalar (1 byte in `raw_data`). `load_tensor_floats` fell through to the fallback case, which tried `chunks_exact(4)` on 1 byte → produced 0 chunks → returned empty vec `[]`. The buffer was set with empty data.
2. **Scalar initializers with empty `dims` created 0-dimensional tensors**. ONNX represents scalars with `dims=[]`. The initializer loop computed `shape = init.dims.iter().map(|&d| d as usize).collect()` → empty vec `[]`, then called `named_tensor(name, [])` which created a tensor with 0 dimensions instead of the intended scalar `[1]`.
### Why it was hard to find
1. **Misdiagnosed as ConstantOfShape issue**: The original plan targeted `ConstantOfShape` with dynamic shapes. The shape `[1,1,4,4]` with strides `[0,0,0,0]` looked like a broadcast from a constant fill. But `parse_constant_of_shape` was never called — the `new_ones` tensor came from an ONNX initializer, not a computation node.
2. **The BOOL data type is unusual**: Most ONNX tensors are FLOAT, INT32, or INT64. BOOL initializers only appear in specific patterns (like `torch.ones()` in attention mask computation). `load_initializer_as_f32` already handled BOOL, but its sibling `load_tensor_floats` didn't.
3. **Empty vec is valid data**: `set_data(node_id, [])` doesn't panic — it silently sets an empty buffer. The error only manifests later when a downstream op tries to read index 0.
### The fix
1. Added `data_type=9` (BOOL) handling to `load_tensor_floats` in `util.rs` — same logic as `load_initializer_as_f32`: 1 byte per element, non-zero → 1.0, zero → 0.0.
2. In `compiled_graph.rs`, initializer tensor creation: if `shape.is_empty()`, set `shape = vec![1]` (scalar representation in luminal).
### General principle
**Keep data loading functions in sync.** `load_tensor_floats` and `load_initializer_as_f32` serve the same purpose (loading ONNX TensorProto data as f32) but had different data type coverage. When adding a new data type to one, check and update the other. Better yet, refactor them into a single function.
**ONNX scalars have `dims=[]`, luminal scalars have shape `[1]`.** Always convert empty dims to `[1]` when creating luminal tensors from ONNX data.
---
## 2026-03-04 — Where Node Missing Broadcast: KernelMul flatten_strides Panic on CUDA
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` panicked at `flatten_strides` with `left: 4, right: 1` during
CUDA `KernelMul::compile`. The `KernelMul` had `out_shape=[1, 1, a, a]` but `b_stride=[z]` (1D).
### What the actual root cause was
`parse_where_node` called `x.cond(condition, y)` without broadcasting the inputs to matching ranks.
The ONNX Where op for the attention mask had condition=[1,1,a,a] (4D), x=[1] (scalar), y=[1] (scalar).
Luminal's `cond` doesn't auto-broadcast — it passes the shape trackers directly to the HLIR node.
The resulting Mul had input A with 4D strides and input B with 1D strides.
### Why it was hard to find
1. **Only triggered by 1B model**: The tiny model's Where inputs all had matching ranks (no scalars).
2. **CUDA-only**: The native runtime's `bin_fn` uses `StridedIterator` which handles mismatched
strides more gracefully. CUDA's `KernelMul::compile` calls `flatten_strides` which asserts
`range.len() == strides.len()`.
3. **Delayed crash**: The mismatch was created during ONNX parsing but only manifested during
CUDA kernel compilation (graph search phase).
### The fix
Added numpy-style broadcasting to `parse_where_node`: compute the broadcast shape across all 3
inputs, then `broadcast_to_expr` each to the common shape before calling `cond`.
### General principle
**ONNX binary/ternary ops all use numpy broadcasting.** When parsing ONNX ops that take multiple
tensor inputs (Where, Add, Mul, etc.), always broadcast all inputs to a common shape BEFORE
calling the luminal graph operation. Luminal graph ops do NOT auto-broadcast — they expect inputs
with matching shape tracker dimensions.
---
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
the value at column 0 of each row (all three top-k positions got the same value).
Native backend was fine.
2. **Root cause**: `gather_elements` was called with a non-contiguous index tensor produced by
`argsort(axis=1) → slice_along(..k, axis=1)`. The slice creates a ShapeTracker view of the
[4,8] argsort buffer with dims [4,3] and strides [8,1]. When this flowed through the
gather_elements Int arithmetic chain (cast, multiply, add) and into the final Gather CUDA
kernel, the non-contiguous strides caused incorrect index reads for later rows.
3. **Why it was hard to find**: `test_topk_indices` passed (it only tests argsort+slice, not
the downstream gather_elements). A standalone `test_gather_elements` with constant indices
also passed because constant indices are contiguous. The bug only manifested when runtime-
computed non-contiguous indices were used with data of a different size along the gather axis.
4. **Fix**: In `parse_topk_node`, compute `gather_elements(x, full_argsort, axis)` with the
full [4,8] argsort result (same size as data), then slice the gathered values to [4,3].
This ensures gather_elements always operates on same-sized contiguous tensors.
5. **General principle**: When building graph operations that chain shape-tracker views
(slice, transpose, etc.) into downstream HLIR ops on CUDA, prefer operating on full
contiguous tensors first and slicing the result afterward. Non-contiguous views flowing
through multiple CUDA kernels can trigger stride-related bugs in the egglog-compiled code.
---
## 2026-03-07 — Non-deterministic CUDA_ERROR_ILLEGAL_ADDRESS: Multiple Missing Rank Constraints
### What the symptom was
`test_hf_llama_tiny` on CUDA failed ~70% of runs with `CUDA_ERROR_ILLEGAL_ADDRESS`. Failures
were non-deterministic due to egglog's `FxHashMap` iteration order in `random_initial_choice()`.
### What the actual root cause was
**Multiple** matmul egglog rules lacked `(= (len ?out_shape) 2)` constraints:
1. `TileMatmulSplitK` in `block/ops.rs` (disabled via comment but rule still registered)
2. `TileMatmulFullSplit` in `block/ops.rs`
3. All 4 `sgemm_v2_*.egg` rules in `host/cublas/`
The `cublaslt_*.egg` rules already had the constraint. When egglog picked TileMatmul or sgemm
for a 3D+ batched matmul, the generated CUDA kernels accessed out-of-bounds memory.
Additionally, `KernelEmbed` in `kernel/hlir.rs` had an output indexing bug:
`out[out_offset * embed_dim + embed_idx]` should be `out[out_offset + embed_idx]` because
`out_offset` already includes the embed_dim factor from `flatten_strides`.
**Most critically**, the KernelEmbed and RowEmbed "with cast" egglog rules passed the
**pre-cast** float token_ids (`?token_ids`) to the embed kernel instead of the **post-cast**
int token_ids (`?token_ids_cast`). The CUDA kernel reads token_ids as `const int*`, so float
data gets reinterpreted as enormous garbage integers, causing out-of-bounds embed table access.
### Why it was hard to find
1. **Multiple independent bug sources**: The ~70% failure rate was caused by three separate bugs
(matmul rank, embed output indexing, embed pre-cast input). Each fix only reduced the rate
partially, making it seem like each fix was insufficient.
2. **CudaGraph wrapping**: The crash occurred inside `CudaGraphOp::execute_internal` which
batches multiple kernels via CUDA graphs. The error just said "CudaGraph" — it
didn't identify which kernel crashed. Adding per-kernel debug launches was essential.
3. **Cascading failures**: When the Megakernel (containing RowEmbed with the pre-cast bug)
corrupted the embed output, the NEXT CudaGraph group's kernels crashed reading the garbage.
This made the Megakernel appear to be the victim, not the source.
4. **The pre-cast bug only crashes SOMETIMES**: Egglog's random choice determines whether
KernelEmbed/RowEmbed is selected (crash) or the generic Gather path is used (works).
Float token_id 1.0 (= 0x3F800000 = 1065353216 as int) produces an astronomically large
embed table index, causing ILLEGAL_ADDRESS.
### The fix
- Added `(= (len ?out_shape) 2)` to TileMatmulSplitK, TileMatmulFullSplit, and all 4 sgemm_v2 rules
- Fixed KernelEmbed output indexing: `out[out_offset + embed_idx]`
- **Fixed KernelEmbed/RowEmbed "with cast" rules**: Changed input from `?token_ids` to
`?token_ids_cast` — using the post-Cast int tensor instead of the pre-Cast float tensor
### Results
Failure rate: ~70% → 0% (20/20 passing). All three bugs needed to be fixed together.
### General principle
**When an egglog rule matches a sub-expression chain (like Cast→Mul→Add), be precise about
which intermediate result becomes each input.** The "with cast" embed rules matched
`Cast(?token_ids, ...)` to verify the Cast existed, but then passed `?token_ids` (the Cast
INPUT) instead of `?token_ids_cast` (the Cast OUTPUT) to the embed kernel. The kernel expects
int data, so the pre-cast float data was reinterpreted as garbage ints.
**Always search for sibling implementations**: KernelEmbed (in `kernel/hlir.rs`) and RowEmbed
(in `block/ops.rs`) had the SAME bug in their "with cast" rules. Fixing one without the other
only reduces the failure rate — both must be fixed.
---
## 2026-03-09 — TileMatmulFullSplit Matches Element-wise Square+Sum from LayerNorm
### What the symptom was
`test_qwen_image_transformer_tiny` on CUDA produced NaN in specific output rows. The failure
was non-deterministic (~85% failure rate) due to egglog's random e-class extraction picking
TileMatmulFullSplit for some operations.
### What the actual root cause was
The `TileMatmulFullSplit` rewrite rule in `block/ops.rs` matched any `Mul + Sum` pattern with
a 2D output, contiguous K-strides, and F32 inputs. This correctly matched real matmuls, but
ALSO matched the element-wise `x * x + Sum(last_dim)` pattern from LayerNorm/RMSNorm
(Pow(x, 2) → ReduceMean).
For a [1, 4, 64] activation tensor `x`:
- `Mul(x, x)` shape: [1, 4, 64], strides: [256z, 64z, z] for both inputs
- `Sum(dim=2)` output: [1, 4], len=2 ✓
TileMatmulFullSplit interpreted this as a [1, 64] × [64, 4] → [1, 4] matmul with:
- A = row 0 of x (64 elements), B = same buffer at column offsets
The kernel computed `C[j] = sum_k x[k] * x[j*64+k]` (cross-products) instead of the correct
`C[j] = sum_k x[j*64+k]^2` (squared sums). This produced subtly wrong values for j > 0
(correct for j=0 since cross-product with self = squared sum). These wrong values propagated
through LayerNorm → downstream operations → softmax → NaN.
Key diagnostic: adding `printf` to the kernel showed `a_ptr == b_ptr` (same buffer for both
inputs), confirming the kernel was operating on `x * x` not a real matmul.
### Why it was hard to find
1. **Individual op tests passed**: Simple Gemm tests, attention tests, and all other bisection
tests passed because they didn't have the specific `x*x → Sum` pattern.
2. **Non-deterministic**: The bug only manifested when egglog selected TileMatmulFullSplit
over the kernel fallback for the square+sum operation.
3. **No NaN from TileMatmulFullSplit itself**: The kernel produced wrong-but-finite values.
NaN only appeared downstream through softmax (exp(large) → ∞ → ∞/∞ = NaN).
4. **Systematic elimination needed**: Had to disable all block ops, then enable one at a time,
to narrow down TileMatmulFullSplit as the culprit.
### The fix
Added matmul broadcast constraints to both `TileMatmulFullSplit` and `TileMatmulSplitK` rules:
```egglog
; Assert proper matmul broadcast pattern:
; A is broadcast over N (a_n_stride = 0), B is broadcast over M (b_m_stride = 0)
(= ?a_n_stride (MNum 0))
(= ?b_m_stride (MNum 0))
```
In a real matmul `[M, K] × [K, N]`, the Mul is created by expanding dims:
- A is broadcast over N → a_n_stride = 0
- B is broadcast over M → b_m_stride = 0
In element-wise `x * x`, both strides are identical (non-zero for all dims), so the
constraints correctly reject it. The cuBLAS `.egg` rules already had these constraints.
### General principle
**Matmul Mul+Sum patterns have specific broadcast structure: one input is broadcast over M
and the other over N.** When writing egglog rules that match `Mul + Sum` patterns for matmul
optimization, always verify the broadcast pattern (`a_n_stride = 0` and `b_m_stride = 0`).
This prevents matching element-wise operations like `x*x → sum` that happen to have a 2D
output and contiguous strides.
---
## 2026-03-09 — Conv3D Permute Axis Mismatch in ONNX Conv Parser
### Symptom
`test_qwen_image_vae_decoder_tiny` panicked with:
> Permute axes (5) doesn't match shape axes (6)
at `src/shape/tracker.rs:153`, during `parse_conv_node`.
### Root cause
The Conv parser's unfold → matmul algorithm used two consecutive permutes with incorrect
index calculations. After unfold produces a 2N-dimensional tensor
`[win_0..win_{N-1}, k_0..k_{N-1}]`, the first permute swapped kernel dims to the front.
But the second permute's index math still assumed the original (pre-first-permute) ordering,
confusing kernel dimensions with window dimensions. Additionally:
1. `output_spatial_dims` was captured from wrong indices (kernel dims instead of window
spatial dims)
2. The `split_dims` loop iterated `spatial` times instead of `spatial-1`, creating a
spurious size-1 dimension
3. The final permute array had `1+spatial` elements for a tensor with `2+spatial` dims
For Conv2D (spatial=2) this was never caught because the xfail'd VAE decoder test was the
only test exercising the Conv parser — the transformer tests don't use Conv ONNX nodes.
### Why it was hard to find
The Conv parser was written and the VAE test immediately xfail'd due to a *different* bug
(`merge_dims` being `todo!()`). Once `merge_dims` was implemented, the Conv parser's own
bugs surfaced for the first time.
### Fix
Rewrote the unfold → matmul section with a single correct permute:
1. **One permute** to `[N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]`
— groups batch | output spatial | channel+kernel
2. **Capture** `output_spatial_dims` from correct indices `[1..1+spatial]`
3. **Merge** all channel+kernel dims from the end into one
4. **Merge** spatial dims into one → `[N, spatial_product, C_in*kernel_product]`
5. **Matmul**`[N, spatial_product, C_out]`
6. **Split** spatial back with `spatial-1` splits (not `spatial`)
7. **Permute** C_out to position 1 with correct `2+spatial` element array
### General principle
**When chaining permutes on high-dimensional tensors, prefer a single combined permute.**
Multiple permutes with hand-computed index arrays are error-prone because each permute
redefines what indices mean. A single permute from the original layout to the target layout
is easier to verify and less likely to confuse source/destination ordering. Also, ensure
`split_dims` loop counts match: splitting N dims out of a product requires N-1 splits
(the outermost dim is the quotient, not split out separately).
---
## 2026-03-18 — CUDA Search Rejects All Candidates: Zero Dummy Data Causes NaN for Div/Pow/Mod/Erf
### What the symptom was
6 CUDA tests (`test_pow`, `test_pow_broadcast`, `test_div`, `test_mod`, `test_mod_broadcast`,
`test_erf`) consistently failed with `Failed to find a viable initial genome for group 0 after
100 attempts`. All 6 passed on native backend.
### What the actual root cause was
The CUDA two-phase initialization in `build_cuda_backend` set ALL input tensor buffers to
`0.0f32` as dummy data for profiling. When `torch.compile` decomposes a model, it passes
model weights as additional ONNX graph inputs (not initializers). Since there were no ONNX
initializers to overwrite the zeros, weight buffers stayed all-zero during search.
Operations with zero inputs produced NaN:
- `fmod(0, 0) = NaN` (Mod test)
- `weight * recip(0) = weight * inf` → with any zero weight → `0 * inf = NaN` (Div test)
- `abs(0).log() = log(0) = -inf` → downstream NaN (Pow test)
- `sign(0)` chain → operations on zero inputs (Erf test)
The `has_nan_outputs` check rejected every candidate genome, exhausting all 100 attempts.
### Why it was hard to find
1. **No panic, no crash — silent NaN rejection**: The error message said "Failed to find a
viable initial genome" which suggested an egglog rewrite issue, not a data issue.
2. **Works on native**: `NativeRuntime::has_nan_outputs()` returns `false` by default (no NaN
check), so zero inputs never caused problems on native.
3. **torch.compile vs direct export difference**: Directly exporting a model via
`torch.onnx.export(model, ...)` produces initializers. But `torch.compile`'s backend
receives a `GraphModule` where weights are graph inputs, not initializers. The ONNX file
from `torch.compile` has 0 initializers.
4. **CudaRuntime's own `allocate_dummy_input` already uses 1.0**: The runtime knew zeros
were problematic (comment: "Zero inputs often hide numerical issues"), but the
`compiled_graph.rs` code used `0.0f32` independently.
### The fix
Changed dummy data from `vec![0.0f32; n_elements]` to `vec![1.0f32; n_elements]` in
`build_cuda_backend`. Using 1.0 is numerically safe: `fmod(1,1)=0`, `recip(1)=1`,
`log(1)=0`, `exp(1)≈2.7` — no NaN or inf. Profiling timing is unaffected (same number
of FLOPs and memory accesses).
### General principle
**Use small non-zero values (1.0) for dummy profiling data, never zeros.** Zero is a
singularity for many floating-point operations (division, log, fmod with zero divisor).
The CUDA runtime's `allocate_dummy_input` already followed this principle — the ONNX
pipeline's `build_cuda_backend` was inconsistent. When creating dummy data for GPU
profiling, always match the runtime's safer default.
---
## 2026-03-18 — Dynamic Decode Loop Fails: HLIR Weight Buffers Consumed After First Execute
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` passed step 0 (seq_len=6) but panicked on step 1
(seq_len=7) with `no entry found for key` at `cublaslt/mod.rs:294` — the CuBlasLt op couldn't
find its weight input buffer.
### What the actual root cause was
**Two bugs:**
1. **Missing `)` in egglog rule** (`luminal_cuda_lite/src/kernel/hlir.rs:3042`): The fourth
KernelEmbed rule ("kernel embed with mul reversed") had 3 closing parens after `INil` instead
of 4. The missing `)` failed to close the `(= ?mul_result ...)` form. This caused an egglog
parse error during search, caught by `catch_unwind`. The rule was dead code — it never fired,
but the parse error consumed a search iteration.
2. **HLIR buffer consumption killed weight buffers** (`luminal_cuda_lite/src/runtime.rs:1010-1057`):
After each `execute()`, the runtime removed all HLIR buffers (weights, constants) except those
directly connected to Output nodes. This was intended to free one-shot input data, but it also
deleted all 168 weight buffers. On the next `graph.run()`, CuBlasLt couldn't find any of its
weight inputs — `hlir_buffers` had 1 entry (the just-set `input_ids`) instead of 169.
### Why it was hard to find
1. **Misdirection by the egglog syntax error**: The plan identified the missing `)` as THE cause.
Fixing it allowed the rule to parse correctly, but the real runtime failure was independent.
2. **Step 0 always succeeds**: The weight consumption happens AFTER a successful execution. So
the first `graph.run()` works perfectly — all 169 HLIR buffers exist. The panic only occurs
on the second call, after consumption has cleared 168 of them.
3. **The consumption code was deliberately designed**: Comments said "weight tensors must have
`.persist()` to survive." The ONNX pipeline didn't call `.persist()` on weights, but this
had never been a problem before because single-shot inference only calls `execute()` once.
4. **Search phase panics masked by `catch_unwind`**: The same "no entry found for key" error
occurred during profiling of search candidates, but was silently caught. This made it look
like only certain LLIR variants had the issue, not all of them.
5. **Debug output needed 4 iterations to find**: The first debug showed which NodeIndex was
missing, the second showed it was an Input node, the third showed the HLIR mapping, and
the fourth revealed `hlir_buffers_count` dropping from 169 to 1 between steps.
### The fix
1. Added missing `)` to the KernelEmbed egglog rule at `hlir.rs:3042`.
2. In `compiled_graph.rs`, added `.persist()` calls on all weight/constant tensors (anything
not in `input_names`) after `process_onnx_nodes` completes. `.persist()` creates an Output
node connected to the Input, which the consumption code recognizes as "do not consume."
User inputs (like `input_ids`) are intentionally NOT persisted — they are consumed after
each `execute()` and re-set via `set_input()` before the next call.
### General principle
**Mark weight/constant tensors as persistent in the graph-building pipeline.** The runtime's
`execute()` consumes all HLIR buffers not connected to Output nodes. This is correct behavior
for one-shot user inputs, but weights must survive across calls. Always call `.persist()` on
tensors that should outlive a single execution. In the ONNX pipeline, the distinction is clear:
`input_names` (user-provided data per step) vs everything else (weights/constants loaded once).
---
## 2026-03-20 — PT2 CUDA Search Rejects All Candidates: Integer Buffers Misinterpreted as Float NaN
### What the symptom was
`test_hf_llama_tiny` on CUDA via PT2 failed with:
`pyo3_runtime.PanicException: Failed to find a viable initial genome for group 0 after 100 attempts`
The search tried 100 different egglog rewrites and ALL were rejected by the `has_nan_outputs` check.
### What the actual root cause was
**Two issues, both required to fix:**
1. **Integer buffers misinterpreted as float in NaN check.** `has_nan_outputs` in
`luminal_cuda_lite/src/runtime.rs` checks ALL `self.buffers` by reinterpreting raw bytes
as `f32` and calling `is_nan()`. The PT2-translated graph has integer intermediate
buffers (from `arange`, `cast(Int)`, integer arithmetic for embedding index computation).
Certain valid `i32` bit patterns (e.g., large integers from `token_id * hidden_dim`)
have exponent=0xFF and non-zero mantissa when reinterpreted as f32 — matching the
IEEE 754 NaN pattern. This caused false NaN rejections for EVERY candidate genome.
2. **Real weights/constants loaded before search contain -inf.** The PT2 path loaded real
safetensors weights and model constants (including the causal attention mask with `-inf`
values) BEFORE the search. While the ONNX path also loads real initializer data before
search, the PT2 graph's different structure (more explicit integer operations) made the
integer NaN false-positive the blocking issue.
### Why it was hard to find
- The original plan diagnosed this as the same zero-dummy-data bug fixed on 2026-03-18.
Changing `0.0` to `1.0` was insufficient because the root cause was different.
- `has_nan_outputs` checking ALL intermediate buffers (not just outputs) masked the real
issue — the NaN was in integer index-computation buffers, not in the model's float outputs.
- The ONNX-translated graph didn't have this problem because it doesn't produce as many
integer intermediate buffers (ONNX embedding uses different ops).
- The NaN pattern was identical across all 100 search attempts, which was the key clue:
it was deterministic and independent of egglog rewrite choices, pointing to input data
or buffer interpretation rather than graph optimization issues.
### The fix
Four changes:
1. **`luminal_cuda_lite/src/kernel/mod.rs`** (`KernelOp` trait): Added `output_dtype()`
method with default `DType::F32`. Each kernel now reports its actual output dtype.
2. **`luminal_cuda_lite/src/kernel/hlir.rs`** and **`other_ops.rs`**: Overrode
`output_dtype()` in all kernels with a `dtype` field (returns `self.dtype`), plus
special cases: `KernelIota``DType::Int`, `KernelLessThan``DType::Bool`,
`KernelCast``self.out_dtype`.
3. **`luminal_cuda_lite/src/runtime.rs`** (`has_nan_outputs`): Replaced fragile
`format!("{:?}").contains("dtype: Int")` string matching with proper
`op.to_dialect::<dyn KernelOp>().output_dtype()` check. Only F32 buffers are
checked for NaN; integer and bool buffers are skipped.
4. **`rust/src/pt2_compiled_model.rs`** (`init_cuda_runtime`): Set ALL input nodes
(weights, constants, user inputs) to `vec![1.0f32; n_elements]` before search via
new `set_all_inputs_dummy_cuda` function, then reload real data after search.
This prevents any -inf values from the causal mask from polluting intermediate
float computations during profiling.
### General principle
**Never reinterpret integer buffer bytes as float for NaN checking.** When a graph has
mixed-dtype operations (float model computation + integer index computation), raw byte
buffers from integer kernels contain valid i32 values that look like NaN when cast to f32.
The search's `has_nan_outputs` must be dtype-aware — use the kernel's `output_dtype()`
method rather than string-matching on Debug output. Additionally, when diagnosing "all
candidates rejected" during search, check whether the rejection is from actual float NaN
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
identical across all attempts (dtype issue) vs varying (actual numerical issue).
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.

View File

View File

@@ -0,0 +1,369 @@
"""Run pytest on Modal with a dynamically selected GPU.
Usage:
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
uv run modal run modal_pytest_runner.py --gpu T4 tests/
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
"""
import argparse
import json
import os
import shutil
import subprocess
import sys
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import modal
from modal.volume import FileEntryType
app = modal.App("luminal-tests")
DEFAULT_TIMEOUT = 30 * 60
CUDARC_CUDA_VERSION = "12080"
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = "/root/luminal/crates/luminal_python"
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
SRC_PATH = f"{PROJECT_DIR}/src"
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
HF_TOKEN_ENV_KEY = "HF_TOKEN"
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
HF_CACHE_VOLUME = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
image = (
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
.uv_sync(
str(LOCAL_PROJECT_DIR),
frozen=False,
groups=["dev"],
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
)
.workdir(PROJECT_DIR)
.add_local_dir(
str(LOCAL_PROJECT_DIR.parent.parent),
remote_path="/root/luminal",
copy=True,
ignore=[
".git",
".claude-project",
".cargo-local",
"**/.venv",
"**/.pytest_cache",
"**/__pycache__",
"**/luminal_artifacts",
"**/target",
"docs",
],
)
)
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _hf_token_secret() -> modal.Secret | None:
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
if not hf_token:
return None
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
return any(arg == flag for arg in pytest_args)
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
return (
cli_profile
or _has_pytest_flag(pytest_args, "--profile")
or _has_pytest_flag(pytest_args, "--profile-svg")
)
def _run_id() -> str:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
def _prepare_scratch_dir(scratch_dir: Path) -> None:
scratch_dir.mkdir(parents=True, exist_ok=True)
linked_names = {
".venv",
".pytest_cache",
"__pycache__",
"luminal_artifacts",
"prof",
}
for entry in Path(PROJECT_DIR).iterdir():
if entry.name in linked_names:
continue
target = scratch_dir / entry.name
if target.exists() or target.is_symlink():
continue
target.symlink_to(entry, target_is_directory=entry.is_dir())
def _default_profile_output_dir(run_id: str) -> Path:
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
def _prepare_local_profile_dir(output_dir: Path) -> None:
if output_dir.exists() and not output_dir.is_dir():
raise NotADirectoryError(f"{output_dir} is not a directory")
output_dir.mkdir(parents=True, exist_ok=True)
prof_dir = output_dir / "prof"
if prof_dir.exists():
shutil.rmtree(prof_dir)
manifest_path = output_dir / "manifest.json"
if manifest_path.exists():
manifest_path.unlink()
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
_prepare_local_profile_dir(output_dir)
for entry in entries:
relative_path = Path(entry.path).relative_to(run_id)
if relative_path == Path("."):
continue
destination = output_dir / relative_path
if entry.type == FileEntryType.DIRECTORY:
destination.mkdir(parents=True, exist_ok=True)
continue
if entry.type != FileEntryType.FILE:
continue
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as handle:
for chunk in PROFILE_VOLUME.read_file(entry.path):
handle.write(chunk)
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
try:
PROFILE_VOLUME.remove_file(run_id, recursive=True)
except FileNotFoundError:
return
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
class TestRunner:
@modal.method()
def run(
self,
pytest_args: list[str],
pytest_addopts: str = "",
profile_enabled: bool = False,
) -> dict[str, Any]:
started_at = _utc_now()
run_id = _run_id() if profile_enabled else None
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
if scratch_dir is not None:
_prepare_scratch_dir(scratch_dir)
env = os.environ.copy()
existing = env.get("PYTHONPATH")
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
env["LUMINAL_BACKEND"] = "cuda"
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
env["HF_HOME"] = HF_CACHE_PATH
if pytest_addopts:
env["PYTEST_ADDOPTS"] = pytest_addopts
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
dot_available = shutil.which("dot") is not None
sanitized_pytest_args = [
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
]
if profile_enabled:
sanitized_pytest_args.append("--profile")
if dot_available:
sanitized_pytest_args.append("--profile-svg")
elif original_svg_requested:
print(
"Graphviz 'dot' is unavailable in the Modal container; "
"falling back to raw .prof artifacts only.",
file=sys.stderr,
)
svg_requested = profile_enabled and dot_available
cmd = [
"uv",
"run",
"--project",
PROJECT_DIR,
"--group",
"dev",
"--reinstall-package",
"luminal_python",
"python",
"-m",
"pytest",
*sanitized_pytest_args,
]
exit_code = subprocess.run(
cmd,
env=env,
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
).returncode
HF_CACHE_VOLUME.commit()
finished_at = _utc_now()
if not profile_enabled:
return {
"exit_code": exit_code,
"run_id": None,
"profile_enabled": False,
"remote_profile_dir": None,
"local_default_dirname": None,
}
volume_root = Path(PROFILE_VOLUME_PATH)
if not volume_root.exists():
raise RuntimeError(
"Profiling requested but the profile volume is not mounted."
)
remote_run_dir = volume_root / run_id
remote_run_dir.mkdir(parents=True, exist_ok=True)
prof_dir = scratch_dir / "prof"
if prof_dir.is_dir():
shutil.copytree(prof_dir, remote_run_dir / "prof")
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
manifest = {
"exit_code": exit_code,
"finished_at": finished_at,
"profile_enabled": True,
"pytest_args": sanitized_pytest_args,
"run_id": run_id,
"started_at": started_at,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
(remote_run_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
PROFILE_VOLUME.commit()
return {
"exit_code": exit_code,
"run_id": run_id,
"profile_enabled": True,
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
"local_default_dirname": run_id,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
def _parse_cli_args(
cli_args: tuple[str, ...],
) -> tuple[str, int | None, bool, str | None, list[str]]:
parser = argparse.ArgumentParser(
prog="modal run modal_pytest_runner.py",
add_help=False,
allow_abbrev=False,
description="Run pytest on Modal with a dynamically selected GPU.",
)
parser.add_argument(
"--gpu",
required=True,
help="GPU type to request from Modal (for example: A100, T4, H100).",
)
parser.add_argument(
"--timeout",
type=int,
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable pytest-profiling and download the resulting artifacts locally.",
)
parser.add_argument(
"--profile-output-dir",
help="Directory to download profiling artifacts into when profiling is enabled.",
)
parsed, pytest_args = parser.parse_known_args(cli_args)
if pytest_args and pytest_args[0] == "--":
pytest_args = pytest_args[1:]
if not pytest_args:
pytest_args = ["tests/"]
return (
parsed.gpu,
parsed.timeout,
parsed.profile,
parsed.profile_output_dir,
pytest_args,
)
@app.local_entrypoint()
def main(*cli_args: str):
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
cli_args
)
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
runner_options = {"gpu": gpu}
hf_token_secret = _hf_token_secret()
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
if timeout is not None:
runner_options["timeout"] = timeout
if profile_enabled:
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
runner_options["volumes"] = runner_volumes
if hf_token_secret is not None:
runner_options["secrets"] = [hf_token_secret]
runner = TestRunner.with_options(**runner_options)()
result = runner.run.remote(
pytest_args=pytest_args,
pytest_addopts=pytest_addopts,
profile_enabled=profile_enabled,
)
if result["profile_enabled"] and result["run_id"] is not None:
if profile_output_dir:
output_dir = Path(profile_output_dir).expanduser().resolve()
else:
output_dir = _default_profile_output_dir(result["local_default_dirname"])
try:
_download_profile_artifacts(result["run_id"], output_dir)
print(f"Profile artifacts downloaded to {output_dir}")
_cleanup_remote_profile_artifacts(result["run_id"])
except FileNotFoundError as exc:
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
except OSError as exc:
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
sys.exit(result["exit_code"])

View File

@@ -0,0 +1,65 @@
[project]
name = "luminal_python"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.0.2",
"torch>=2.10.0",
"onnx",
"onnxscript",
"safetensors",
"flash-attn-3>=3.0.0",
]
[tool.uv]
no-build-isolation-package = ["flash-attn"]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
flash-attn-3 = { index = "pytorch-cu128" }
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.maturin]
python-source = "src"
manifest-path = "rust/Cargo.toml"
module-name = "luminal.luminal"
[tool.pytest.ini_options]
markers = [
"slow: tests that download large models or require pre-generated artifacts",
]
[dependency-groups]
dev = [
"maturin>=1.0,<2.0",
"maturin-import-hook>=0.3.0",
"pytest>=9.0.2",
"pytest-profiling",
"snakeviz",
"pytest-randomly>=4.0.1",
"transformers>=5.5.0,<6",
"diffusers>=0.35.0",
"onnxsim",
"tiktoken>=0.12.0",
"pydantic>=2.12.5",
"psutil>=7.2.2",
"modal>=1.3.5",
"pillow",
"flash-attn>=2.8.3",
]
flash-attention-4 = [
"nvidia-cutlass-dsl==4.1.0",
]

View File

@@ -0,0 +1,44 @@
#!/bin/bash
set -e
echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
# ── Phase 1: Native Backend ─────────────────────────────────
echo ""
echo "=== Phase 1: Building native backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml
echo ""
echo "--- 1a: Native + ONNX ---"
uv run pytest $NATIVE_TESTS -v
echo ""
echo "--- 1b: Native + PT2 ---"
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
# ── Phase 2: CUDA Backend ───────────────────────────────────
echo ""
echo "=== Phase 2: Building CUDA backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
echo ""
echo "--- 2a: CUDA + ONNX ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "--- 2b: CUDA + PT2 ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "=========================================="
echo " All tests passed!"
echo "=========================================="

View File

@@ -0,0 +1,22 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml
# Run pytest
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -0,0 +1,20 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner (PT2 Export Mode) ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml
# Run pytest with PT2 export mode
echo "Step 3: Running pytest with PT2 export mode..."
LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -0,0 +1,20 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -0,0 +1,19 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner (CUDA + PT2 Export Mode) ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend and PT2 export mode
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -0,0 +1,30 @@
[package]
name = "luminal_python"
version = "0.1.0"
edition.workspace = true
[lib]
name = "luminal"
crate-type = ["cdylib"]
path = "src/lib.rs"
[features]
cuda = ["dep:luminal_cuda_lite"]
[dependencies]
onnx-protobuf = "0.2"
protobuf = "~3.4"
rustc-hash = "2.1.1"
luminal = {path= "../../.."}
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
serde = { version = "1", features = ["derive"] }
serde_json = "1"
zip = "2"
anyhow = "1"
memmap2 = "0.9"
safetensors = "0.5"
half = "2"
[dependencies.pyo3]
version = "0.28.0"
features = ["abi3-py38"]

View File

@@ -0,0 +1,514 @@
#[cfg(feature = "cuda")]
use luminal::prelude::tracing::{trace, warn};
use luminal::{prelude::*, shape::Expression, visualization::ToDot};
use pyo3::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::collections::HashSet;
use crate::{runtime::RuntimeBackend, util::DimParamMap};
/// Common intermediate result from translating a model graph (ONNX or FX).
pub struct GraphTranslation {
pub graph: Graph,
pub tensor_ids: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
/// Pre-loaded weight data from any model format.
///
/// NOTE: Currently assumes all data is F32. When the type system branch lands
/// with proper multi-dtype support, this struct (and all callers) will need
/// updating to carry dtype metadata alongside the raw data.
pub struct WeightData {
/// (Input node label, f32 data) for weights and constants.
pub weights: Vec<(String, Vec<f32>)>,
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
pub tensor_sizes: HashMap<String, usize>,
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
pub device_ptrs: HashMap<String, (u64, usize)>,
}
#[pyclass(unsendable)]
pub struct CompiledGraph {
pub graph: Graph,
pub runtime: RuntimeBackend,
pub tensor_ids: HashMap<String, NodeIndex>,
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
label_map: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
impl CompiledGraph {
/// Shared compilation pipeline for both ONNX and FX/PT2 graphs.
///
/// Takes a format-neutral `GraphTranslation` (produced by `translate_onnx` or
/// `translate_pt2`) and `WeightData`, builds the backend, loads weights, and
/// returns a ready-to-execute `CompiledGraph`.
pub fn parse_graph(
translation: GraphTranslation,
weight_data: WeightData,
backend: &str,
search_iters: usize,
) -> Result<CompiledGraph, String> {
let GraphTranslation {
mut graph,
tensor_ids,
input_names,
output_names,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
} = translation;
let rt = match backend {
#[cfg(feature = "cuda")]
"cuda" | "gpu" => {
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
}
"native" | "cpu" => {
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
}
_ => {
#[cfg(feature = "cuda")]
{
return Err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
));
}
#[cfg(not(feature = "cuda"))]
{
if backend == "cuda" {
return Err(
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
.to_string(),
);
}
return Err(format!(
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
backend
));
}
}
};
// Resolve concrete output shapes from expressions
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
.iter()
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
.collect();
let label_map = CompiledGraph::build_label_map(&graph);
Ok(CompiledGraph {
graph,
runtime: rt,
tensor_ids,
label_map,
input_names,
output_names,
output_shapes,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
})
}
/// Build a label → NodeIndex map for all Input nodes in the graph.
/// Used for efficient weight loading by label matching.
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
graph
.graph
.node_indices()
.filter_map(|node_id| {
(*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
.map(|input| (input.label.clone(), node_id))
})
.collect()
}
#[cfg(feature = "cuda")]
fn build_cuda_backend(
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
let device_ptrs = &weight_data.device_ptrs;
use luminal_cuda_lite::cudarc::driver::CudaContext;
use luminal_cuda_lite::runtime::CudaRuntime;
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
let stream = cuda_ctx.default_stream();
graph.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Build label → NodeIndex map for device pointer matching.
let label_map = CompiledGraph::build_label_map(graph);
// For weights with device pointers: use them directly (zero-copy).
// This avoids allocating ~N GB of dummy data during search.
// The pointers survive search because profiling mode skips buffer consumption,
// and persist_hlir_node ensures they survive post-search execution too.
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
let mut matched_count = 0usize;
let mut missed_labels: Vec<String> = Vec::new();
for (label, &(ptr, n_bytes)) in device_ptrs {
if let Some(&node_id) = label_map.get(label) {
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
rt.persist_hlir_node(node_id);
device_ptr_nodes.insert(node_id);
matched_count += 1;
} else {
missed_labels.push(label.clone());
}
}
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
trace!(
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
matched_count,
missed_labels.len(),
device_ptrs.len(),
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
);
if !missed_labels.is_empty() {
warn!(
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
missed_labels.len(),
&missed_labels[..missed_labels.len().min(10)]
);
let available: Vec<&String> = label_map.keys().take(10).collect();
warn!(
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
available
);
}
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
// device pointers) for safe search profiling.
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
// - fmod(0, 0) = NaN (Mod)
// - recip(0) = inf → weight * inf = NaN (Div)
// - log(0) = -inf (Pow)
// - chain ops with zero produce NaN (Erf)
let mut dummy_total_elements = 0usize;
let mut dummy_count = 0usize;
for node_id in graph.graph.node_indices() {
if device_ptr_nodes.contains(&node_id) {
continue;
}
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
if n > 0 {
dummy_total_elements += n;
dummy_count += 1;
rt.set_data(node_id, vec![1.0f32; n]);
}
}
}
}
trace!(
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
dummy_count,
dummy_total_elements,
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
);
// Search (device-pointer weights are used directly; dummy data for the rest)
let mut rt = graph.search(rt, search_iters);
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
let mut loaded_weight_elements = 0usize;
let mut loaded_weight_count = 0usize;
for (label, data) in &weight_data.weights {
if !device_ptrs.contains_key(label) {
if let Some(&node_id) = label_map.get(label) {
loaded_weight_elements += data.len();
loaded_weight_count += 1;
rt.set_data(node_id, data.clone());
}
}
}
trace!(
"[CUDA BUILD] Post-search weight load: {} weights, {} elements ({:.3} GiB as f32)",
loaded_weight_count,
loaded_weight_elements,
(loaded_weight_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
);
Ok(RuntimeBackend::Cuda(Box::new(rt)))
}
fn build_native_backend(
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
graph.build_search_space::<NativeRuntime>();
let mut rt = graph.search(NativeRuntime::default(), search_iters);
// Load weight data after search
let label_map = CompiledGraph::build_label_map(graph);
for (label, data) in &weight_data.weights {
if let Some(&node_id) = label_map.get(label) {
rt.set_data(node_id, data.clone());
}
}
Ok(RuntimeBackend::Native(rt))
}
}
#[pymethods]
impl CompiledGraph {
/// Get the list of input tensor names.
#[getter]
fn input_names(&self) -> Vec<String> {
self.input_names.clone()
}
/// Get the list of output tensor names.
#[getter]
fn output_names(&self) -> Vec<String> {
self.output_names.clone()
}
/// Get the output shapes.
#[getter]
fn output_shapes(&self) -> Vec<Vec<usize>> {
self.output_shapes.clone()
}
/// Get all tensor names in the graph.
#[getter]
fn tensor_names(&self) -> Vec<String> {
self.tensor_ids.keys().cloned().collect()
}
/// Get the name of the active backend (native or cuda).
#[getter]
fn backend(&self) -> &'static str {
self.runtime.name()
}
/// Whether this graph has dynamic (symbolic) dimensions.
#[getter]
fn has_dynamic_dims(&self) -> bool {
!self.dim_param_map.is_empty()
}
/// Get the dynamic dimension parameter names (e.g. ["seq_len"]).
#[getter]
fn dim_params(&self) -> Vec<String> {
self.dim_param_map.keys().cloned().collect()
}
/// Set a dynamic dimension value by its param name (e.g. "seq_len").
fn set_dim(&mut self, param_name: &str, value: usize) -> PyResult<()> {
let ch = self.dim_param_map.get(param_name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown dim param '{}'. Available: {:?}",
param_name,
self.dim_param_map.keys().collect::<Vec<_>>()
))
})?;
self.graph.set_dim(*ch, value);
Ok(())
}
/// Auto-detect and set dynamic dimensions from input tensor shapes.
/// For each user input, matches the concrete shape against its symbolic
/// shape expressions and sets the corresponding dyn_map entries.
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
// Check if this expression is a bare symbolic variable
let terms = dim_expr.terms.read();
if terms.len() == 1
&& let luminal::shape::Term::Var(c) = terms[0]
{
self.graph.set_dim(c, dim_val);
}
}
}
}
/// Resolve output shapes using current dynamic dimension values.
/// Returns concrete shapes after substituting all symbolic dims.
fn resolve_output_shapes(&self) -> PyResult<Vec<Vec<usize>>> {
let dyn_map = &self.graph.dyn_map;
let mut result = Vec::new();
for shape_exprs in &self.output_shape_exprs {
let shape: Vec<usize> = shape_exprs
.iter()
.map(|e| {
e.exec(dyn_map).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Cannot resolve dimension expression {:?}. Set all dynamic dims first.",
e
))
})
})
.collect::<PyResult<Vec<usize>>>()?;
result.push(shape);
}
Ok(result)
}
/// Set input tensor data by name.
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
self.runtime.set_data(*node_id, data);
Ok(())
}
/// Set input tensor data from a CPU host memory pointer (avoids Python list conversion).
/// The pointer must point to contiguous f32 data (from tensor.data_ptr() on a CPU float32 tensor).
fn set_input_from_ptr(&mut self, name: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
let data: Vec<f32> =
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
self.runtime.set_data(*node_id, data);
Ok(())
}
/// Set input from a CUDA device pointer. Zero-copy on device.
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
#[cfg(feature = "cuda")]
fn set_input_device_ptr(
&mut self,
name: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_input_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Mark an input tensor as persistent (survives execute() calls).
/// Call this for weight tensors that should not be consumed after each execution.
fn persist_input(&mut self, name: &str) -> PyResult<()> {
let _node_id = *self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
match &mut self.runtime {
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => rt.persist_hlir_node(_node_id),
RuntimeBackend::Native(_) => {} // Native: persist is handled at graph level
}
Ok(())
}
/// Set a weight tensor from a CUDA device pointer, matching by Input node label.
/// Also marks the weight as persistent. For PT2 weights (e.g. "fc1.weight").
#[cfg(feature = "cuda")]
fn set_weight_device_ptr(
&mut self,
label: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
rt.persist_hlir_node(node_id);
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_weight_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Set a weight tensor from a CPU host pointer, matching by Input node label.
fn set_weight_from_ptr(&mut self, label: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
let data: Vec<f32> =
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
self.runtime.set_data(node_id, data);
Ok(())
}
/// Execute the graph.
fn run(&mut self) {
self.runtime.execute(&self.graph.dyn_map);
}
/// Return the HLIR graph as a DOT string for visualization.
fn to_dot(&self) -> PyResult<String> {
self.graph.graph.to_dot().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("DOT generation failed: {e}"))
})
}
/// Get output tensor data by name (copies to host).
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
Ok(self.runtime.get_f32(*node_id))
}
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
#[cfg(feature = "cuda")]
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
Ok(())
}
_ => Err(pyo3::exceptions::PyValueError::new_err(
"copy_output_to_device_ptr requires CUDA backend",
)),
}
}
}

View File

@@ -0,0 +1,248 @@
use std::collections::HashMap;
use luminal::{prelude::*, shape::Expression};
use onnx_protobuf::NodeProto;
use crate::ops_parse::*;
pub fn process_onnx_nodes(
nodes: &[NodeProto],
tensors: &mut HashMap<String, GraphTensor>,
cx: &mut Graph,
weight_data: &mut Vec<(String, Vec<f32>)>,
known_values: &mut HashMap<String, Vec<f32>>,
shape_exprs: &mut HashMap<String, Vec<Expression>>,
) -> Result<(), String> {
for node in nodes {
match node.op_type.as_str() {
"Add" => parse_binary_broadcast_op(
node,
tensors,
"Add",
|a, b| a + b,
shape_exprs,
known_values,
)?,
"Mod" => parse_binary_broadcast_op(
node,
tensors,
"Mod",
|a, b| a % b,
shape_exprs,
known_values,
)?,
"Sub" => parse_binary_broadcast_op(
node,
tensors,
"Sub",
|a, b| a - b,
shape_exprs,
known_values,
)?,
"Mul" => parse_binary_broadcast_op(
node,
tensors,
"Mul",
|a, b| a * b,
shape_exprs,
known_values,
)?,
"Div" => parse_binary_broadcast_op(
node,
tensors,
"Div",
|a, b| a / b,
shape_exprs,
known_values,
)?,
"Sqrt" => parse_unary_op(node, tensors, "Sqrt", |a| a.sqrt())?,
"Transpose" => parse_transpose_node(node, tensors)?,
"Concat" => parse_concat_node(node, tensors, shape_exprs, known_values)?,
"Floor" => parse_floor_node(node, tensors)?,
"Ceil" => parse_ceil_node(node, tensors)?,
"Sin" => parse_unary_op(node, tensors, "Sin", |a| a.sin())?,
"Neg" => parse_unary_op(node, tensors, "Neg", |a| -a)?,
"Cos" => parse_unary_op(node, tensors, "Cos", |a| a.cos())?,
"Pow" => parse_binary_broadcast_op(
node,
tensors,
"Pow",
|a, b| a.pow(b),
shape_exprs,
known_values,
)?,
"Sigmoid" => parse_unary_op(node, tensors, "Sigmoid", |a| a.sigmoid())?,
"Tanh" => parse_unary_op(node, tensors, "Tanh", |a| a.tanh())?,
"Relu" => parse_unary_op(node, tensors, "Relu", |a| a.relu())?,
"Softmax" => parse_softmax_node(node, tensors)?,
"Abs" => parse_unary_op(node, tensors, "Abs", |a| a.abs())?,
"Reciprocal" => parse_unary_op(node, tensors, "Reciprocal", |a| a.reciprocal())?,
"Clip" => parse_clip_node(node, tensors, known_values)?,
"Equal" => parse_binary_broadcast_op(
node,
tensors,
"Equal",
|a, b| a.eq(b),
shape_exprs,
known_values,
)?,
"Where" => parse_where_node(node, tensors)?,
"Constant" => {
parse_constant_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
}
"ConstantOfShape" => {
parse_constant_of_shape(node, tensors, cx, weight_data, known_values, shape_exprs)?
}
"Cast" => parse_cast_node(node, tensors, weight_data, known_values, shape_exprs)?,
"MatMul" => parse_matmul_node(node, tensors)?,
"Reshape" => parse_reshape_node(node, tensors, known_values, shape_exprs)?,
"Shape" => parse_shape_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
"Gather" => {
parse_gather_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
}
"GatherND" => parse_gathernd_node(node, tensors, cx, weight_data, known_values)?,
"Less" => parse_binary_broadcast_op(
node,
tensors,
"Less",
|a, b| a.lt(b),
shape_exprs,
known_values,
)?,
"Greater" => parse_binary_broadcast_op(
node,
tensors,
"Greater",
|a, b| b.lt(a),
shape_exprs,
known_values,
)?,
"LessOrEqual" => parse_binary_broadcast_op(
node,
tensors,
"LessOrEqual",
|a, b| a.le(b),
shape_exprs,
known_values,
)?,
"GreaterOrEqual" => parse_binary_broadcast_op(
node,
tensors,
"GreaterOrEqual",
|a, b| a.ge(b),
shape_exprs,
known_values,
)?,
"Not" => parse_not_node(node, tensors)?,
"And" => parse_binary_broadcast_op(
node,
tensors,
"And",
|a, b| a.cast(DType::F32) * b.cast(DType::F32),
shape_exprs,
known_values,
)?,
"Or" => parse_binary_broadcast_op(
node,
tensors,
"Or",
|a, b| (a.cast(DType::F32) + b.cast(DType::F32)).minimum_f32(1.0),
shape_exprs,
known_values,
)?,
"Xor" => parse_binary_broadcast_op(
node,
tensors,
"Xor",
|a, b| a.ne(b),
shape_exprs,
known_values,
)?,
"Min" => parse_variadic_broadcast_op(
node,
tensors,
"Min",
|a, b| a.minimum(b),
shape_exprs,
known_values,
)?,
"Max" => parse_variadic_broadcast_op(
node,
tensors,
"Max",
|a, b| a.maximum(b),
shape_exprs,
known_values,
)?,
"Identity" => parse_identity(node, tensors, known_values, shape_exprs)?,
"Unsqueeze" => parse_unsqueeze_node(node, tensors, known_values, shape_exprs)?,
"Squeeze" => parse_squeeze_node(node, tensors, known_values, shape_exprs)?,
"ReduceSum" => parse_reduce_op(
node,
tensors,
known_values,
"ReduceSum",
|t, axes| t.sum(axes),
|flat, _n| flat.sum(1),
)?,
"ReduceMax" => parse_reduce_op(
node,
tensors,
known_values,
"ReduceMax",
|t, axes| t.max(axes),
|flat, _n| flat.max(1),
)?,
"ReduceMin" => parse_reduce_op(
node,
tensors,
known_values,
"ReduceMin",
|t, axes| t.min(axes),
|flat, _n| flat.min(1),
)?,
"ReduceMean" => parse_reduce_op(
node,
tensors,
known_values,
"ReduceMean",
|t, axes| t.mean(axes),
|flat, n| flat.sum(1) / n as f32,
)?,
"Trilu" => parse_trilu_node(node, tensors, cx, known_values)?,
"GatherElements" => parse_gather_elements_node(node, tensors)?,
"ScatterElements" => parse_scatter_elements_node(node, tensors)?,
"ScatterND" => parse_scatter_nd_node(node, tensors)?,
"Expand" => parse_expand_node(node, tensors, known_values, shape_exprs)?,
"IsNaN" => parse_unary_op(node, tensors, "IsNaN", |a| a.ne(a))?,
"LayerNormalization" => parse_layernorm_node(node, tensors)?,
"Gemm" => parse_gemm_node(node, tensors)?,
"Erf" => parse_erf_node(node, tensors)?,
"Slice" => parse_slice_node(node, tensors, known_values, shape_exprs)?,
"Split" => parse_split_node(node, tensors, known_values)?,
"TopK" => parse_topk_node(node, tensors, known_values)?,
"OneHot" => parse_onehot_node(node, tensors, known_values)?,
"Range" => parse_range_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
"CumSum" => parse_cumsum_node(node, tensors, known_values)?,
"Gelu" => parse_unary_op(node, tensors, "Gelu", |a| a.gelu())?,
"Conv" => parse_conv_node(node, tensors)?,
"Pad" => parse_pad_node(node, tensors, known_values)?,
"Resize" => parse_resize_node(node, tensors, known_values)?,
"Tile" => parse_tile_node(node, tensors, known_values)?,
"ReduceL2" => parse_reduce_op(
node,
tensors,
known_values,
"ReduceL2",
|t, axes| (t * t).sum(axes).sqrt(),
|flat, _n| (flat * flat).sum(1).sqrt(),
)?,
"GroupNormalization" => parse_group_norm_node(node, tensors)?,
_ => {
panic!("Missing Node {}", node.op_type)
}
}
}
Ok(())
}

View File

@@ -0,0 +1,73 @@
mod compiled_graph;
mod dispatch;
mod onnx_translator;
mod ops_parse;
mod runtime;
mod util;
// PT2 modules
mod pt2_compiled_model;
mod pt2_parser;
mod pt2_schema;
mod pt2_util;
mod translator;
use compiled_graph::CompiledGraph;
use pt2_compiled_model::process_pt2;
use pyo3::prelude::*;
use std::collections::HashMap;
fn validate_backend(backend: &str) -> PyResult<()> {
match backend {
"native" => Ok(()),
#[cfg(feature = "cuda")]
"cuda" => Ok(()),
#[cfg(not(feature = "cuda"))]
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
)),
_ => {
#[cfg(feature = "cuda")]
{
Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
)))
}
#[cfg(not(feature = "cuda"))]
{
Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
backend
)))
}
}
}
}
#[pyfunction]
#[pyo3(signature = (path, backend="native", search_iters=10, weight_device_ptrs=None))]
fn process_onnx(
path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
) -> PyResult<CompiledGraph> {
validate_backend(backend)?;
onnx_translator::compile_onnx(
path,
backend,
weight_device_ptrs.unwrap_or_default(),
search_iters,
)
.map_err(pyo3::exceptions::PyRuntimeError::new_err)
}
#[pymodule]
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
m.add_class::<CompiledGraph>()?;
Ok(())
}

View File

@@ -0,0 +1,283 @@
use luminal::{
prelude::{
tracing::{Level, span, trace},
*,
},
shape::Expression,
};
use onnx_protobuf::ModelProto;
use protobuf::Message;
use std::{
collections::{HashMap, HashSet},
fs,
path::Path,
};
use crate::{
compiled_graph::{CompiledGraph, GraphTranslation, WeightData},
dispatch::process_onnx_nodes,
util::{
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
load_all_tensor_floats, load_initializer_as_f32,
},
};
/// Load, validate, translate, and compile an ONNX model.
///
/// This is the ONNX counterpart of `pt2_compiled_model::compile_pt2()`.
pub fn compile_onnx(
path: &str,
backend: &str,
weight_device_ptrs: HashMap<String, (u64, usize)>,
search_iters: usize,
) -> Result<CompiledGraph, String> {
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
let model = ModelProto::parse_from_bytes(&data)
.map_err(|e| format!("Failed to parse ONNX model: {}", e))?;
let opset_version = model
.opset_import
.iter()
.find(|entry| entry.domain.is_empty())
.map(|entry| entry.version);
match opset_version {
Some(20) => {}
Some(v) => {
return Err(format!(
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
));
}
None => {
return Err(
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
);
}
}
let (translation, mut weights) = translate_onnx(model, model_directory)?;
weights.device_ptrs = weight_device_ptrs;
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
}
/// Translate an ONNX model into a format-neutral GraphTranslation + WeightData.
pub fn translate_onnx(
model: ModelProto,
model_directory: &Path,
) -> Result<(GraphTranslation, WeightData), String> {
let _span = span!(Level::TRACE, "ONNX Graph Translation").entered();
let onnx_graph = &model.graph;
let mut cx = Graph::new();
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
// Dynamic dimension tracking
let mut dim_param_map: DimParamMap = HashMap::new();
let mut next_char = 'a';
// Separate initializers (weights) from true user inputs
let initializer_names: HashSet<&str> = onnx_graph
.initializer
.iter()
.map(|t| t.name.as_str())
.collect();
let input_names: Vec<String> = onnx_graph
.input
.iter()
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
.map(|inp| inp.name.clone())
.collect();
// Create input tensors with dynamic dimension support
for input in &onnx_graph.input {
let shape_exprs = get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
if shape_exprs.is_empty() {
let shape = get_shape_for_onnx_value(input);
if shape.is_empty() {
trace!("Input {} skipped because it is empty", input.name.clone());
continue;
}
let tensor = cx.named_tensor(input.name.clone(), shape);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
continue;
}
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
}
// Create initializer (weight) tensors
for init in &onnx_graph.initializer {
if !tensors.contains_key(&init.name) {
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
if shape.is_empty() {
shape = vec![1];
}
let tensor = cx.named_tensor(init.name.clone(), shape);
tensors.insert(init.name.clone(), tensor);
}
}
// Load small constants for constant folding
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
for init in &onnx_graph.initializer {
let n_elements: usize = init
.dims
.iter()
.map(|&d| d as usize)
.product::<usize>()
.max(1);
if n_elements <= 32 {
if let Some(floats) = load_initializer_as_f32(init) {
known_values.insert(init.name.clone(), floats);
} else {
panic!("Unable to load initializer values for {:?}", init.name);
}
}
}
// Shape expressions for propagating symbolic shapes through ONNX graphs
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
// Accumulates constant node data from process_onnx_nodes
let mut constant_data: Vec<(String, Vec<f32>)> = Vec::new();
// Process computation nodes
process_onnx_nodes(
&onnx_graph.node,
&mut tensors,
&mut cx,
&mut constant_data,
&mut known_values,
&mut shape_exprs,
)
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
// Mark weight/constant tensors as persistent so their buffers survive execute()
for (name, gt) in &tensors {
if !input_names.contains(name) {
gt.persist();
}
}
// Mark graph outputs (must happen before build_search_space)
let mut output_names = Vec::new();
let mut output_shape_exprs = Vec::new();
for output_vi in &onnx_graph.output {
if let Some(&gt) = tensors.get(&output_vi.name) {
// Force contiguous if the shape tracker is a non-contiguous view
let gt = if gt.shape != gt.shape.contiguous() {
let contiguous = gt * 1.0;
tensors.insert(output_vi.name.clone(), contiguous);
contiguous
} else {
gt
};
gt.output();
let dims = gt.dims();
output_shape_exprs.push(dims.clone());
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
if shape.is_empty() {
return Err(format!(
"Output tensor '{}' has no shape information in the ONNX model",
output_vi.name
));
}
output_names.push(output_vi.name.clone());
}
}
// Set initial dynamic dimension values from example input shapes
let has_dynamic = !dim_param_map.is_empty();
if has_dynamic {
for input in &onnx_graph.input {
if initializer_names.contains(input.name.as_str()) {
continue;
}
let concrete_shape = get_shape_for_onnx_value(input);
let expr_shape =
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
if expr.to_usize().is_none()
&& let Some(ch) = dim_param_map
.values()
.find(|&&ch| Expression::from(ch) == *expr)
{
cx.set_dim(*ch, *concrete);
}
}
}
}
// Build weight data: initializers + constants from process_onnx_nodes
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
if let Some(f) = floats {
weights.push((name, f));
}
}
weights.extend(constant_data);
// Build tensor sizes for CUDA dummy data allocation
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
for input in &onnx_graph.input {
if !initializer_names.contains(input.name.as_str()) {
let shape = get_shape_for_onnx_value(input);
let n: usize = shape.iter().product::<usize>().max(1);
tensor_sizes.insert(input.name.clone(), n);
}
}
for init in &onnx_graph.initializer {
let n: usize = init
.dims
.iter()
.map(|&d| d as usize)
.product::<usize>()
.max(1);
tensor_sizes.insert(init.name.clone(), n);
}
for (name, data) in &weights {
if !tensor_sizes.contains_key(name) {
tensor_sizes.insert(name.clone(), data.len());
}
}
// Collect tensor name → NodeIndex mapping
let tensor_ids: HashMap<String, NodeIndex> = tensors
.iter()
.map(|(name, gt)| (name.clone(), gt.id))
.collect();
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
let input_shape_exprs: Vec<Vec<Expression>> = input_names
.iter()
.map(|name| {
if let Some(&gt) = tensors.get(name) {
gt.dims()
} else {
vec![]
}
})
.collect();
let translation = GraphTranslation {
graph: cx,
tensor_ids,
input_names,
output_names,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
};
let weight_data = WeightData {
weights,
tensor_sizes,
device_ptrs: HashMap::new(),
};
Ok((translation, weight_data))
}

View File

@@ -0,0 +1,187 @@
use std::collections::HashMap;
use luminal::{
prelude::{tracing::trace, *},
shape::Expression,
};
use onnx_protobuf::NodeProto;
use crate::util::{broadcast_to_expr, compute_broadcast_shape_expr};
/// Handle Where node: conditional select — output[i] = condition[i] ? x[i] : y[i]
///
/// ONNX Where uses numpy-style broadcasting across all three inputs.
pub fn parse_where_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
assert!(node.input.len() == 3, "Where should have 3 inputs");
let condition = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Where: missing condition tensor '{}'", node.input[0]))?;
let x = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Where: missing X tensor '{}'", node.input[1]))?;
let y = *tensors
.get(&node.input[2])
.ok_or_else(|| format!("Where: missing Y tensor '{}'", node.input[2]))?;
let output_name = &node.output[0];
// ONNX Where broadcasts all 3 inputs to a common shape
let bc_shape = compute_broadcast_shape_expr(
&condition.dims(),
&compute_broadcast_shape_expr(&x.dims(), &y.dims()),
);
let condition = broadcast_to_expr(condition, &bc_shape);
let x = broadcast_to_expr(x, &bc_shape);
let y = broadcast_to_expr(y, &bc_shape);
let result = x.cond(condition, y);
tensors.insert(output_name.clone(), result);
Ok(())
}
pub fn parse_binary_broadcast_op(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
op_name: &str,
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
shape_exprs: &mut HashMap<String, Vec<Expression>>,
known_values: &HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: {} Node", op_name);
assert!(
node.input.len() == 2,
"{} should have 2 inputs, got {}",
op_name,
node.input.len()
);
assert!(
node.output.len() == 1,
"{} should have 1 output, got {}",
op_name,
node.output.len()
);
// Shape-only path: if any input is shape-only (not in tensors), do Expression arithmetic
let a_missing = !tensors.contains_key(&node.input[0]);
let b_missing = !tensors.contains_key(&node.input[1]);
if a_missing || b_missing {
// At least one input is shape-only. Do shape_exprs arithmetic and return.
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
known_values
.get(&node.input[0])
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
});
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
known_values
.get(&node.input[1])
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
});
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
&& se_a.len() == 1
&& se_b.len() == 1
{
let result_expr = match op_name {
"Add" => Some(se_a[0] + se_b[0]),
"Sub" => Some(se_a[0] - se_b[0]),
"Mul" => Some(se_a[0] * se_b[0]),
"Div" => Some(se_a[0] / se_b[0]),
_ => None,
};
if let Some(expr) = result_expr {
shape_exprs.insert(node.output[0].clone(), vec![expr]);
}
}
trace!("Finished parse: {} Node (shape-only)", op_name);
return Ok(());
}
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[1]))?;
let broadcast_shape = compute_broadcast_shape_expr(&a.dims(), &b.dims());
let a_bc = broadcast_to_expr(a, &broadcast_shape);
let b_bc = broadcast_to_expr(b, &broadcast_shape);
let result = op(a_bc, b_bc);
tensors.insert(node.output[0].clone(), result);
// Propagate shape_exprs for scalar shape arithmetic (e.g., Add(1, seq_len))
// At least one input must be in shape_exprs; the other can come from known_values.
let has_shape_expr =
shape_exprs.contains_key(&node.input[0]) || shape_exprs.contains_key(&node.input[1]);
if has_shape_expr {
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
known_values
.get(&node.input[0])
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
});
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
known_values
.get(&node.input[1])
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
});
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
&& se_a.len() == 1
&& se_b.len() == 1
{
let result_expr = match op_name {
"Add" => Some(se_a[0] + se_b[0]),
"Sub" => Some(se_a[0] - se_b[0]),
"Mul" => Some(se_a[0] * se_b[0]),
"Div" => Some(se_a[0] / se_b[0]),
_ => None,
};
if let Some(expr) = result_expr {
shape_exprs.insert(node.output[0].clone(), vec![expr]);
}
}
}
trace!("Finished parse: {} Node", op_name);
Ok(())
}
pub fn parse_variadic_broadcast_op(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
op_name: &str,
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
_shape_exprs: &mut HashMap<String, Vec<Expression>>,
_known_values: &HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: {} Node", op_name);
assert!(
node.input.len() >= 2,
"{} needs at least two inputs, got {}",
op_name,
node.input.len()
);
assert!(
node.output.len() == 1,
"{} nodes only have one output, got {}",
op_name,
node.output.len()
);
let mut result = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
for input_name in &node.input[1..] {
let rhs = *tensors
.get(input_name)
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, input_name))?;
let broadcast_shape = compute_broadcast_shape_expr(&result.dims(), &rhs.dims());
let lhs_bc = broadcast_to_expr(result, &broadcast_shape);
let rhs_bc = broadcast_to_expr(rhs, &broadcast_shape);
result = op(lhs_bc, rhs_bc);
}
tensors.insert(node.output[0].clone(), result);
trace!("Finished parse: {} Node", op_name);
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More