Compare commits

...

2501 Commits

Author SHA1 Message Date
Austin Glover
7ee5b54438 idiomatic changes 2026-04-06 23:05:06 +00:00
Austin Glover
389c05abeb auto rebuild, silence onnx, cache images. 2026-04-06 22:43:24 +00:00
Austin Glover
dcc2c9cbb4 hf tests 2026-04-06 22:42:58 +00:00
Austin Glover
a9af4c3923 test deps 2026-04-06 22:42:27 +00:00
Austin Glover
3092d0d68b skill slop 2026-04-06 22:29:34 +00:00
Austin Glover
8a2bd714ac test classes wip 2026-04-02 01:32:37 +00:00
Austin Glover
54a26a044c save codex data on container restart 2026-04-02 01:31:23 +00:00
Austin Glover
5a0d3f87cc Merge remote-tracking branch 'origin/main' into pytest-classes
# Conflicts:
#	crates/luminal_python/pyproject.toml
#	crates/luminal_python/tests/conftest.py
#	crates/luminal_python/tests/generate_llama38b_artifacts.py
#	crates/luminal_python/tests/test_llama3.py
2026-04-02 01:18:40 +00:00
Joe Fioti
a28b755245 Merge pull request #259 from luminal-ai/tucker_shared_pytorch_memory 2026-04-01 12:51:15 -07:00
Tucker Morgan
fd83534e53 Remove dead logical.rs stub from luminal_cuda_lite
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:58:12 +00:00
Tucker Morgan
b5d984c3fa Move KernelExp/KernelSigmoid to other_ops.rs and remove logical intermediaries
hlir.rs should only contain 1:1 HLIR op analogues. KernelExp and KernelSigmoid
are fused kernels, so they belong in other_ops.rs. Also removed the redundant
logical::Exp and logical::Sigmoid intermediary ops since the kernel ops match
HLIR patterns directly via their direct-fusion egglog rules.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:46:17 +00:00
Tucker Morgan
64a5ca41b5 Merge remote-tracking branch 'origin/main' into tucker_shared_pytorch_memory 2026-04-01 16:45:16 +00:00
Joe Fioti
9bda47714a Merge pull request #256 from luminal-ai/asglover/modal_ci_ready 2026-04-01 05:21:02 -07:00
Austin Glover
9e513b6589 Fix git safe.directory for pre-commit in CUDA clippy container
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:32:40 -07:00
Austin Glover
a62d728bd7 Fix CUDA clippy container image to luminal-docker
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:21:36 -07:00
Austin Glover
4114714d3f Rename clippy workflow to cuda-clippy and fix container image
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:17:47 -07:00
Austin Glover
6191597571 Remove Modal CUDA clippy job, now handled by T4 runner
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:10:17 -07:00
Austin Glover
253cd95ab0 Run clippy on T4 runner with CUDA container for full lint coverage
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:05:05 -07:00
Austin Glover
d7e396ba5b Gate Modal CI on 'modal-ready' label and convert CUDA tests to Modal
- Gate test-cuda.yml and test-python-cuda.yml behind 'modal-ready' label
- Convert CUDA clippy and unit tests from self-hosted runner to Modal
- Add ci/modal_cargo_test.py and ci/modal_cargo_clippy.py runners

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 16:03:17 -07:00
Joe Fioti
1a53626716 Merge pull request #260 from luminal-ai/nvidia-devcontainer-args 2026-03-31 15:55:21 -07:00
Austin Glover
4329d68adc Merge main and resolve workflow conflicts
Resolve conflicts from main's pre-commit migration and Modal pytest runner.
Split new lint jobs (ruff, ruff-format, metal-clippy) into individual files
and update test-python-cuda to use Modal runner from main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 15:44:29 -07:00
Tucker Morgan
989e7e2d44 Fixing native tests 2026-03-31 21:27:34 +00:00
Austin Glover
4f0a3ab102 Merge branch 'main' into nvidia-devcontainer-args 2026-03-31 13:52:26 -07:00
Tucker Morgan
019972cdd4 Fixing ruff lint issue 2026-03-31 20:46:17 +00:00
Tucker Morgan
d7a3f468bd Ruff formatting 2026-03-31 20:44:23 +00:00
Tucker Morgan
c504fbf8a1 Merge cleanip 2026-03-31 20:41:40 +00:00
Austin Glover
648720caf9 force a 12.8 cuda version of torch. 2026-03-31 20:33:44 +00:00
Tucker Morgan
625be7f4da Merge origin/main into tucker_shared_pytorch_memory
Resolved conflicts:
- other_ops.rs: kept kernel_rewrite import, dropped unused compile_kernel
- lib.rs: kept weight_device_ptrs param, added validate_backend call
- runtime.rs: accepted two-phase CUDA init helpers from main
- compiled_model.py: kept weight_refs/user_indices/is_cuda fields
- pt2.py: kept original_weights tracking for zero-copy
- test_llama3.py: kept xfail + device param for dynamic test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 20:27:30 +00:00
Austin Glover
21ed7ef31f gitignore claude , codex 2026-03-31 20:24:00 +00:00
Joe Fioti
6e94f80c9e Merge pull request #255 from luminal-ai/modal-pytest-runnner
Modal pytest runner
2026-03-31 12:33:49 -07:00
Tucker Morgan
c2a17a4854 Removing uneeded qwen3 moe test file 2026-03-31 19:04:29 +00:00
Austin Glover
386b3df983 try recommended nvidia container args 2026-03-31 18:57:14 +00:00
Tucker Morgan
5c60f1d768 Fixing up small things for review 2026-03-31 18:24:16 +00:00
Tucker Morgan
4c51e3ea84 Cargo fmt: 2026-03-31 16:44:03 +00:00
Tucker Morgan
846551aa6f Cargo clippy 2026-03-31 16:42:30 +00:00
Tucker Morgan
c26076bc75 Cargo fmt 2026-03-31 16:38:09 +00:00
Tucker Morgan
871629b770 fmt and clippy 2026-03-31 16:35:13 +00:00
Tucker Morgan
c6dfa9c62f Unify ONNX/PT2 compilation paths and extract shared helpers
Restructure so both ONNX and PT2 paths follow the same call flow:
  lib.rs (thin PyO3 wrapper)
    → onnx_translator.rs / pt2_compiled_model.rs (format-specific translate + compile)
      → compiled_graph.rs::parse_graph (shared backend pipeline)

Rust changes:
- Create onnx_translator.rs with compile_onnx() and translate_onnx()
  (moved from compiled_graph.rs and lib.rs)
- compiled_graph.rs now only contains shared code (GraphTranslation,
  WeightData, CompiledGraph, parse_graph)
- Cache label_map in CompiledGraph for O(1) set_weight_* lookups
- Move weight_device_ptrs into WeightData.device_ptrs
- Add search_iters param to process_onnx (parity with PT2)
- Fix .unwrap() → ? error propagation in ONNX file loading
- lib.rs reduced to thin PyO3 registration layer

Python changes:
- Extract _collect_weight_pointers(), _detect_backend(),
  _load_cpu_weights() shared helpers in main.py
- Both ONNX and PT2 paths use the same helpers
- Centralize _register_cache_serialization() in __init__.py
- CompiledModel: add input_names override, keep user_indices for
  torch.compile lifted-param filtering

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 23:29:15 +00:00
Tucker Morgan
90e3a915d7 Cargo fmt 2026-03-30 22:20:41 +00:00
Tucker Morgan
56cb237aa2 removing uneeded prints 2026-03-30 22:20:32 +00:00
Tucker Morgan
a2c42b35c8 Cleaning up qwen tests 2026-03-30 21:36:30 +00:00
Tucker Morgan
898204b2dd setting test right 2026-03-30 17:51:32 +00:00
Tucker Morgan
2c1a7f087f removing uneeded logs 2026-03-30 17:36:26 +00:00
Austin Glover
112d064700 remove unnecessary ignore 2026-03-28 00:39:34 +00:00
Austin Glover
c51c36fbcb add node for mcp servers 2026-03-28 00:37:46 +00:00
Austin Glover
ee372d464e ignore codex 2026-03-28 00:37:34 +00:00
Austin Glover
1bef1344d1 pytest native approach to caching 2026-03-28 00:36:41 +00:00
Austin Glover
2e27c29b47 Gate Modal CI on 'modal-ready' label and split workflows into one-job-per-file
Modal examples now only run on PRs when the 'modal-ready' label is applied,
preventing expensive GPU runs on every push. Split test.yml and lint.yml
into individual workflow files for clearer CI organization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-27 15:45:40 -07:00
Austin Glover
8d41c491fd test prototype 2026-03-27 01:36:37 +00:00
Austin Glover
64f390a833 silence maturin logs 2026-03-27 01:36:23 +00:00
Austin Glover
8d20581f38 testing infra 2026-03-27 01:35:56 +00:00
Austin Glover
bfd4ae9b27 temp 2026-03-27 01:35:49 +00:00
Tucker Morgan
92e4260f1e Fixing weight stripping issues 2026-03-26 21:31:48 +00:00
Tucker Morgan
662a564efc Cleaning up a set of changes 2026-03-26 18:39:57 +00:00
Tucker Morgan
1761dc6b66 Missed a directory 2026-03-26 18:05:28 +00:00
Tucker Morgan
da71273d7e Getting LLama tests closer to proper passing 2026-03-26 18:05:13 +00:00
Austin Glover
39122672b4 clippy fixes 2026-03-26 01:26:51 +00:00
Austin Glover
d866ba6407 plz 2026-03-26 01:05:50 +00:00
Austin Glover
9a0fb453ed plz 2026-03-26 01:00:39 +00:00
Austin Glover
dab60f0b21 simplify cuda clippy 2026-03-26 00:54:21 +00:00
Austin Glover
1ea872bd2a allow lots of args 2026-03-25 23:44:11 +00:00
Austin Glover
90a66ac704 metal clippy 2026-03-25 23:38:33 +00:00
Austin Glover
2b94ba0b71 retry 2026-03-25 23:26:56 +00:00
Austin Glover
2ed65d5386 precommit 2026-03-25 22:50:21 +00:00
Austin Glover
336d49c147 pre-commit stuff 2026-03-25 22:50:12 +00:00
Austin Glover
1ff5840a76 disable TF32 2026-03-25 22:41:42 +00:00
Austin Glover
bc94b10648 ruff changes 2026-03-25 22:34:45 +00:00
Tucker Morgan
7c921d03a8 Working weight sharing in both onnx and pt 2026-03-25 21:27:14 +00:00
Austin Glover
4e46051617 skip slow tests 2026-03-25 20:46:40 +00:00
Austin Glover
a55952d591 skip test_hf_llama3_1b_decode_loop_dynamic 2026-03-25 20:42:39 +00:00
Tucker Morgan
679aa7e092 Fixing up the onnx and fx parsing layer to share more of their code paths 2026-03-25 17:25:00 +00:00
Tucker Morgan
3dd2be2fb2 First pass of the new memory model 2026-03-25 15:59:06 +00:00
Austin Glover
c290e266f7 move dependency back to candle 2026-03-25 04:02:08 +00:00
Austin Glover
be3d8aa064 fix file permission issue w/ workaround 2026-03-25 04:01:53 +00:00
Austin Glover
84e0c842a1 ignore venv? 2026-03-25 03:37:37 +00:00
Austin Glover
403fd36b1f hugging face caching, all files on host 2026-03-25 02:06:28 +00:00
Austin Glover
651a4c2aee pass HF token 2026-03-25 01:45:09 +00:00
Austin Glover
41ddd244ef lowkirkenuenly 2026-03-25 00:54:09 +00:00
Austin Glover
0dbee87a8c Merge branch 'modal-pytest-runnner' of https://github.com/luminal-ai/luminal into modal-pytest-runnner 2026-03-25 00:41:24 +00:00
Austin Glover
194b8adfa5 satisfy clippy 2026-03-25 00:31:44 +00:00
Austin Glover
86e616800d Merge branch 'main' into modal-pytest-runnner 2026-03-24 17:23:51 -07:00
Austin Glover
683205121d add lots of stuff for creating profiling outputs 2026-03-25 00:17:05 +00:00
Austin Glover
6d653e854d add profiling stuff 2026-03-25 00:16:38 +00:00
Austin Glover
08397b566d compile to cubin 2026-03-25 00:16:25 +00:00
Austin Glover
5310335256 force cuda 12.8 for modal 2026-03-25 00:12:55 +00:00
Austin Glover
638765b62b add artifacts 2026-03-25 00:12:35 +00:00
Austin Glover
3850b3a533 chatgpt 2026-03-25 00:12:12 +00:00
Austin Glover
a4c84c6cf5 force CUDA 12.8 on modal 2026-03-25 00:12:00 +00:00
Joe Fioti
f32161d43b Merge pull request #254 from luminal-ai/tucker_fx_parsing_integration
Add PT2/FX graph parsing integration with bucketed compilation
2026-03-24 14:07:11 -07:00
Austin Glover
da83d51b27 extend timeout 2026-03-24 01:07:40 +00:00
Austin Glover
29a3ffa3e3 wip 2026-03-23 23:22:38 +00:00
Austin Glover
97e358916a simplify runner 2026-03-23 23:20:51 +00:00
Tucker Morgan
631c1b53d7 Fixing bad test config 2026-03-23 22:48:05 +00:00
Austin Glover
02449a6bea move CICD to go through modal 2026-03-23 22:18:08 +00:00
Tucker Morgan
6c4597102e Fixing test 2026-03-23 22:02:41 +00:00
Austin Glover
b077cfdb76 reduce diff 2026-03-23 21:50:47 +00:00
Tucker Morgan
869b519e39 Fixing up tests 2026-03-23 21:50:40 +00:00
Tucker Morgan
2b831c9f25 cargo fmt 2026-03-23 21:10:44 +00:00
Tucker Morgan
f35a950496 Running a review to clean things up 2026-03-23 21:08:15 +00:00
Tucker Morgan
9ab0e1472c Fmt and Clippy 2026-03-23 18:18:08 +00:00
Tucker Morgan
88f2601d5e Preclippy push 2026-03-23 18:05:25 +00:00
Tucker Morgan
b0ebdcba8c Preclipply fix 2026-03-23 18:04:55 +00:00
Tucker Morgan
0ab124194b Pre commit cleanup 2026-03-23 17:57:51 +00:00
Austin Glover
7f042ae615 Merge remote-tracking branch 'origin/main' into tucker-branch-austin-wip
# Conflicts:
#	crates/luminal_python/.gitignore
#	crates/luminal_python/rust/src/ops_parse/unary.rs
2026-03-23 17:52:46 +00:00
Joe Fioti
082d9c48bd renamed progress bar ui 2026-03-23 05:57:03 +00:00
Tucker Morgan
251e9526f3 Merge remote-tracking branch 'origin/main' into tucker_fx_parsing_integration 2026-03-22 21:34:49 +00:00
Tucker Morgan
41b3774ec2 Getting the errors on most models to pass 1e-5 2026-03-22 21:32:02 +00:00
Joe Fioti
3fdb464f5a fmt 2026-03-21 18:58:02 +00:00
Austin Glover
a3a4fd94ec ignore uv.lock 2026-03-20 23:55:55 +00:00
Austin Glover
5446dccb04 clean up conf test 2026-03-20 23:47:40 +00:00
Austin Glover
8e6535563e clear errors 2026-03-20 23:46:57 +00:00
Austin Glover
bdc923aa50 add logic to prevent unclear errors 2026-03-20 23:46:38 +00:00
Austin Glover
ea67742b3b ignore uv.lock 2026-03-20 23:45:52 +00:00
Austin Glover
149e570f26 modal in dev, not sure if the module_name is necessary? 2026-03-20 23:45:14 +00:00
Tucker Morgan
ac52098d5c Fixing brioken llama issue 2026-03-20 23:06:16 +00:00
Austin Glover
01946ecd10 modal 2026-03-20 22:57:02 +00:00
Tucker Morgan
ef70fee204 Work to get the early version of llama to work 2026-03-20 22:24:06 +00:00
Austin Glover
a6fea110dc ignore 2026-03-20 20:59:20 +00:00
Austin Glover
2d4ebb2cb6 sure 2026-03-20 20:52:15 +00:00
Austin Glover
5adb875b04 ignore cargo-local 2026-03-20 20:50:55 +00:00
Joe Fioti
2adfcfa70e bucketed compilation 2026-03-20 18:33:15 +00:00
Tucker Morgan
65600e8730 Cargo fmt 2026-03-20 17:45:40 +00:00
Tucker Morgan
4c4f39b4af A quick review pass 2026-03-20 17:45:06 +00:00
Tucker Morgan
49b9209ad0 All basic tests passing for both Onnx and FX 2026-03-20 00:08:57 +00:00
Austin Glover
dea5df51dd wip 2026-03-19 23:59:12 +00:00
Austin Glover
eb6a6c2174 ignore luminal artifacts 2026-03-19 23:59:05 +00:00
Austin Glover
8864ef31fb ignore claude 2026-03-19 23:58:23 +00:00
Austin Glover
38c98a8835 fix devcontainers 2026-03-19 23:58:13 +00:00
Tucker Morgan
31b5fd886d First pass, copy over scartch pad work, easy test fixes 2026-03-19 22:57:32 +00:00
Joe Fioti
04b2753aa8 fixed test 2026-03-19 15:54:21 -07:00
Joe Fioti
dbb5282fd6 removed kimi submodule 2026-03-19 15:48:00 -07:00
Joe Fioti
8e315c62df Merge pull request #240 from luminal-ai/tucker_pytorch_integegration_layer
PyTorch Integration Layer (luminal_python)
2026-03-19 13:40:46 -07:00
Tucker Morgan
f53d990581 Adding Flatten 2026-03-19 20:18:15 +00:00
Tucker Morgan
2c8ecba6a5 Flatten function on ShapeTracker and GraphTensor 2026-03-19 18:51:30 +00:00
Tucker Morgan
1644cce031 Clippy and fmt 2026-03-19 18:20:57 +00:00
Tucker Morgan
b2bb455b30 Removing as many forced materlizaing as possibling 2026-03-19 18:18:55 +00:00
Tucker Morgan
8628b1425a Actually removing the uv lock 2026-03-19 16:36:01 +00:00
Tucker Morgan
ca66609d6f cargo fmt 2026-03-18 23:39:03 +00:00
Tucker Morgan
c50e122ac1 making scatter element and gather elemnets to be simpler 2026-03-18 23:37:19 +00:00
Tucker Morgan
272acabd0c Small fixes from comments 2026-03-18 23:16:27 +00:00
Tucker Morgan
f772c0529a Removing GPU dashboard 2026-03-18 22:28:49 +00:00
Tucker Morgan
c3e1f568ea Wrong name for docket imag 2026-03-18 21:55:16 +00:00
Tucker Morgan
eb3dd02836 updating docker image 2026-03-18 21:44:47 +00:00
Tucker Morgan
8a9f85b0ce Cargo fmt fixes 2026-03-18 21:27:06 +00:00
Tucker Morgan
372501e527 Fixing clippy and review issues 2026-03-18 21:23:55 +00:00
Tucker Morgan
cda12a6d84 Clippy and fmt fixes 2026-03-18 21:04:33 +00:00
Tucker Morgan
566fb00ed2 Fixes in bad tests 2026-03-18 21:03:20 +00:00
Tucker Morgan
ecb78a2635 Fixing up issues left over from a merge 2026-03-18 18:56:39 +00:00
Tucker Morgan
f401ffb900 Merge remote-tracking branch 'origin/main' into tucker_pytorch_integegration_layer 2026-03-18 16:36:08 +00:00
Joe Fioti
cd94000140 Merge pull request #253 from luminal-ai/matmul-flattening-minimal
Matmul flattening rules for cuBLAS matching
2026-03-18 09:35:42 -07:00
Austin Glover
1f1636e188 cargo fmt
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 01:20:55 +00:00
Austin Glover
371fa8491a Update matmul flattening rules for Op() wrapper syntax
Adapt the 3 egglog flattening rules (squeeze, batch_merge_a_contig,
batch_merge_b_contig) to use the new Op(...) (ICons ...) wrapper
syntax from the HLIR refactor, so they match the current egglog
representation of Mul and Sum nodes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 01:04:37 +00:00
Austin Glover
6d1fe67b66 flatten matmuls 2026-03-18 00:58:55 +00:00
Austin Glover
189d1e2594 match on flattened matmuls 2026-03-18 00:58:55 +00:00
Austin Glover
c7acfb9794 flatten matmul when applicable 2026-03-18 00:58:33 +00:00
Tucker Morgan
fe6af5290a Merge remote-tracking branch 'origin/main' into tucker_pytorch_integegration_layer
# Conflicts:
#	.github/workflows/test.yml
#	crates/luminal_cuda/src/block/ops.rs
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg
#	crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg
#	crates/luminal_cuda/src/kernel/other_ops.rs
#	crates/luminal_cuda_lite/src/kernel/hlir.rs
2026-03-17 17:19:04 +00:00
Joe Fioti
8b8669c744 Merge pull request #252 from luminal-ai/hlir_enhancements
tweaks
2026-03-16 21:54:09 -07:00
Joe Fioti
7af771b999 tweaks 2026-03-17 04:46:27 +00:00
Joe Fioti
133757f187 Merge pull request #250 from xiaoniaoyouhuajiang/feature/metal-mega-kernel
metal backend optimization for transformer
2026-03-13 16:16:31 -07:00
xiaoniaoyouhuajiang
07ee241b25 merge upstream work 2026-03-13 17:01:07 +08:00
Joe Fioti
958331ab6c Merge pull request #251 from luminal-ai/hlir_enhancements
Hlir enhancements
2026-03-12 22:28:33 -07:00
Joe Fioti
340199d4a8 fixed cuda testss 2026-03-13 05:16:09 +00:00
Joe Fioti
f17a95e673 metal fixes 2026-03-12 20:24:07 -07:00
Joe Fioti
6bb576e711 more test fixes 2026-03-13 03:10:37 +00:00
Joe Fioti
744e4d767a changed actions 2026-03-13 03:08:29 +00:00
Joe Fioti
c940161f25 fmt 2026-03-12 23:27:12 +00:00
Joe Fioti
3aa2c309f5 clippy 2026-03-12 23:24:13 +00:00
Joe Fioti
8a0592646b added qwen3 moe 2026-03-12 23:10:32 +00:00
Joe Fioti
68ce81e52b Merge remote-tracking branch 'origin/main' into hlir_enhancements 2026-03-12 22:37:07 +00:00
Joe Fioti
39789404f4 Merge pull request #248 from luminal-ai/modal-ci
Add CI workflows: lint, CUDA tests, Modal llama
2026-03-12 15:28:34 -07:00
Joe Fioti
8c53234966 Merge branch 'main' into modal-ci 2026-03-12 14:45:30 -07:00
Joe Fioti
71eca945cb changed luminal_cuda_lite 2026-03-12 21:34:13 +00:00
Joe Fioti
8da130ae1c Merge remote-tracking branch 'origin/main' into hlir_enhancements 2026-03-12 20:33:09 +00:00
Joe Fioti
fef6a45c9c paged attention llama example 2026-03-12 20:29:39 +00:00
Joe Fioti
c6763a69ba fixed metal ci 2026-03-12 11:24:45 -07:00
Joe Fioti
30caca106c fixed metal ci 2026-03-12 11:16:45 -07:00
Joe Fioti
6c90bb5059 metal ci/cd 2026-03-12 11:01:04 -07:00
xiaoniaoyouhuajiang
82189cd602 threadgroup staging 2026-03-12 18:56:05 +08:00
xiaoniaoyouhuajiang
cc5e0a639d support simdgroup micro-kernel 2026-03-12 18:23:50 +08:00
xiaoniaoyouhuajiang
8dc05233cb support tiled matmul for gemm 2026-03-12 17:32:19 +08:00
xiaoniaoyouhuajiang
0ab9947292 matmul family template & match pipeline 2026-03-12 16:56:02 +08:00
Joe Fioti
f11ba3a388 added tensor.repeat 2026-03-11 22:35:32 -07:00
Austin Glover
a346e503db make it so draft PRs don't trigger 2026-03-11 18:16:10 +00:00
xiaoniaoyouhuajiang
6bbf244924 support f16 for metal 2026-03-11 14:02:48 +08:00
Joe Fioti
a8505668ac converted gemma and qwen 2026-03-11 03:39:51 +00:00
Austin Glover
a0b237c424 use new images 2026-03-11 02:03:39 +00:00
Austin Glover
8fabacd17e manually override build 2026-03-11 01:50:21 +00:00
Austin Glover
df0128ad04 pass example str 2026-03-11 01:28:04 +00:00
Austin Glover
de55e67594 lock to cuda 13 2026-03-11 01:22:22 +00:00
Austin Glover
cdff26755f fmt 2026-03-11 01:17:01 +00:00
Austin Glover
9f11b7e24a try matrix formulation 2026-03-11 01:15:54 +00:00
Austin Glover
27344a0e45 updating CI 2026-03-11 00:57:56 +00:00
Joe Fioti
e9e6f824a1 fixed memory leak 2026-03-10 23:08:28 +00:00
Austin Glover
d894eeae50 remove caching 2026-03-10 22:45:25 +00:00
Joe Fioti
da078b5bdd fixed in-place scatter for llama, cleaned up cleanup rules 2026-03-10 21:26:35 +00:00
Tucker Morgan
f156265ff4 Add TileMatmul broadcast constraints and expand qwen image tests
Add (= ?a_n_stride (MNum 0)) and (= ?b_m_stride (MNum 0)) constraints
to TileMatmulSplitK and TileMatmulFullSplit to prevent matching
element-wise Mul+Sum patterns (e.g. x*x from LayerNorm) as matmuls.
Also refactor qwen image tests to use torch.compile backend for
transformer tests, add medium configs, and add full production-scale
test cases.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-10 18:04:54 +00:00
Joe Fioti
b34f104cea normalized egglog ops 2026-03-10 04:49:41 +00:00
Austin Glover
1873e26185 fix: add Modal environment to modal-llama workflow for secrets access
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 00:22:14 +00:00
Austin Glover
384a426ba3 modal prototype 2026-03-09 23:58:47 +00:00
Tucker Morgan
d25654a0ec Small fixes to qwen image 2026-03-09 20:54:34 +00:00
Tucker Morgan
579daa1a57 Merge origin/main into tucker_pytorch_integegration_layer
Resolved conflicts:
- other_ops.rs: kept parallelized reduction loop from branch
- base.rs: kept existing commutativity rules, avoided duplicate add-comm

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-09 17:48:29 +00:00
Tucker Morgan
4fa8a92086 Pushed from only the luminal_python crate 2026-03-09 17:46:25 +00:00
Tucker Morgan
bf7debc1d6 Various fixes after a merge 2026-03-09 17:45:54 +00:00
Joe Fioti
a49c970029 hlir attn 2026-03-09 06:05:24 +00:00
xiaoniaoyouhuajiang
1d2db8f88f Merge branch 'feature/metal-mega-kernel' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-mega-kernel 2026-03-09 10:06:33 +08:00
Tucker Morgan
4fff906b8d Merge branch 'tucker_pytorch_integegration_layer' of https://github.com/luminal-ai/luminal into tucker_pytorch_integegration_layer
# Conflicts:
#	.github/workflows/test.yml
#	crates/luminal_python/LessonsLearned.md
#	crates/luminal_python/run_tests_cuda.sh
2026-03-08 02:10:05 +00:00
Tucker Morgan
85b49018a3 KernelEmbed failure fix 2026-03-08 02:02:52 +00:00
Joe Fioti
a82c530ae7 consumable inputs 2026-03-07 20:28:27 +00:00
Joe Fioti
8fdfac19e1 fixed model examples 2026-03-07 19:00:47 +00:00
Joe Fioti
c61dfa0a13 merge dims 2026-03-07 00:59:53 +00:00
Tucker Morgan
3815a6de67 Working on dynamic caching handling from transformers 2026-03-06 22:48:22 +00:00
Tucker Morgan
375c54b641 Other part of i64 converstion 2026-03-06 22:48:22 +00:00
Tucker Morgan
1dd8853a60 Fixing up Num expression in Term to use i64 not i32 2026-03-06 22:48:22 +00:00
Tucker Morgan
050eeba815 Fixing borken topk on cuda 2026-03-06 22:48:22 +00:00
Tucker Morgan
05e0a2fc31 Fixing llama3 decode loop so it odes not crash 2026-03-06 22:48:21 +00:00
Tucker Morgan
af7d0b002f Working LLama3 1b instruct 2026-03-06 22:48:21 +00:00
Tucker Morgan
f6c3b68f86 Fixing broken tests 2026-03-06 22:48:21 +00:00
Tucker Morgan
afc17eac31 Fix for failing reduction matmul tests, adjusting z stride 2026-03-06 22:47:20 +00:00
Tucker Morgan
5330ab9159 Dynamic Shapes 2026-03-06 22:47:20 +00:00
Tucker Morgan
ea2cd9da45 Scatter ND cuda test fix 2026-03-06 22:47:20 +00:00
Tucker Morgan
92957228f5 inlining parsing functions into dispatch.rs 2026-03-06 22:47:19 +00:00
Tucker Morgan
4152c8a732 Locking onnx opset to 20, reject all models that are not that explict model version 2026-03-06 22:47:19 +00:00
Tucker Morgan
d95e847dba removing unsupported datatype constants 2026-03-06 22:47:19 +00:00
Tucker Morgan
153ac59773 removing eprint 2026-03-06 22:47:19 +00:00
Tucker Morgan
bcd1fe673a fixing onehot test 2026-03-06 22:47:19 +00:00
Tucker Morgan
99d65ae9df Fixing Slice2d test 2026-03-06 22:47:19 +00:00
Tucker Morgan
af9442bfa0 Llama3 2026-03-06 22:47:18 +00:00
Tucker Morgan
3414a9fe6c Pushing local changes 2026-03-06 22:47:18 +00:00
Tucker Morgan
d3a60de7fa Just creating a nice point for me 2026-03-06 22:47:18 +00:00
Tucker Morgan
63bad7d5a2 Gather Elements, LayerNorm, Gemm, IsNan, and Expand 2026-03-06 22:47:18 +00:00
Tucker Morgan
69e143f33d Adjusting how we determine where to compile the model 2026-03-06 22:47:18 +00:00
Tucker Morgan
ca9113b8a8 Fixing up uses of to axes 2026-03-06 22:47:17 +00:00
Tucker Morgan
4e27a9fb31 Refactor parsing functions to reuse as code 2026-03-06 22:47:17 +00:00
Tucker Morgan
ad51ad88cd Ceil and Cargo fmt 2026-03-06 22:47:17 +00:00
Tucker Morgan
9afc41e0a1 Reduce Mean 2026-03-06 22:47:17 +00:00
Tucker Morgan
0cbe874492 ReduceMin handling and tests 2026-03-06 22:47:17 +00:00
Tucker Morgan
987ac5b5ec Trilu trill handling and tests 2026-03-06 22:47:17 +00:00
Tucker Morgan
d4faf80cf9 Added common condtionals, Not, And, Or, Xoe, LessOrEqual, GreaterOrEqual 2026-03-06 22:47:17 +00:00
Tucker Morgan
8ce60f61ec Soft max function and tests 2026-03-06 22:47:16 +00:00
Tucker Morgan
3a9c945e71 Min and max handling 2026-03-06 22:47:16 +00:00
Tucker Morgan
765ddd5070 Adding where node handling and tests 2026-03-06 22:47:16 +00:00
Tucker Morgan
4178b38eb0 Added pow handling and testing 2026-03-06 22:47:16 +00:00
Tucker Morgan
516faaa205 Fixed mod tests failing by changing from using mod operator to fmod function 2026-03-06 22:47:16 +00:00
Tucker Morgan
5033efc8d7 Clip and Equal Nodes 2026-03-06 22:47:16 +00:00
Tucker Morgan
de659b4a06 Removing unsued code 2026-03-06 22:47:16 +00:00
Tucker Morgan
612c76bd7c Rust format issues 2026-03-06 22:47:15 +00:00
Tucker Morgan
2230b28751 Trigger CI 2026-03-06 22:47:15 +00:00
Tucker Morgan
ae3e2c9331 Disable mod test 2026-03-06 22:47:15 +00:00
Tucker Morgan
f09c2ddd68 Fix all 30 clippy errors in luminal_python crate
- Remove unused imports (onnx_protobuf::Message, protobuf::MessageField,
  std::sync::Arc, get_dtype_for_onnx_value)
- Gate cuda-only code with #[cfg(feature = "cuda")]
  (transpose_weight_data, Arc import, runtime.rs)
- Prefix unused variables/params with _ (_known_values, _input_tensor_names)
- Remove needless return statements in binary.rs and unary.rs (13 total)
- Replace never-loop pattern in compiled_graph.rs with direct panic
- Replace result = result * x with result *= x in movement.rs
- Replace manual memcpy loop with copy_from_slice in movement.rs
- Replace needless range loops with iterator form in reduction.rs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 22:47:15 +00:00
Tucker Morgan
dae508f26c Add maturin to dev dependency-groups so uv run can find it in CI
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 22:47:15 +00:00
Tucker Morgan
9b970e9d01 Fixing up fmt issues that was casuing ci/cd to fail 2026-03-06 22:47:15 +00:00
Tucker Morgan
c14c7f4be4 Added Abs, Sigmoid, Relu, and Tanh, and tests for them 2026-03-06 22:47:15 +00:00
Tucker Morgan
f60147f4a7 Adding tests to cicd 2026-03-06 22:47:14 +00:00
Tucker Morgan
2f2ba2fb1f breaking up ops parse into different files, added pytest random 2026-03-06 22:46:03 +00:00
Tucker Morgan
695df9cb29 Removing unsued parameters, adding in trace! 2026-03-06 22:46:02 +00:00
Tucker Morgan
97af33e671 Sum and Max Reduce 2026-03-06 22:46:02 +00:00
Tucker Morgan
c1824a1e8d Squeeze parsing and tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
96cfdd9fcf Fixed broken cuda tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
e4d2be9c89 Fixing a bunch of changes 2026-03-06 22:46:02 +00:00
Tucker Morgan
10ea5be27d Fixed native tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
3d5c939a7f fixing issue with floor op for mod 2026-03-06 22:46:02 +00:00
Tucker Morgan
4ba5b3b121 Handling the less node and tests 2026-03-06 22:46:02 +00:00
Tucker Morgan
2dc8e23496 Reshape, Shape, ops parsing added, tests as well 2026-03-06 22:46:01 +00:00
Tucker Morgan
9798fd0e38 Added Mod and tests for Mod op 2026-03-06 22:46:01 +00:00
Tucker Morgan
cf36d45b77 Fixing up cuda cast tests 2026-03-06 22:46:01 +00:00
Tucker Morgan
dccda51f30 Adding handling constant nodes 2026-03-06 22:46:01 +00:00
Tucker Morgan
34e9f6f5e8 Added transpose op and tests 2026-03-06 22:46:01 +00:00
Tucker Morgan
ba1aa9575c Cleaning up test files 2026-03-06 22:46:01 +00:00
Tucker Morgan
77a68d3091 Adding a way to select which backend is used for compile 2026-03-06 22:46:01 +00:00
Tucker Morgan
aaaa476352 Adding Sqrt node handling 2026-03-06 22:46:00 +00:00
Tucker Morgan
879e4203af Extending to use cuda, and fixing up luminal_python project structure so it follows best practices 2026-03-06 22:46:00 +00:00
tucker-luminal
f5185b86d6 Adding cuda support 2026-03-06 22:46:00 +00:00
tucker-luminal
7bf2543909 Add and Sub first pass
Add and sub first pass, pushing this so I can test on a cuda enviroment
2026-03-06 22:46:00 +00:00
Tucker Morgan
00b797e281 Working on dynamic caching handling from transformers 2026-03-06 22:45:05 +00:00
Joe Fioti
41f6a64746 clippy fix 2026-03-06 22:38:08 +00:00
Joe Fioti
2556bfa90b merged 2026-03-06 22:35:15 +00:00
Joe Fioti
32756f04c1 Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-03-06 22:31:34 +00:00
Joe Fioti
eaf8ba9219 fixed z-strides 2026-03-06 22:31:30 +00:00
Joe Fioti
b37c06d9b9 fmt 2026-03-06 11:52:27 -08:00
Joe Fioti
dd85e23e60 removed shapetracker on hlir edges 2026-03-06 11:48:45 -08:00
xiaoniaoyouhuajiang
a0a162049e make dyn dim = n_inputs +1 2026-03-06 17:11:36 +08:00
Austin Glover
d63cb1a115 try to prevent permissions issues 2026-03-06 04:39:12 +00:00
Austin Glover
e0413c640a Merge branch 'new-docker-images' of https://github.com/luminal-ai/luminal into new-docker-images 2026-03-06 04:24:30 +00:00
Austin Glover
da9a45a044 gh token forwarding 2026-03-06 04:22:47 +00:00
Austin Glover
a736d1aa2f use user ubuntu 2026-03-06 04:22:47 +00:00
Austin Glover
0d5880296a new split dev container 2026-03-06 04:22:46 +00:00
Austin Glover
ea3fa459ec gh token forwarding 2026-03-06 02:30:52 +00:00
Austin Glover
d21969370f use user ubuntu 2026-03-06 01:52:02 +00:00
Tucker Morgan
26ffc23c12 Other part of i64 converstion 2026-03-05 23:55:49 +00:00
Tucker Morgan
df3f6b539b Fixing up Num expression in Term to use i64 not i32 2026-03-05 23:31:21 +00:00
Austin Glover
45a6a62909 new split dev container 2026-03-05 23:18:38 +00:00
Tucker Morgan
4d2bb35e9e Fixing borken topk on cuda 2026-03-05 21:54:24 +00:00
Tucker Morgan
fa6991d8fb Fixing llama3 decode loop so it odes not crash 2026-03-05 20:10:10 +00:00
Joe Fioti
a5c02dd6f4 Merge pull request #246 from luminal-ai/remove-false-dtypes
Remove false dtypes
2026-03-05 11:40:11 -08:00
Joe Fioti
4dc7623b9d Merge pull request #245 from xiaoniaoyouhuajiang/feature/metal-mega-kernel
adapt to upstream code&add option to reproduce benchmark result
2026-03-05 11:16:52 -08:00
xiaoniaoyouhuajiang
f8fd9d568d remove env var parameters 2026-03-05 14:58:00 +08:00
xiaoniaoyouhuajiang
1aeb825e12 fix reduce error 2026-03-05 11:20:08 +08:00
Austin Glover
ff45626ae1 clippy 2026-03-05 01:16:22 +00:00
Austin Glover
659952ef13 fmt 2026-03-05 00:57:06 +00:00
Tucker Morgan
45a4e8c617 Working LLama3 1b instruct 2026-03-04 20:52:54 +00:00
Austin Glover
548a3eab83 import renames 2026-03-04 18:57:05 +00:00
Austin Glover
8c0beb1dbf Merge branch 'main' of https://github.com/luminal-ai/luminal into remove-false-dtypes 2026-03-04 18:49:57 +00:00
Austin Glover
ec30bd6b6b fmt 2026-03-04 18:19:26 +00:00
Austin Glover
b78b57b41b spelling 2026-03-04 18:18:41 +00:00
Austin Glover
539c705c22 tests run 2026-03-04 18:15:05 +00:00
Joe Fioti
440130a68f fixed tests 2026-03-03 22:37:23 +00:00
Tucker Morgan
c6ee7e2f21 Fixing broken tests 2026-03-03 22:03:21 +00:00
Tucker Morgan
335ec78d3e Fix for failing reduction matmul tests, adjusting z stride 2026-03-03 20:11:39 +00:00
Austin Glover
057e40b26c adding cuda dtype mapping 2026-03-03 19:55:41 +00:00
Austin Glover
9843d37278 dtypes to egglog 2026-03-03 19:18:37 +00:00
Austin Glover
38533176df core dtype change 2026-03-03 19:18:09 +00:00
Austin Glover
c828275906 move to explicit scaling factors 2026-03-03 19:17:10 +00:00
Tucker Morgan
7f760ad847 Dynamic Shapes 2026-03-03 18:54:42 +00:00
Tucker Morgan
a2e9b5209f Scatter ND cuda test fix 2026-03-03 18:54:42 +00:00
Tucker Morgan
d18e93d671 inlining parsing functions into dispatch.rs 2026-03-03 18:54:42 +00:00
Tucker Morgan
5a432c717c Locking onnx opset to 20, reject all models that are not that explict model version 2026-03-03 18:54:41 +00:00
Tucker Morgan
a15ae2d41c removing unsupported datatype constants 2026-03-03 18:54:41 +00:00
Tucker Morgan
2a03239c8d removing eprint 2026-03-03 18:54:41 +00:00
Tucker Morgan
9b11ce16ba fixing onehot test 2026-03-03 18:54:41 +00:00
Tucker Morgan
428d28e307 Fixing Slice2d test 2026-03-03 18:54:41 +00:00
Tucker Morgan
8f4680e7c0 Llama3 2026-03-03 18:54:41 +00:00
Tucker Morgan
4cb7c2c337 Pushing local changes 2026-03-03 18:54:40 +00:00
Tucker Morgan
38be807c11 Just creating a nice point for me 2026-03-03 18:54:40 +00:00
Tucker Morgan
26902c3cbf Gather Elements, LayerNorm, Gemm, IsNan, and Expand 2026-03-03 18:54:40 +00:00
Tucker Morgan
8550e98370 Adjusting how we determine where to compile the model 2026-03-03 18:54:40 +00:00
Tucker Morgan
b47aa9f4c5 Fixing up uses of to axes 2026-03-03 18:54:39 +00:00
Tucker Morgan
0264290091 Refactor parsing functions to reuse as code 2026-03-03 18:54:39 +00:00
Tucker Morgan
d3c8bbb838 Ceil and Cargo fmt 2026-03-03 18:54:39 +00:00
Tucker Morgan
09af10df2d Reduce Mean 2026-03-03 18:54:39 +00:00
Tucker Morgan
a3bc233cc9 ReduceMin handling and tests 2026-03-03 18:54:39 +00:00
Tucker Morgan
472148ec2e Trilu trill handling and tests 2026-03-03 18:54:39 +00:00
Tucker Morgan
202ee2e4b4 Added common condtionals, Not, And, Or, Xoe, LessOrEqual, GreaterOrEqual 2026-03-03 18:52:46 +00:00
Tucker Morgan
a52c90de55 Soft max function and tests 2026-03-03 18:52:46 +00:00
Tucker Morgan
d0fc4e528b Min and max handling 2026-03-03 18:52:46 +00:00
Tucker Morgan
3c51d2cd6e Adding where node handling and tests 2026-03-03 18:52:46 +00:00
Tucker Morgan
a0783896a1 Added pow handling and testing 2026-03-03 18:52:46 +00:00
Tucker Morgan
64700aa2a8 Fixed mod tests failing by changing from using mod operator to fmod function 2026-03-03 18:52:45 +00:00
Tucker Morgan
c7325f5590 Clip and Equal Nodes 2026-03-03 18:52:45 +00:00
Tucker Morgan
e1f8d2366f Removing unsued code 2026-03-03 18:52:45 +00:00
Tucker Morgan
bfee79d764 Rust format issues 2026-03-03 18:52:45 +00:00
Tucker Morgan
aab5db7d10 Trigger CI 2026-03-03 18:52:45 +00:00
Tucker Morgan
840ab6abd2 Disable mod test 2026-03-03 18:52:45 +00:00
Tucker Morgan
4079eebf1f Fix all 30 clippy errors in luminal_python crate
- Remove unused imports (onnx_protobuf::Message, protobuf::MessageField,
  std::sync::Arc, get_dtype_for_onnx_value)
- Gate cuda-only code with #[cfg(feature = "cuda")]
  (transpose_weight_data, Arc import, runtime.rs)
- Prefix unused variables/params with _ (_known_values, _input_tensor_names)
- Remove needless return statements in binary.rs and unary.rs (13 total)
- Replace never-loop pattern in compiled_graph.rs with direct panic
- Replace result = result * x with result *= x in movement.rs
- Replace manual memcpy loop with copy_from_slice in movement.rs
- Replace needless range loops with iterator form in reduction.rs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 18:52:44 +00:00
Tucker Morgan
0ad7b4e509 Add maturin to dev dependency-groups so uv run can find it in CI
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 18:52:44 +00:00
Tucker Morgan
60206e362c Fixing up fmt issues that was casuing ci/cd to fail 2026-03-03 18:52:44 +00:00
Tucker Morgan
ce72c5223b Added Abs, Sigmoid, Relu, and Tanh, and tests for them 2026-03-03 18:52:44 +00:00
Tucker Morgan
d0d958d69d Adding tests to cicd 2026-03-03 18:52:44 +00:00
Tucker Morgan
e554532108 breaking up ops parse into different files, added pytest random 2026-03-03 18:52:44 +00:00
Tucker Morgan
4ebb762724 Removing unsued parameters, adding in trace! 2026-03-03 18:52:43 +00:00
Tucker Morgan
734787b7c4 Sum and Max Reduce 2026-03-03 18:52:43 +00:00
Tucker Morgan
97e7e6e3f2 Squeeze parsing and tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
adf1e6be8b Fixed broken cuda tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
dcf49e3a61 Fixing a bunch of changes 2026-03-03 18:52:43 +00:00
Tucker Morgan
a3ff32d7a7 Fixed native tests 2026-03-03 18:52:43 +00:00
Tucker Morgan
8760818756 fixing issue with floor op for mod 2026-03-03 18:52:42 +00:00
Tucker Morgan
e4c153b388 Handling the less node and tests 2026-03-03 18:52:42 +00:00
Tucker Morgan
6470b74a9d Reshape, Shape, ops parsing added, tests as well 2026-03-03 18:52:42 +00:00
Tucker Morgan
e2cf926e5a Added Mod and tests for Mod op 2026-03-03 18:52:42 +00:00
Tucker Morgan
97766c59cf Fixing up cuda cast tests 2026-03-03 18:52:42 +00:00
Tucker Morgan
379ea51474 Adding handling constant nodes 2026-03-03 18:52:41 +00:00
Tucker Morgan
4e87658460 Added transpose op and tests 2026-03-03 18:52:41 +00:00
Tucker Morgan
2b16cdea24 Cleaning up test files 2026-03-03 18:52:41 +00:00
Tucker Morgan
bc7d5e8b14 Adding a way to select which backend is used for compile 2026-03-03 18:52:41 +00:00
Tucker Morgan
4e15740876 Adding Sqrt node handling 2026-03-03 18:52:41 +00:00
Tucker Morgan
53a056ba8a Extending to use cuda, and fixing up luminal_python project structure so it follows best practices 2026-03-03 18:52:41 +00:00
tucker-luminal
615f43ed05 Adding cuda support 2026-03-03 18:52:40 +00:00
tucker-luminal
5847810f52 Add and Sub first pass
Add and sub first pass, pushing this so I can test on a cuda enviroment
2026-03-03 18:52:40 +00:00
xiaoniaoyouhuajiang
1fbce19cfc fix clippy error 2026-02-28 18:44:07 +08:00
xiaoniaoyouhuajiang
4908ab0db1 Merge remote-tracking branch 'upstream/main' into feature/metal-mega-kernel 2026-02-28 18:40:19 +08:00
xiaoniaoyouhuajiang
68da1b69e0 fix clippy error 2026-02-28 18:28:29 +08:00
xiaoniaoyouhuajiang
64bc9d1786 add choice-set SIGNATURE && option to fix search seed 2026-02-28 18:17:34 +08:00
Joe Fioti
d81253a759 removed protobuf install 2026-02-26 16:12:56 -08:00
Joe Fioti
0cab8bf8c8 removed dep 2026-02-26 14:44:21 -08:00
Joe Fioti
c59f54b503 fixed some tests 2026-02-26 12:03:08 -08:00
Joe Fioti
4b541db780 z-strides 2026-02-26 11:36:41 -08:00
xiaoniaoyouhuajiang
30e1aab3bb Adapt to the upstream from_source to ensure the Metal backend works correctly 2026-02-26 17:10:03 +08:00
Joe Fioti
f3b929b1cc fixed workspace 2026-02-25 18:22:07 +00:00
Joe Fioti
5f3bc73753 :Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-02-25 17:54:42 +00:00
Joe Fioti
c10702e638 stable argsort 2026-02-25 17:49:55 +00:00
Joe Fioti
cf1c916bdb Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-02-24 23:07:20 -08:00
Joe Fioti
553f096907 cloud modifications 2026-02-24 23:07:08 -08:00
Joe Fioti
3d98fba66c fix 2026-02-24 06:14:07 +00:00
Joe Fioti
dbe4d9d2e5 fmt and clippy 2026-02-24 06:05:09 +00:00
Joe Fioti
95e5b40e2b moe 2026-02-24 05:34:01 +00:00
Joe Fioti
c4866efd25 moe in luminal_nn 2026-02-23 19:49:04 +00:00
Joe Fioti
944000fca8 Merge pull request #243 from luminal-ai/luminal_fuzz
Luminal fuzz
2026-02-22 13:30:54 -05:00
Joe Fioti
013246ee8f merged with fuzz testing 2026-02-22 18:26:12 +00:00
Joe Fioti
3942a4dc9d Merge origin/main into luminal_fuzz
Resolves merge conflicts from upstream test restructuring:
- Accepts new test structure (utilities.rs, op_functional_tests.rs, performance_tests.rs)
- Removes old ops.rs, misc.rs, mxfp4.rs, nvfp4.rs
- Adapts transformer.rs to use new utilities API (assert_close with rtol/atol, random_f32_vec)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 02:49:13 +00:00
Joe Fioti
3cfad049b0 Merge pull request #242 from luminal-ai/egglog_api
Egglog api
2026-02-21 21:34:53 -05:00
Joe Fioti
398b533354 removed narrow precision dtypes 2026-02-22 02:23:05 +00:00
Joe Fioti
b6e40d1b57 fmt 2026-02-22 01:27:19 +00:00
Joe Fioti
d591b55579 fixed cuda 2026-02-22 01:23:50 +00:00
Joe Fioti
8808c58790 fuzz testing 2026-02-21 23:30:50 +00:00
Joe Fioti
35bb113209 fmt 2026-02-21 15:12:44 -08:00
Joe Fioti
0cb890689a fixed tests 2026-02-21 15:10:23 -08:00
Joe Fioti
a4d49ac4e1 fmt 2026-02-21 14:33:12 -08:00
Joe Fioti
9abab37e57 Merge main into egglog_api: resolve conflicts
- Combined egglog_api's op imports with main's tracing imports in cublas
- Kept main's b_dtype/includes additions in kernel hlir binary ops
- Kept egglog_api's declarative Rule API for Mul dtype propagation
- Removed base.egg (replaced by declarative API in egglog_api)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 14:28:26 -08:00
Joe Fioti
9794146f65 tweaks 2026-02-21 14:24:40 -08:00
Joe Fioti
0659b2a0bd improved declarative egglog api 2026-02-21 14:23:04 -08:00
Joe Fioti
aca17a072f Merge remote egglog_api: resolve conflicts keeping local API
Resolved conflicts in api.rs, base.rs, mod.rs, and expression.rs by
keeping the local (HEAD) egglog API which uses pub const SortClass,
free-standing SortDef, and string-based Term variants.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:43:59 -08:00
Joe Fioti
b49c4cf36b simplified binary ops 2026-02-18 09:39:13 -08:00
Joe Fioti
50d16f6efa more egglog api changes 2026-02-18 08:39:13 -08:00
Austin Glover
4016a0a253 move dtype to new module, start adding additional dtypes 2026-02-17 01:35:30 +00:00
Joe Fioti
5ee6127bca closer to egglog api 2026-02-15 21:18:23 -08:00
Joe Fioti
bdf05ce95a Merge pull request #239 from withoutJ/mmrkaic/kernel-embed
Add kernel op for embedding lookups
2026-02-15 16:18:29 -05:00
Momcilo
dce379c151 fix fmt 2026-02-15 20:47:58 +01:00
Momcilo
54bb1a6370 proptest 2026-02-14 18:24:16 +01:00
Momcilo
cac5eacc35 add output_bytes to kernelEmbed impl 2026-02-14 18:24:09 +01:00
Momcilo
0749fd9ea4 Merge branch 'main' into mmrkaic/kernel-embed 2026-02-14 17:50:06 +01:00
Austin Glover
aeb7274d55 Merge remote-tracking branch 'origin' into remove-false-dtypes 2026-02-13 00:32:25 +00:00
Austin Glover
cd246bf7cb add gh cli. So that agents can initiate prs. 2026-02-13 00:31:14 +00:00
Joe Fioti
17035b273e initial transition to egglog api 2026-02-12 10:35:12 -08:00
Joe Fioti
d43b6aa214 Merge pull request #237 from luminal-ai/cublaslt
Cublaslt
2026-02-11 23:09:31 -05:00
Austin Glover
16b93235d7 fmt 2026-02-12 01:43:19 +00:00
Austin Glover
d3e6dccd76 keep random i32 vec 2026-02-12 01:41:25 +00:00
Austin Glover
cbf731d10c gate test on bf16 2026-02-12 01:40:53 +00:00
Austin Glover
07e637a584 patch candle-kernels to fix bf16 WMMA build on T4 (sm_75)
Use forked candle-kernels that disables bf16 WMMA kernel compilation
on pre-Ampere GPUs (compute cap < 80) via -DNO_BF16_KERNEL.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 01:09:29 +00:00
Austin Glover
1db3b4d837 attempt ci fix 2026-02-12 00:20:27 +00:00
Austin Glover
0a6366c398 remove duplicate tests 2026-02-11 23:46:24 +00:00
Austin Glover
87c91dd0c4 working to fix tests 2026-02-11 22:41:54 +00:00
Austin Glover
96f89e16e8 fixing tests 2026-02-10 23:31:17 +00:00
Austin Glover
e019f312ee temp adding nvfp4 and mxfp4 (will delete soon) 2026-02-10 23:23:05 +00:00
Austin Glover
623fd10ac2 Merge branch 'cublaslt' of https://github.com/luminal-ai/luminal into cublaslt 2026-02-10 19:39:23 +00:00
Austin Glover
855f996963 fmt fix bad merge 2026-02-10 19:39:18 +00:00
Joe Fioti
cd61de9362 fixed tests 2026-02-10 19:32:23 +00:00
Joe Fioti
694f66982e Merge branch 'main' into cublaslt 2026-02-10 12:57:42 -05:00
Joe Fioti
49ddba30fe Merge pull request #230 from luminal-ai/devcontainer-updates
Devcontainer updates
2026-02-10 12:55:20 -05:00
Joe Fioti
53c0b08f5e Merge pull request #238 from luminal-ai/nvfp4
gpt-oss 120b
2026-02-09 22:55:38 -05:00
Joe Fioti
92ab77c74c clippy 2026-02-10 03:49:28 +00:00
Joe Fioti
3e9c889fd2 clippy and fmt 2026-02-10 03:45:23 +00:00
Joe Fioti
f18556475f removed llama nvfp4 2026-02-10 03:41:30 +00:00
Austin Glover
a5f6abc6c0 fmt 2026-02-10 01:41:38 +00:00
Austin Glover
dcdc6864bf Add Bool dtype support and fix buffer allocation for dtype-aware sizes
- Add Bool dtype (1 byte) support throughout CUDA backend:
  - hlir.rs, other_ops.rs: output_bytes() returns correct sizes for Bool
  - block/mod.rs, to_host.rs: Add output_bytes() to MegakernelOp and CudaGraphOp
  - cublaslt/mod.rs: Handle Bool in dtype_to_cuda_types (panic, not supported)
  - utilities.rs: Bool epsilon handling for tests

- Fix KernelLessThan to output unsigned char (Bool) instead of F32
  - This fixes the relu lowering which expects Bool from comparisons

- Update runtime.rs to use output_bytes() for dtype-aware buffer allocation
  - Removed register_buffer function (used non-existent ExecutableKernel type)
  - Buffers now allocated with correct byte sizes for F16/Bf16/Bool

- Make fuzz_test_cuda_genomes fully deterministic with proptest seed
  - Uses seeded StdRng for reproducible test data and genome selection
  - egglog_utils: Changed random_initial_choice and extract_generation to
    accept impl Rng instead of ThreadRng for seedable RNG support

- Fix test_less_than to cast Bool output to F32 for comparison

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-10 00:19:44 +00:00
Joe Fioti
acbc2851c8 gpt-oss fast 2026-02-09 22:52:05 +00:00
Joe Fioti
d97629b759 gpt-oss slow 2026-02-09 22:12:21 +00:00
Austin Glover
13889b50e0 merging 2026-02-09 20:21:02 +00:00
Austin Glover
92037538e5 Merge remote-tracking branch 'origin' into cublaslt 2026-02-09 20:20:49 +00:00
Joe Fioti
710da39851 improved tilematmul nvfp4 path 2026-02-09 18:20:38 +00:00
Joe Fioti
736b459a81 llama nvfp4 2026-02-09 18:06:19 +00:00
Joe Fioti
aa136a09ae initial nvfp4 2026-02-09 17:44:32 +00:00
Joe Fioti
ee782ea829 initial egglog api 2026-02-08 22:26:11 -07:00
Joe Fioti
725c5947a1 initial egglog api 2026-02-08 22:25:43 -07:00
Joe Fioti
63667b724a Merge pull request #235 from xiaoniaoyouhuajiang/feature/metal-mega-kernel
[Metal] support dynamic shapes to ensure runtime dimensionality
2026-02-08 17:51:06 -05:00
Momcilo
a43319b8c0 kernel embed op 2026-02-08 16:55:01 +01:00
xiaoniaoyouhuajiang
26ef39cca0 Merge branch 'main' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-mega-kernel 2026-02-08 11:53:52 +08:00
xiaoniaoyouhuajiang
502a36046d add dynamic shapes for metal to ensure runtime dimensionality 2026-02-08 11:51:21 +08:00
xiaoniaoyouhuajiang
47494e1489 add regression tests for dyn buffer 2026-02-08 10:20:53 +08:00
Joe Fioti
5143759ef4 Merge pull request #234 from luminal-ai/waterfall
Waterfall compilation
2026-02-07 20:35:58 -05:00
Joe Fioti
69ec453651 fixed test 2026-02-08 00:20:10 +00:00
Joe Fioti
9e3e038e74 fmt and clippy 2026-02-07 22:43:45 +00:00
Joe Fioti
b7196102f9 fixed gemma 2026-02-07 22:35:24 +00:00
xiaoniaoyouhuajiang
9244a9acdc compatible with core's update 2026-02-07 17:52:14 +08:00
Joe Fioti
d48289b64b Merge origin/main into waterfall
Resolved conflicts keeping waterfall architecture:
- runtime.rs: kept waterfall's HostOp-based exec_graph
- block/mod.rs: kept FxHashMap kernel_cache (not LruCache)
- llama main.rs: kept waterfall comment style

Integrated main changes:
- Added trials parameter to Runtime::profile()
- New RuntimeStats trait and ExecutionStats
- GA search improvements
- Various bug fixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 02:22:07 +00:00
Joe Fioti
dbcd747d70 Merge branch 'main' into waterfall 2026-02-07 01:53:47 +00:00
Joe Fioti
7e2df5a15b kernel to host waterfall 2026-02-07 01:52:45 +00:00
Joe Fioti
795cae0dfa fmt clippy 2026-02-07 01:22:49 +00:00
Joe Fioti
9c9a05ba4e fixed gemma and qwen 2026-02-07 01:12:40 +00:00
Joe Fioti
f4d8b880f5 Merge pull request #232 from luminal-ai/ga
Genetic Algorithm Search
2026-02-06 01:47:17 -05:00
Joe Fioti
75ccffa6db Merge remote-tracking branch 'origin/main' into ga 2026-02-06 06:38:30 +00:00
Joe Fioti
9d7c41a063 fmt and clippy 2026-02-06 06:32:55 +00:00
Joe Fioti
81d3c504bb removed extras 2026-02-06 06:28:04 +00:00
Joe Fioti
02ae6cce3d Merge pull request #221 from xiaoniaoyouhuajiang/feature/metal-backend
luminal_bench for measure performance changes & debug tool
2026-02-06 01:27:42 -05:00
Joe Fioti
223b81c872 grouped chunk compilation 2026-02-06 05:01:19 +00:00
Austin Glover
75d2b57a76 working on it 2026-02-06 00:28:26 +00:00
Joe Fioti
f16f3643ad ga search with ui 2026-02-06 00:13:33 +00:00
Joe Fioti
9b225f21ee waterfall block to kernel 2026-02-05 23:09:40 +00:00
Joe Fioti
5ede75ddd1 initial 2026-02-05 21:57:54 +00:00
Joe Fioti
86404b762d Merge remote-tracking branch 'origin/main' into ga
# Conflicts:
#	examples/llama/src/main.rs
2026-02-05 19:51:21 +00:00
Joe Fioti
83ae7a4f43 adj search 2026-02-05 19:50:17 +00:00
Joe Fioti
d4cc6b9851 fixed llama example 2026-02-05 18:43:53 +00:00
Joe Fioti
04c7389ec1 genetic algorithm 2026-02-05 07:13:01 +00:00
Austin Glover
f3846d482a sgemm equivalence 2026-02-04 01:42:59 +00:00
Joe Fioti
e8bbe199f6 fixed argsort issue and added bool dtype 2026-02-03 23:47:31 +00:00
Joe Fioti
3a0d80acc0 switched megakernels to input array 2026-02-03 17:46:48 +00:00
xiaoniaoyouhuajiang
47ce1f2c53 compatible with the new MLIR cast. 2026-02-03 14:31:45 +08:00
xiaoniaoyouhuajiang
eb4a5cfe0b Merge branch 'feature/metal-backend' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-backend 2026-02-03 14:20:16 +08:00
xiaosa
dcdff24d27 Merge branch 'luminal-ai:main' into feature/metal-backend 2026-02-03 14:20:02 +08:00
xiaoniaoyouhuajiang
1d37b3f279 Merge branch 'feature/metal-backend' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-backend 2026-02-03 14:18:20 +08:00
xiaoniaoyouhuajiang
7044913680 extract stats function from runtime trait 2026-02-03 14:10:30 +08:00
Austin Glover
c6f32e8770 start copying boilerplate 2026-02-03 00:54:53 +00:00
Joe Fioti
d9f6f71bdf Merge pull request #231 from luminal-ai/prefill
Optimizations on search and cuda graphs
2026-02-02 18:41:58 -05:00
Joe Fioti
35d3c35c31 fmt and clippy 2026-02-02 23:36:47 +00:00
Joe Fioti
6d9899a42a optimization on cuda graph 2026-02-02 23:28:41 +00:00
Austin Glover
b297a3ad9e run CI with docker image 2026-02-02 19:02:49 +00:00
Austin Glover
06f83aecdf proptest regressions is a folder 2026-02-02 17:40:10 +00:00
Austin Glover
38a79a57ab match remote 2026-02-02 17:31:33 +00:00
Austin Glover
dbbee06363 ignore claude folder 2026-02-02 17:30:40 +00:00
Austin Glover
ad973665ff fmt 2026-02-02 17:24:46 +00:00
Austin Glover
c0427d5680 no diff 2026-02-02 17:23:14 +00:00
Austin Glover
ed71697de4 Merge remote-tracking branch 'origin' into devcontainer-updates 2026-02-02 17:21:41 +00:00
Joe Fioti
04030bb5d6 kernel compile cache 2026-02-02 05:58:24 +00:00
Joe Fioti
4fd2091610 imrpoved tracing 2026-02-02 04:17:30 +00:00
Joe Fioti
86041eab27 improved timings 2026-02-01 21:37:54 +00:00
Joe Fioti
817c1416e1 fixed cuda_graph perfetto timings 2026-02-01 20:04:42 +00:00
Joe Fioti
56bbd29e21 removed special rope op for gemma 2026-02-01 07:17:07 +00:00
Joe Fioti
8f0790736d removed qwen rope 2026-02-01 06:06:20 +00:00
Joe Fioti
22365da202 fixed rope hlir 2026-02-01 05:19:42 +00:00
Joe Fioti
a8290a4505 Update readme 2026-01-31 21:47:36 -05:00
Joe Fioti
8fa52c8028 tracing cleanup 2026-01-31 22:40:51 +00:00
Joe Fioti
888512507a var for block timing slots 2026-01-31 20:37:42 +00:00
Joe Fioti
1b7a2e8c57 Merge pull request #225 from luminal-ai/cuda_graph_merge
Cuda graphs
2026-01-31 02:22:42 -05:00
Joe Fioti
691f66c030 clippy 2026-01-31 07:05:32 +00:00
Joe Fioti
b66a8b9370 clippy and fmt 2026-01-31 07:02:40 +00:00
Joe Fioti
98cc057d4d fixed block, kernel, host ops for llama 2026-01-31 06:52:29 +00:00
xiaosa
48dfd09e28 Merge branch 'main' into feature/metal-backend 2026-01-31 09:12:10 +08:00
Austin Glover
9ee58410e6 simplify 2026-01-30 23:56:13 +00:00
Joe Fioti
7623ee1d66 changed ignore 2026-01-30 21:55:11 +00:00
Austin Glover
75120847de Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-01-30 21:30:38 +00:00
Joe Fioti
e13e83089a Merge branch 'cuda_graph' into cuda_graph_merge 2026-01-30 14:12:55 -05:00
Joe Fioti
3031ab25c4 removed setup scripts / inlined hf setup code 2026-01-30 19:07:49 +00:00
Joe Fioti
23ea4c9b00 tmp disable host op 2026-01-30 15:57:48 +00:00
Joe Fioti
b0880e60f2 Merge pull request #211 from luminal-ai/cuda_host_op
Cuda Host Op
2026-01-30 10:52:20 -05:00
Joe Fioti
e1681f3b6b fixed issues with hlir kernel ops 2026-01-30 07:11:57 +00:00
Austin Glover
ca3159dfe3 attempt ci fix 2026-01-30 00:05:06 +00:00
Austin Glover
6f9f1f9078 add cudatoolkit 2026-01-29 23:35:05 +00:00
Austin Glover
b003e88169 remove auto building to fix CI clippy 2026-01-29 22:54:43 +00:00
Austin Glover
233a890e49 attempt to un-break CI 2026-01-29 22:40:01 +00:00
Austin Glover
642f4739d1 remove separate test folder, merge with other tests 2026-01-29 22:32:34 +00:00
Austin Glover
6cfd456f70 fmt 2026-01-29 21:30:02 +00:00
Austin Glover
4b710ac380 satisfy clippy 2026-01-29 21:28:46 +00:00
Austin Glover
2a3a510317 .gitignore 2026-01-29 19:26:27 +00:00
Austin Glover
3a3c610e20 remove dead code 2026-01-29 18:50:38 +00:00
Austin Glover
cf98f09cd9 remove cuda_matmul example 2026-01-29 18:47:30 +00:00
Austin Glover
eeba8bba19 remove tracing 2026-01-29 18:37:59 +00:00
xiaoniaoyouhuajiang
61139c4b51 Move derived metrics like MBU/MFU that tied to GPU into , keeping raw counts in 2026-01-29 14:25:05 +08:00
Joe Fioti
190a115366 cuda readme 2026-01-29 06:22:08 +00:00
Joe Fioti
0ae77dd630 added cuda graphs 2026-01-29 06:09:05 +00:00
Joe Fioti
5fbc260fb1 Add per-kernel timing for CUDA graphs in Perfetto traces
Similar to how megakernel block ops are tracked on SM timelines in
Perfetto, this adds per-kernel timing for CUDA graph executions.

Implementation:
- Add event record nodes between kernels when building the graph
- Query CUDA event elapsed times after graph execution
- Store kernel timing data (name, start_ns, end_ns) with span UUID
- Emit timing slices on "CUDA Graph N" tracks in record_cuda_perfetto_trace()

New structures in graph.rs:
- CudaGraphKernelTiming: Per-kernel timing data
- CudaGraphTiming: Collection of kernel timings for a graph execution
- create_cuda_event(), destroy_cuda_event(), event_elapsed_ms() helpers

Runtime changes:
- CudaGraphExec now stores timing_events (Vec<CUevent>)
- Graph execution records span UUID for Perfetto correlation
- Kernel timings collected and stored in cuda_graph_timings

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 04:33:56 +00:00
Joe Fioti
b3aca06857 Remove Kernel variant, use CudaGraphExec for all kernel ops
Simplifies the code by removing the separate Kernel variant from
ExecutableKernel. Now all kernel ops are wrapped in CUDA graphs,
even single kernels. This provides:

- Simpler code with one execution path for all kernel ops
- Consistent behavior for single and multiple kernels
- Future-proof for graph optimizations (e.g., kernel fusion)

The overhead of a single-kernel CUDA graph is minimal compared to
the code simplification benefits.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 04:15:49 +00:00
xiaoniaoyouhuajiang
6ec9088055 cargo fmt & remove useless comments 2026-01-29 11:25:24 +08:00
xiaoniaoyouhuajiang
7bd28301d1 Revert the automated check changes made by Zed 2026-01-29 11:16:21 +08:00
xiaoniaoyouhuajiang
30c2af5523 optimize debug ops cli 2026-01-29 10:40:25 +08:00
Joe Fioti
09644fe4a4 Use explicit CUDA graph construction instead of stream capture
This replaces the stream capture approach with explicit graph construction:
- Build graphs by adding kernel nodes directly with cuGraphAddKernelNode_v2
- Store kernel parameters in stable memory (KernelParams struct)
- Perform surgical updates via cuGraphExecKernelNodeSetParams_v2
- Track buffer pointers to detect when full rebuild is needed

Key changes:
- Fixed CUfunction extraction (Rust reorders fields - at offset 8, not 0)
- Added KernelParams struct to manage kernel parameter lifetime
- CudaGraphHandle/CudaGraphExecHandle properly track CUDA context
- Runtime builds graph on first execution when buffers exist
- Surgical update when only dyn dims change (same buffer pointers)
- Full rebuild when buffer pointers change

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 23:38:58 +00:00
Austin Glover
74ce1ad751 clippy 2026-01-28 23:23:39 +00:00
Joe Fioti
3ccfe54d1b Add CUDA graph support for kernel ops
- Add graph.rs with low-level CUDA graph API wrappers (CudaGraphHandle,
  CudaGraphExecHandle, CudaFunctionExt trait)
- Add CudaGraphExec variant to ExecutableKernel enum
- Partition kernel ops into subgraphs using partition_marked_convex
- Build CUDA graphs via stream capture at first execution
- Rebuild graphs when dynamic dimensions change
- Use non-blocking capture stream to avoid default stream limitations
- Use raw device pointers during capture to avoid cudarc's event tracking
- Add comprehensive tests for CUDA graph functionality

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-28 23:05:10 +00:00
Austin Glover
a283ee099d work on reusability for setup scripts 2026-01-28 23:01:19 +00:00
Austin Glover
86b04dff3b reduce diffs 2026-01-28 22:58:12 +00:00
Austin Glover
d322962b8a remove unnecessary dep 2026-01-28 22:58:02 +00:00
Austin Glover
b424230113 fmt 2026-01-28 20:17:35 +00:00
Austin Glover
557d5137ee restore perfetto 2026-01-28 20:16:30 +00:00
Austin Glover
472fdf0fa4 rewrite rules for each layout mapping. 2026-01-28 20:16:15 +00:00
xiaoniaoyouhuajiang
15cf055f6f revert back cast impl 2026-01-28 15:55:36 +08:00
xiaoniaoyouhuajiang
1b473f81b0 support layer_norm & add readme for luminal_bench 2026-01-28 14:36:14 +08:00
Austin Glover
038fdfca11 add small sizes and printing 2026-01-28 04:15:53 +00:00
Austin Glover
dec8aa90f9 add assert to ensure cublas 2026-01-28 04:15:44 +00:00
Austin Glover
1f359bc0d2 swapping trick to support row major 2026-01-28 04:15:31 +00:00
Austin Glover
06d50459a1 make simply pass through 2026-01-28 04:14:25 +00:00
Austin Glover
9d7d87de81 hide stats behind logging flag 2026-01-28 04:14:08 +00:00
xiaoniaoyouhuajiang
8409a1ef1c add lower analysis ability to debug_ops cli& find root cause for egraph's broken 2026-01-28 10:25:41 +08:00
Austin Glover
3f19f4e331 fix printing 2026-01-27 19:08:48 +00:00
Austin Glover
fcfa8806e5 cuda driver visibility 2026-01-27 19:07:19 +00:00
Austin Glover
8c6c6a5964 add devcontainer 2026-01-27 18:23:46 +00:00
xiaoniaoyouhuajiang
e9fce9ef81 merge remote commit from .gitignore 2026-01-27 16:49:47 +08:00
xiaoniaoyouhuajiang
69edcb9a7e add debug example to fix egglog graph problem 2026-01-27 16:41:45 +08:00
Austin Glover
7bfe0fd61e wip 2026-01-27 03:38:46 +00:00
Austin Glover
9e941e03c1 wip 2026-01-27 03:38:36 +00:00
Austin Glover
e1191e46d5 fix warning 2026-01-26 22:44:22 +00:00
Austin Glover
bb973bb6eb add host op execute branch (removed during merge by accident) 2026-01-26 22:44:11 +00:00
Austin Glover
2666bf00cc fix rename 2026-01-26 22:43:43 +00:00
Austin Glover
dd9b92517d Merge branch 'main' of https://github.com/luminal-ai/luminal into cuda_host_op 2026-01-26 22:14:06 +00:00
Austin Glover
2c2f486385 renames and bump cuda version 2026-01-26 22:10:08 +00:00
Austin Glover
42837f800d git ignore 2026-01-26 21:18:24 +00:00
Joe Fioti
f886b55bbf fixed gemma and qwen 2026-01-26 19:01:42 +00:00
Joe Fioti
613ca6895e Merge pull request #220 from luminal-ai/cstruct_api
Cstruct api
2026-01-26 13:39:45 -05:00
Joe Fioti
9d6dec791a round out cstruct api 2026-01-26 18:39:01 +00:00
Joe Fioti
5e6860d669 removed expressions 2026-01-26 18:33:10 +00:00
Joe Fioti
e39742a397 Merge pull request #219 from luminal-ai/qwen
Qwen 3 4B Dense and Gemma 3 4B Dense
2026-01-26 12:53:08 -05:00
Joe Fioti
d25210615c clippy and fmt 2026-01-26 17:06:41 +00:00
Joe Fioti
1b183f1515 gemma 2026-01-26 07:06:28 +00:00
Joe Fioti
b2e8a35d94 new cstruct api 2026-01-25 16:13:51 -08:00
Joe Fioti
0f2d109cf5 qwen 3 4b dense 2026-01-25 22:24:58 +00:00
xiaoniaoyouhuajiang
8bce1e1d38 add debug ops 2026-01-25 23:17:54 +08:00
xiaoniaoyouhuajiang
687e10b31b remove egglog search for each benchmark 2026-01-24 22:10:55 +08:00
xiaoniaoyouhuajiang
bf3f4a33ef fix lint 2026-01-23 17:40:15 +08:00
xiaoniaoyouhuajiang
47520d0291 record dynamic metrics into generated json file 2026-01-23 17:22:06 +08:00
xiaoniaoyouhuajiang
54ab273b88 modify bench time as gpu time 2026-01-23 15:50:16 +08:00
xiaoniaoyouhuajiang
abc500bc67 add more metrics for luminal_bench 2026-01-23 11:03:51 +08:00
xiaoniaoyouhuajiang
3d46196fb5 add benchmark for patterns.rs 2026-01-23 09:22:51 +08:00
xiaoniaoyouhuajiang
ec62c3441f fix gather-op micro benchmark 2026-01-23 00:19:12 +08:00
xiaoniaoyouhuajiang
6c3e3232ef ignore .zed file & add more hlir.rs op for micro.rs 2026-01-22 17:49:42 +08:00
xiaoniaoyouhuajiang
139c7b0d70 add a basic benchmark supported by criterion.rs 2026-01-22 15:08:16 +08:00
Joe Fioti
a194fea8ac sped up search 2026-01-21 19:44:13 +00:00
Joe Fioti
35e9695760 Merge pull request #216 from luminal-ai/sol
Much closer to SoL on llama
2026-01-21 12:57:16 -05:00
Joe Fioti
b9e3ecaeb8 disable mul-div simplification test 2026-01-21 17:56:41 +00:00
Joe Fioti
40655a3916 clippy 2026-01-21 17:51:53 +00:00
Joe Fioti
96af7c7670 fixed simplification issue 2026-01-21 17:46:30 +00:00
Joe Fioti
6dc8e8f21c Merge branch 'main' into sol 2026-01-21 11:52:31 -05:00
Joe Fioti
6658b8dae3 testing 2026-01-21 16:46:47 +00:00
Joe Fioti
fee1952e70 cleanups and tests passing 2026-01-21 16:41:46 +00:00
Joe Fioti
6eb45916a3 abstraction cleanup 2026-01-21 15:07:36 +00:00
Joe Fioti
adfb98ca46 contained full split logic 2026-01-21 04:16:58 +00:00
Joe Fioti
eb5becf391 optimized register_buffer calls 2026-01-21 00:42:33 +00:00
Joe Fioti
c294bee221 no memsets 2026-01-21 00:26:27 +00:00
Joe Fioti
65a03cd77e initial no-zero output buffers 2026-01-20 23:20:11 +00:00
Joe Fioti
3057849501 rowembed op 2026-01-20 19:12:47 +00:00
austin_glover
6747de13a8 mid renaming 2026-01-20 01:35:37 +00:00
Joe Fioti
8f6e5aaae2 cleanup 2026-01-20 00:47:48 +00:00
austin_glover
74d8564a72 check that it's not the egraph extract fix 2026-01-19 21:53:41 +00:00
austin_glover
7f2ea49a4f remove tracing specific decode - prefill distinction 2026-01-19 21:22:32 +00:00
Joe Fioti
7aa4898bec 75.8% MBU 2026-01-19 21:12:58 +00:00
Joe Fioti
4c7aab343a 71% mbu 2026-01-19 19:37:09 +00:00
austin_glover
d3a4e3c4cf fix merge 2026-01-19 19:37:06 +00:00
austin_glover
e8f0abab6a wip 2026-01-19 19:22:52 +00:00
austin_glover
d2bf7182e9 Merge remote-tracking branch 'origin' into cuda_host_op 2026-01-19 19:22:17 +00:00
austin_glover
f3a9422de0 wip 2026-01-19 18:21:41 +00:00
austin_glover
61262232b2 wip 2026-01-19 18:21:35 +00:00
Joe Fioti
89b22e0e9e some slight improvements 2026-01-19 05:14:20 +00:00
Joe Fioti
7d40194902 test matmul 2026-01-19 01:53:21 +00:00
Joe Fioti
53b45f127d fmt 2026-01-18 17:12:53 -08:00
Joe Fioti
6b621a3077 Merge pull request #214 from jonahsamost/jonah_1_18_ops
Common ops implementations
2026-01-18 20:12:26 -05:00
Joe Fioti
873c41f6b3 Merge pull request #215 from jonahsamost/jonah_kernel_sum_1_18
Bug fix for KernelSumReduce
2026-01-18 20:00:54 -05:00
Joe Fioti
bb639d2861 added dynamic k chunk size 2026-01-18 23:35:30 +00:00
jonah
e4fdcb7730 kernel sum reduce fix 2026-01-18 10:05:32 -08:00
jonah
31fa21d9d4 try to repro 2026-01-18 07:51:32 -08:00
jonah
ebb4f02a47 naming 2026-01-18 07:47:58 -08:00
jonah
d331a5c48c common operations 2026-01-18 07:41:35 -08:00
Joe Fioti
23d39031ce fixed timing 2026-01-17 06:34:37 +00:00
Joe Fioti
241313c018 53% mbu 2026-01-17 06:06:53 +00:00
Joe Fioti
ec47346080 clippy 2026-01-16 21:44:18 -08:00
Joe Fioti
2348bcfc20 Merge pull request #210 from xiaoniaoyouhuajiang/feature/metal-backend
Feature/metal backend
2026-01-17 00:26:04 -05:00
Joe Fioti
4ce67cdb11 improved issue latency and added op count to stats 2026-01-17 05:17:22 +00:00
Joe Fioti
0e8ca500ec attempted ci fix 2026-01-17 03:09:57 +00:00
Joe Fioti
18f8f09429 debug ci/cd 2026-01-17 02:59:29 +00:00
Joe Fioti
2664be3b81 improved expression simplification 2026-01-17 02:47:26 +00:00
Joe Fioti
6a7061497c fmt 2026-01-16 18:23:22 -08:00
Joe Fioti
17e5ab8690 fixed compile errors in cuda 2026-01-16 18:23:09 -08:00
Joe Fioti
9e6e1755e0 changed expr rules 2026-01-17 01:49:05 +00:00
austin_glover
c93070a096 fmt 2026-01-17 00:04:26 +00:00
austin_glover
488348a890 Merge remote-tracking branch 'origin' into cuda_host_op 2026-01-17 00:04:03 +00:00
austin_glover
852bd67d77 fmt 2026-01-17 00:02:28 +00:00
Joe Fioti
e9e169fb51 Merge pull request #207 from jonahsamost/jonah_argsort_1_10 2026-01-16 14:44:43 -05:00
jonah
a8b7600cf2 revert src op 2026-01-16 11:27:55 -08:00
jonah
fe7d7403e1 remove unused import 2026-01-16 11:17:14 -08:00
jonah
f7b945940e remove egglog pattern print 2026-01-16 10:53:39 -08:00
jonah
472ea979c7 change kernel arg sort to match on hlir argsort pattern 2026-01-16 10:51:45 -08:00
Joe Fioti
07fac1e0ca added tilematmulsplitk 2026-01-16 06:21:49 +00:00
Joe Fioti
e3b0d79c78 added producer_barriers_seperate function 2026-01-16 05:08:07 +00:00
austin_glover
6d0fff7f35 back to println 2026-01-16 01:42:56 +00:00
austin_glover
6fc80d1432 claude egraph export fix 2026-01-16 01:42:49 +00:00
austin_glover
445b8a621e Merge remote-tracking branch 'origin/main' into cuda_host_op 2026-01-16 00:11:34 +00:00
austin_glover
16c9b1b250 tracing! 2026-01-16 00:09:14 +00:00
austin_glover
ad190e4951 more tracing 2026-01-15 23:48:47 +00:00
Joe Fioti
fac121a8fc edge vis 2026-01-15 16:19:49 +00:00
xiaoniaoyouhuajiang
280be38da9 delete review content 2026-01-15 16:39:06 +08:00
xiaoniaoyouhuajiang
919cf6b97d exclude from linux ci 2026-01-15 16:36:03 +08:00
xiaoniaoyouhuajiang
32ecfa83df fix gather wrong index 2026-01-15 16:23:11 +08:00
Joe Fioti
d92295407a removed cubemul and tilesum 2026-01-15 05:11:24 +00:00
austin_glover
fa473087a6 add tracing for debugging 2026-01-15 01:21:36 +00:00
austin_glover
b88e3cb60c reduce 2026-01-14 21:16:10 +00:00
austin_glover
1bec5dfb9e reduce changes 2026-01-14 21:12:52 +00:00
austin_glover
56573ba532 have luminal re-export tracing 2026-01-14 21:11:47 +00:00
austin_glover
8abc155ff0 reduce changes 2026-01-14 21:09:03 +00:00
austin_glover
bc6ab17048 remove diff 2026-01-14 21:07:19 +00:00
austin_glover
293adaf6b0 make % 64 so it works (temporary fix for tilematmul bug) 2026-01-14 20:59:16 +00:00
austin_glover
21fe6cba57 clean up rewrite 2026-01-14 20:57:54 +00:00
austin_glover
a55ee71bff put host_matmul back as an op 2026-01-14 20:57:43 +00:00
austin_glover
9641f930a7 fix buffer pointer bug 2026-01-14 20:57:20 +00:00
Joe Fioti
6529fd48f9 Optimize RowSwishMul chunking, issue process, and remove prologues
- Break up RowSwishMul to process rows in 128-element chunks for better
  parallelization across SMs (37% faster)
- Optimize eval_expression by pre-computing constant expressions at
  compile time (20% reduction in Issue time)
- Remove prologues from RowAdd and TileMatmul, simplifying the code
  and reading directly from global memory

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-14 20:48:01 +00:00
austin_glover
83d8a9a7c2 wip 2026-01-14 19:37:29 +00:00
xiaoniaoyouhuajiang
07201d4d99 add remaining ops for metal backend 2026-01-14 16:45:26 +08:00
Joe Fioti
7280de75d9 prologue tile matmul 2026-01-14 06:39:53 +00:00
Joe Fioti
2e4939ef58 prologue per input on block ops 2026-01-14 05:13:15 +00:00
xiaoniaoyouhuajiang
44a8f1019f support all unary&binary operation; add relative tests 2026-01-14 10:59:26 +08:00
austin_glover
c52e64d73d cargo fmt 2026-01-14 01:53:26 +00:00
austin_glover
c09a9d8f50 Merge remote-tracking branch 'origin/main' into cuda_host_op 2026-01-14 01:49:54 +00:00
austin_glover
0acd632847 updates to get closer to running 2026-01-14 01:29:48 +00:00
austin_glover
ce71dd3ce9 refactors 2026-01-14 01:29:37 +00:00
Joe Fioti
b90a29b01a parametarized tile matmul tile size 2026-01-13 22:15:55 +00:00
Joe Fioti
f325e94e5b tweaks 2026-01-13 21:07:49 +00:00
austin_glover
f423c9aa67 Merge remote-tracking branch 'origin/main' into cuda_host_op (not working) 2026-01-13 19:17:14 +00:00
Joe Fioti
75be858339 fixed for real this time 2026-01-13 18:28:53 +00:00
Joe Fioti
7e76f68e82 fixed rowadd 2026-01-13 18:06:51 +00:00
Joe Fioti
cb96fbc669 Merge branch 'main' of https://github.com/luminal-ai/luminal 2026-01-13 17:55:44 +00:00
Joe Fioti
f63abbc638 better interpreter sync 2026-01-13 17:54:57 +00:00
Joe Fioti
b25c7a528c optimized rowadd op 2026-01-13 05:39:13 +00:00
xiaoniaoyouhuajiang
871e182647 restore .gitignore 2026-01-13 11:42:11 +08:00
xiaoniaoyouhuajiang
a8b59361e5 remove review docs 2026-01-13 11:39:54 +08:00
xiaoniaoyouhuajiang
398edb31e2 adjust upstream code 2026-01-13 11:38:23 +08:00
xiaoniaoyouhuajiang
e7d495f4de Merge branch 'feature/metal-backend' of https://github.com/xiaoniaoyouhuajiang/luminal into feature/metal-backend 2026-01-13 11:23:04 +08:00
xiaoniaoyouhuajiang
a51300adbe fix clippy lint 2026-01-13 11:22:20 +08:00
xiaoniaoyouhuajiang
5c2c1bce32 remove useless comments 2026-01-13 11:15:57 +08:00
xiaosa
04af8b6605 Merge branch 'luminal-ai:main' into feature/metal-backend 2026-01-13 11:09:20 +08:00
xiaoniaoyouhuajiang
352ef5bc69 remove protest test-suit 2026-01-13 10:59:36 +08:00
xiaoniaoyouhuajiang
0b0852c8e2 fix(metal): fix stride handling and get_f32() fallback
- Add stride-aware indexing using flatten_mul_strides in all kernels
- Fix get_f32() to fallback to hlir_buffers for Input nodes
- All 6 tests passing
2026-01-13 10:32:57 +08:00
austin_glover
5c11c41ac3 column major fix 2026-01-13 02:09:11 +00:00
austin_glover
b197180607 more tracing stuff 2026-01-13 02:08:45 +00:00
austin_glover
1bc238a0f2 idiomatic tracing 2026-01-13 02:08:10 +00:00
austin_glover
e8e8257cbc convert to candle and col-major b 2026-01-13 02:07:51 +00:00
austin_glover
5b4701a304 remove ndarray, include candle 2026-01-13 01:46:20 +00:00
austin_glover
c08881f3ca "document" convention for matmul layouts 2026-01-13 01:44:15 +00:00
austin_glover
9b8416b5e1 move to candle for tests 2026-01-13 01:43:44 +00:00
austin_glover
32637f4279 flush llir_graphs dir 2026-01-13 01:43:31 +00:00
austin_glover
16783e3851 more idiomatic tracing 2026-01-13 01:40:42 +00:00
austin_glover
f802faa528 ignore log and viz files 2026-01-13 01:31:50 +00:00
austin_glover
524c28b729 ignore proptests-regressions 2026-01-13 01:30:00 +00:00
Joe Fioti
fd569a71a4 better perfetto 2026-01-12 23:11:14 +00:00
Joe Fioti
f7bcea3eab added end to smevent 2026-01-12 20:13:15 +00:00
Joe Fioti
9c16272f03 Merge pull request #209 from luminal-ai/benchmarking
Benchmarking
2026-01-12 14:52:43 -05:00
Joe Fioti
e52087ef18 fmt 2026-01-12 19:48:19 +00:00
austin_glover
ca3ad47106 ignore viz files 2026-01-12 19:39:00 +00:00
Joe Fioti
ade077f2f3 removed llama benchmark and added computed stats 2026-01-12 19:36:59 +00:00
austin_glover
ce9d0b9fc8 make output llir files .dot files 2026-01-12 19:34:23 +00:00
austin_glover
a09dbe8b3a remove debugging stuff 2026-01-12 19:33:45 +00:00
austin_glover
a27032f9f4 add perfetto span 2026-01-12 19:33:28 +00:00
austin_glover
b21f5c17da ignore .pftrace 2026-01-12 19:30:17 +00:00
Joe Fioti
53345ea47c added 1 thread cargo test 2026-01-12 17:04:57 +00:00
Joe Fioti
a5ffafa300 fixed llama example 2026-01-12 16:57:45 +00:00
xiaoniaoyouhuajiang
a6fec2d1b6 fix(metal): fix cleanup() and add working tests
- Change cleanup() to return false for all Metal ops
- Make allocate_intermediate_buffers() public
- Add proptest and deterministic tests for Add, Mul, Exp2
- All 6 tests passing
2026-01-12 22:49:55 +08:00
Joe Fioti
b7468a4d2b measure bandwidth of kernelops 2026-01-12 06:50:14 +00:00
austin_glover
1aa149f787 wip 2026-01-12 05:32:10 +00:00
austin_glover
23f2952653 ignore llir_graphs 2026-01-12 05:27:47 +00:00
Joe Fioti
5dec359501 fmt 2026-01-12 03:47:12 +00:00
Joe Fioti
ab611aedae moved mk logic to block/mod 2026-01-12 03:37:58 +00:00
xiaoniaoyouhuajiang
6c7f60299d feat(metal): add initial luminal_metal crate skeleton
- Add MetalRuntime implementing Runtime trait
- Add MetalKernelOp trait for Metal kernel operations
- Implement initial ops: MetalExp2, MetalAdd, MetalMul
- Add egglog integration with rewrites for HLIR->Metal ops
2026-01-12 11:34:05 +08:00
Joe Fioti
51a57068ed runtime-sized megakernel payloads 2026-01-12 02:54:03 +00:00
Joe Fioti
5e060887dd re-enalbed gpu sm tracing 2026-01-12 02:24:30 +00:00
jonah
7576cc892b native and cuda argsort 2026-01-11 11:58:40 -08:00
Joe Fioti
fc0e9deb28 cache dyn map cuda runtime 2026-01-11 06:47:34 +00:00
Joe Fioti
31102a5443 tweaks 2026-01-11 06:04:19 +00:00
Joe Fioti
a2cad934ff tweaks 2026-01-11 05:46:29 +00:00
Joe Fioti
89752055a2 Merge pull request #206 from luminal-ai/hlir_attn
Attention as a custom op
2026-01-11 00:31:38 -05:00
Joe Fioti
51339e7564 merged 2026-01-11 05:25:57 +00:00
Joe Fioti
2afd077198 finally removed custom state 2026-01-11 05:10:16 +00:00
Joe Fioti
9565b8a324 merged from main 2026-01-11 04:04:25 +00:00
Joe Fioti
688cf9bc4a custom attention op 2026-01-11 03:28:47 +00:00
Joe Fioti
d50fd065e7 changed docs 2026-01-10 08:44:20 -08:00
austin_glover
857030005d normalize 2026-01-09 22:05:07 +00:00
austin_glover
6766b47c05 remove dev dependencies 2026-01-09 22:04:47 +00:00
austin_glover
5b27f661db convert to trace 2026-01-09 20:14:55 +00:00
austin_glover
edde16845b looks for changes in the folder, not just file 2026-01-09 19:18:09 +00:00
austin_glover
b93b13d3f3 more strict gitignore 2026-01-09 19:17:55 +00:00
austin_glover
450fddf98e make sure both inputs are f32 2026-01-09 18:58:35 +00:00
austin_glover
4abe144cfa fmt 2026-01-09 18:50:23 +00:00
austin_glover
c21a482025 add CudaRuntime::new 2026-01-09 18:50:18 +00:00
austin_glover
3dee4aa1e6 add * to make gitignore actually work 2026-01-09 18:49:51 +00:00
austin_glover
544520a71f matmul example 2026-01-08 23:44:52 +00:00
austin_glover
5bcb666ad9 ignore log files 2026-01-08 23:43:53 +00:00
austin_glover
80794386ce ignore proptest this should be defined more globally 2026-01-08 23:43:23 +00:00
austin_glover
0a1d0d70fe added tracing, but I think this should actually be just a part of luminal regular? maybe? because we will always want it to trace? maybe we make it a feature flag? 2026-01-08 23:42:58 +00:00
austin_glover
a4d5941437 WIP 2026-01-08 23:42:09 +00:00
austin_glover
dc73e445e6 HostMatmul impl 2026-01-08 23:41:58 +00:00
austin_glover
808ad3d4e8 fmt 2026-01-08 23:41:35 +00:00
austin_glover
337aad82c5 fmt 2026-01-08 23:41:24 +00:00
austin_glover
14fae6755e copy tests into luminal_cuda 2026-01-08 23:40:46 +00:00
austin_glover
80f0e48e08 fmt 2026-01-08 23:39:40 +00:00
austin_glover
950a108904 lots of changes here to support host op 2026-01-08 23:39:16 +00:00
austin_glover
a355414a70 dev dependencies 2026-01-08 23:38:45 +00:00
austin_glover
5b42419a9f change timing to info!, support for more args in the macro 2026-01-08 23:38:33 +00:00
Joe Fioti
baedeeec63 tweaked runtime trait 2026-01-08 05:51:41 +00:00
Joe Fioti
b1426ba8b2 llama tweaks 2026-01-08 05:41:35 +00:00
Joe Fioti
a8f6110fff big speedups with newer egglog 2026-01-08 03:39:29 +00:00
Joe Fioti
e6e1801426 removed prints 2026-01-08 01:42:34 +00:00
Joe Fioti
ec578b70a9 Merge pull request #203 from luminal-ai/hlir_rope
Hlir rope
2026-01-07 20:40:30 -05:00
Joe Fioti
ec91872d04 fixed argmax 2026-01-08 01:32:46 +00:00
Joe Fioti
ee1daf3979 clippy 2026-01-08 00:39:11 +00:00
Joe Fioti
aa55c46f9e fmt 2026-01-08 00:37:20 +00:00
Joe Fioti
48db7a7191 merge 2026-01-08 00:36:47 +00:00
Joe Fioti
da0b514f30 Merge branch 'main' into hlir_rope 2026-01-07 19:10:06 -05:00
Joe Fioti
eafc31de50 added explicit early_rewrites to egglogop 2026-01-08 00:00:45 +00:00
Joe Fioti
3fbdbad4e8 little changes 2026-01-07 21:08:03 +00:00
austin_glover
d0b7acd27d ignore logs, viz 2026-01-06 22:49:40 +00:00
Joe Fioti
33133ad7a8 fmt 2026-01-05 21:16:50 -08:00
Joe Fioti
a1af80c677 Merge pull request #201 from jonahsamost/jonah_1_4_reduction_cuda_kernels
Implement SumReduce and MaxReduce cuda kernels
2026-01-06 00:13:27 -05:00
austin_glover
5c4ee6b272 add build script support to make example batteries included 2026-01-06 00:37:23 +00:00
austin_glover
1bd80ec18b placeholder impl for matmul 2026-01-06 00:22:50 +00:00
austin_glover
cf43620a35 re-export anyhow 2026-01-06 00:20:53 +00:00
austin_glover
625ea8aaf2 add host module 2026-01-06 00:20:33 +00:00
austin_glover
a75d19d645 credit visualization libraries 2026-01-06 00:20:10 +00:00
jonah
710375673a remove debug statements 2026-01-05 15:49:53 -08:00
jonah
a945e03a14 added mean reduce kernel 2026-01-05 15:47:00 -08:00
Joe Fioti
00fb85a4da Merge pull request #202 from luminal-ai/codex/add-proptests-to-src-and-luminal_cuda
Update CUDA proptest and stabilize argmax checks
2026-01-05 17:37:27 -05:00
Joe Fioti
d22ff09f2b Run cuda test without env gate 2026-01-05 17:33:29 -05:00
Joe Fioti
d4342df432 proptests 2026-01-05 12:24:12 -08:00
Joe Fioti
22168d7169 llama on thor 2026-01-04 22:21:24 -08:00
Joe Fioti
437fb84ae0 cuda runtime changes 2026-01-04 22:05:32 -08:00
Joe Fioti
ed63b25cd6 changed action for cuda runner 2026-01-04 20:23:11 -08:00
Joe Fioti
66a56936b3 adjusted cudarc 2026-01-04 20:16:58 -08:00
Joe Fioti
9a31917f92 adjusted to work on jetson thor 2026-01-04 20:11:21 -08:00
Joe Fioti
7bf74b6afd Merge pull request #200 from ScottBrenner/patch-1
Bump actions/checkout to v6
2026-01-04 20:05:20 -05:00
jonah
94c5ac3977 kernel max reduce works 2026-01-04 15:43:04 -08:00
Joe Fioti
bb16f481b1 api and file cleanup 2026-01-04 23:10:40 +00:00
jonah
d344270c79 reduce kernel sum works 2026-01-04 12:54:18 -08:00
Scott Brenner
e533ad35f1 Bump actions/checkout to v6 2026-01-04 11:01:55 -08:00
Joe Fioti
b29e373cf1 Merge pull request #199 from luminal-ai/search
Re-enabled Search
2026-01-03 21:28:35 -05:00
Joe Fioti
d6321f3f6a changed cuda constants 2026-01-04 02:24:49 +00:00
Joe Fioti
2bcc5d54a6 symbolic test disables 2026-01-04 02:11:44 +00:00
Joe Fioti
eeec8d4eb5 fixed cuda test 2026-01-04 02:08:39 +00:00
Joe Fioti
6cb5e39e97 re-enabled timings 2026-01-04 01:36:09 +00:00
Joe Fioti
62170cb64a cstruct name change 2026-01-04 01:30:29 +00:00
Joe Fioti
fbf3e81ef7 search 2026-01-02 23:51:44 +00:00
Joe Fioti
5c6d37b67e removed deps 2026-01-01 19:15:57 +00:00
Joe Fioti
e51edb326e cleaned up luminal_tracing 2026-01-01 18:55:19 +00:00
Joe Fioti
cff26db0c0 cleaned up vis example 2026-01-01 18:40:28 +00:00
Joe Fioti
f6dad3b9c7 Merge pull request #197 from luminal-ai/codex/design-unified-tracing-mechanism
Wire GPU trace postprocessing into trace session
2026-01-01 13:23:48 -05:00
Joe Fioti
094eb86db0 clippy 2026-01-01 18:20:52 +00:00
Joe Fioti
8bc04477f3 fmt 2026-01-01 18:16:47 +00:00
Joe Fioti
44a60dc38a Merge branch 'main' into codex/design-unified-tracing-mechanism 2026-01-01 13:16:05 -05:00
Joe Fioti
1f7b29a7d9 cleaned up interface 2026-01-01 18:14:24 +00:00
Joe Fioti
d3fbc58173 Wire GPU trace postprocessing into trace session 2026-01-01 12:33:26 -05:00
Joe Fioti
5fa21d2f8f Merge pull request #196 from luminal-ai/benckmarking
llama benchmarking
2025-12-31 20:41:22 -05:00
Joe Fioti
53ed1f7898 llama benchmark 2026-01-01 01:36:30 +00:00
Joe Fioti
6fe0f7c9d1 Merge pull request #195 from luminal-ai/codex/add-benchmarking-functionality-for-llama-example
Add benchmarking metrics to llama example
2025-12-31 20:06:34 -05:00
Joe Fioti
9bfa4096e7 Add llama benchmarking metrics 2025-12-31 20:06:21 -05:00
Joe Fioti
0082fedd3c iterative egglog parsing/running 2025-12-31 13:10:05 -08:00
Joe Fioti
183aeae009 changed sort api 2025-12-31 09:38:13 -08:00
Joe Fioti
112b833989 attempted early destructive rewrite 2025-12-31 17:21:59 +00:00
Joe Fioti
09ada8acc3 rope in hlir 2025-12-30 23:12:47 +00:00
Joe Fioti
d654927b2e Add HLIR graph display and implement rotary embeddings
- Add display_graph call to output HLIR to hlir.txt
- Implement apply_rotary_embeddings function for positional encoding
- Simplify model forward pass to test rope implementation
- Comment out full layer stack to focus on rope debugging

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-30 11:20:11 -08:00
Joe Fioti
fe78f9b4a0 removed other readmes 2025-12-29 20:27:31 -08:00
Joe Fioti
f41e7600f2 fixed tests 2025-12-28 09:14:46 -05:00
Joe Fioti
8b0fafedd2 Merge pull request #193 from luminal-ai/codex/identify-areas-for-improved-test-coverage
Consolidate and group related test variants for ops, cumops, reductions, and shape tracker
2025-12-28 08:35:04 -05:00
Joe Fioti
cf7381529f Group related test variants 2025-12-28 08:31:01 -05:00
Joe Fioti
81f2276c62 Merge pull request #192 from luminal-ai/expr-equality
Egglog Expressions
2025-12-27 21:00:09 -05:00
Joe Fioti
b384319d42 Added global simplification cache 2025-12-28 01:57:14 +00:00
Joe Fioti
81251f6c63 llama works 2025-12-28 01:31:48 +00:00
Joe Fioti
0ef5a28c5f fixed striding bug 2025-12-28 00:29:19 +00:00
Joe Fioti
b8906af6f8 tmp fix 2025-12-27 22:35:38 +00:00
Joe Fioti
db56feb260 eq perf improvements, made egglog equality its own function 2025-12-27 16:04:29 -05:00
Joe Fioti
2d393d6569 Merge branch 'expr-equality' of https://github.com/luminal-ai/luminal into expr-equality 2025-12-27 09:28:43 -05:00
Joe Fioti
25be1205b1 clippy 2025-12-27 09:28:39 -05:00
Joe Fioti
2f15f449a2 Merge pull request #191 from luminal-ai/codex/implement-equality-for-expressions
Add tests for egglog-based expression equality
2025-12-27 09:24:39 -05:00
Joe Fioti
cd4c107bdf Add tests for egglog expression equality 2025-12-27 09:24:30 -05:00
Joe Fioti
fbe3b52703 simplifications 2025-12-27 08:45:27 -05:00
Joe Fioti
b23a6c7555 Merge pull request #190 from luminal-ai/codex/convert-to-use-egglog-for-eq-sat-engine
Refine egglog-based symbolic simplification
2025-12-27 08:01:26 -05:00
Joe Fioti
c0519572cc Refine egglog simplification rules 2025-12-27 07:59:58 -05:00
Joe Fioti
6967193f16 Merge branch 'main' of https://github.com/luminal-ai/luminal 2025-12-25 15:03:36 -05:00
Joe Fioti
4200d2f26a improved api 2025-12-25 15:03:32 -05:00
Joe Fioti
4ddd725232 Changed readme
Updated README.md to reflect changes in usage and examples.
2025-12-25 11:34:58 -05:00
Joe Fioti
e0727a3678 Merge pull request #188 from luminal-ai/visualization
Modules for visualization, serialized_egraph, egglog_utils, example for visualization
2025-12-24 22:18:43 -05:00
austin_glover
f9599ceb8e clippy 2025-12-24 05:38:43 +00:00
austin_glover
947e5901d2 fmt 2025-12-24 05:27:47 +00:00
austin_glover
1dcc626884 Merge branch 'main' of https://github.com/luminal-ai/luminal into visualization 2025-12-24 05:14:36 +00:00
austin_glover
2baac7abe6 egglog utils module 2025-12-24 05:13:40 +00:00
austin_glover
6015349e41 serialized egraph changes 2025-12-24 05:13:23 +00:00
austin_glover
72958f82c3 add visualization and egglog_utils 2025-12-24 05:13:10 +00:00
austin_glover
62bd3d6725 visualization module 2025-12-24 05:12:49 +00:00
austin_glover
8e1fe3b1f4 visualization example 2025-12-24 05:12:40 +00:00
austin_glover
6693df7a03 use luminal prelude 2025-12-24 05:09:27 +00:00
austin_glover
5712458fcb use luminal prelude 2025-12-24 05:09:14 +00:00
austin_glover
fc9047401f silence warnings 2025-12-24 05:08:45 +00:00
austin_glover
2039f4983a serialized egraph changes 2025-12-24 05:08:33 +00:00
austin_glover
9492cbaf5a remove mut to silence warnings 2025-12-24 05:07:13 +00:00
austin_glover
c546bc58f0 scootch egg and clean up examples related toml 2025-12-24 05:06:55 +00:00
austin_glover
2285849b38 renaming to jinja to say it's templated 2025-12-24 05:04:21 +00:00
Joe Fioti
8a4b2c3c13 disable llama in ci/cd 2025-12-24 02:03:46 +00:00
Joe Fioti
2f3fe5569d fixed compile args 2025-12-24 01:52:04 +00:00
Joe Fioti
af7fa415bf fixed compile args 2025-12-24 01:37:03 +00:00
Joe Fioti
d21914888d changed megakernal compile opts 2025-12-24 01:22:36 +00:00
Joe Fioti
ee84b4f534 changed workflow 2025-12-24 00:50:58 +00:00
Joe Fioti
1c33a35bbb changed workflow 2025-12-24 00:48:25 +00:00
Joe Fioti
8929fc5b21 changed workflow 2025-12-24 00:35:41 +00:00
Joe Fioti
5054cd90d8 changed workflow 2025-12-24 00:33:39 +00:00
Joe Fioti
3b4e332c65 changed workflow 2025-12-24 00:29:53 +00:00
Joe Fioti
9d8a46c665 changed workflow 2025-12-24 00:27:47 +00:00
Joe Fioti
d9aa330d1a changed workflow 2025-12-24 00:25:31 +00:00
Joe Fioti
648880f3bd changed workflow 2025-12-24 00:22:21 +00:00
Joe Fioti
16bbbf551e luminal cuda fixes 2025-12-23 18:43:57 +00:00
Joe Fioti
99c988a0b6 cuda stride conversion 2025-12-23 13:22:54 -05:00
Joe Fioti
8564970ad8 multiplicative shape strides 2025-12-23 13:16:42 -05:00
Joe Fioti
4c8368d230 clippy fixes 2025-12-23 10:56:18 -05:00
Joe Fioti
90c96fbe91 cumulative ops added back and topk fixed 2025-12-23 10:46:47 -05:00
Joe Fioti
6788b98d38 changing workflow 2025-12-23 03:15:42 +00:00
Joe Fioti
262b77283f changing workflow 2025-12-23 03:13:48 +00:00
Joe Fioti
5938a06bc5 workflow 2025-12-23 03:09:50 +00:00
Joe Fioti
1fded13426 workflow 2025-12-23 03:06:27 +00:00
Joe Fioti
4fe545709e workflow 2025-12-23 03:03:37 +00:00
Joe Fioti
665a7c47f0 workflow 2025-12-23 03:02:12 +00:00
Joe Fioti
019f5b05de workflow 2025-12-23 03:00:50 +00:00
Joe Fioti
0b44f1a67e workflow 2025-12-23 02:54:54 +00:00
Joe Fioti
aec2f0ca4f workflow 2025-12-23 02:50:27 +00:00
Joe Fioti
9140e012d7 workflow 2025-12-23 02:47:08 +00:00
Joe Fioti
eaf4e350e6 workflow 2025-12-23 02:45:36 +00:00
Joe Fioti
aaec23ea43 workflow 2025-12-23 02:44:56 +00:00
Joe Fioti
0ac28cf9f7 workflow 2025-12-23 02:43:08 +00:00
Joe Fioti
99b15d57bc worflow 2025-12-23 02:39:19 +00:00
Joe Fioti
012dc6a893 worflow 2025-12-23 02:36:19 +00:00
Joe Fioti
709d37b349 debugging cuda cicd 2025-12-22 19:04:18 +00:00
Joe Fioti
b0ca4bf007 debugging cuda cicd 2025-12-22 19:03:38 +00:00
Joe Fioti
dde521d0c8 workflow change 2025-12-22 18:58:35 +00:00
Joe Fioti
da359acf18 changed workflow 2025-12-22 18:56:39 +00:00
Joe Fioti
148a159a9a install rust on runner 2025-12-22 18:55:03 +00:00
Joe Fioti
b7e777f3eb changed workflow 2025-12-22 18:47:47 +00:00
Joe Fioti
3d214c8dc5 changed workflow 2025-12-22 18:43:00 +00:00
Joe Fioti
de42253657 changed workflow 2025-12-22 18:42:06 +00:00
Joe Fioti
c59c6d873e changed workflow 2025-12-22 18:36:57 +00:00
Joe Fioti
06175991af try gpu runner 2025-12-22 18:30:43 +00:00
Joe Fioti
e1d7a630b4 changed workflow 2025-12-22 18:21:08 +00:00
Joe Fioti
e946d7953a added cuda test 2025-12-22 18:16:40 +00:00
Joe Fioti
8304525d44 more tests 2025-12-22 12:30:25 -05:00
Joe Fioti
ba241070ed pad tests and bug fixes 2025-12-22 12:13:53 -05:00
Joe Fioti
44141edce8 clippy fix + cuda fix 2025-12-22 11:26:44 -05:00
Joe Fioti
2c5c6db989 more unit tests 2025-12-22 11:11:23 -05:00
Joe Fioti
fa607d6cef conv tests 2025-12-22 09:09:59 -05:00
Joe Fioti
f2f16f3931 matmul tests 2025-12-22 08:49:54 -05:00
Joe Fioti
180ad53ebf binary tests 2025-12-22 08:47:28 -05:00
Joe Fioti
abc6b4fbdf fixed llama via cuda rewrite rule fix 2025-12-22 03:34:31 +00:00
Joe Fioti
d9161d1e47 more unary tests 2025-12-21 21:29:34 -05:00
Joe Fioti
fc7c69c0ef unary tests and bug fixes 2025-12-21 21:16:55 -05:00
Joe Fioti
7dff991ab0 added native op dialect for semantic definitional checking 2025-12-21 20:18:39 -05:00
Joe Fioti
237497ea9a Merge branch 'main' of https://github.com/luminal-ai/luminal 2025-12-20 17:01:02 -05:00
Joe Fioti
e7a6ca52b3 argsort and topk untested 2025-12-20 17:00:51 -05:00
austin_glover
e05118d3ee ignore visualization outputs 2025-12-20 01:36:02 +00:00
austin_glover
0ef9de30b3 add Fx stuff to prelude 2025-12-20 01:35:17 +00:00
austin_glover
53cc3c22f3 working towards working example 2025-12-20 01:35:07 +00:00
austin_glover
a467e6927f visualization module 2025-12-20 01:34:22 +00:00
austin_glover
472df70919 add egraph_viz_template 2025-12-20 00:39:36 +00:00
austin_glover
a501f08bc5 add visualization module 2025-12-20 00:39:13 +00:00
austin_glover
192b5a544a add graphviz (no graphviz exec) and egraph-serialize [graphviz] to get dot file exports 2025-12-20 00:38:48 +00:00
austin_glover
b09b8f14ea generic error handling 2025-12-19 19:20:07 +00:00
Joe Fioti
b73d476ec5 Merge pull request #187 from luminal-ai/llama-example-setup
Add safetensors loading and combination script for llama example
2025-12-19 14:03:42 -05:00
austin_glover
647ab26736 lint and format 2025-12-19 18:51:27 +00:00
austin_glover
ad1ec37d60 revert name to model_combined for future-proofing 2025-12-19 18:31:51 +00:00
austin_glover
32e7cf3813 :Merge branch 'main' of https://github.com/luminal-ai/luminal into llama-example-setup 2025-12-19 18:07:58 +00:00
Joe Fioti
39fdf43354 changed workflow 2025-12-19 09:06:23 -05:00
Joe Fioti
05309ece1d fmt 2025-12-19 09:05:16 -05:00
Joe Fioti
7453c40d62 added llama / cuda back to workspace 2025-12-19 09:02:58 -05:00
Joe Fioti
0098b1e4e4 fixed barrier stride bug 2025-12-19 03:41:28 +00:00
austin_glover
ebff9a563d get setup script working out of the box 2025-12-19 01:41:37 +00:00
Joe Fioti
2a10f105b8 clippy fix 2025-12-18 18:11:47 +00:00
Joe Fioti
4692788b87 clippy fixes 2025-12-18 18:07:38 +00:00
Joe Fioti
3157a47585 adjust export 2025-12-18 17:23:10 +00:00
Joe Fioti
71584e5f3a export utils 2025-12-18 16:41:42 +00:00
Joe Fioti
0575328fa9 added unfold, slicing and padding back in 2025-12-18 16:08:18 +00:00
Joe Fioti
9074b6ffd2 api cleanup 2025-12-17 18:18:33 +00:00
Joe Fioti
6b2e7dd83f removed much runtime stuff from graph 2025-12-17 15:53:52 +00:00
Joe Fioti
db1de85fe1 more dtypes 2025-12-17 13:29:17 +00:00
Joe Fioti
fa62b4f3a5 removed old embedding code 2025-12-16 23:21:11 +00:00
Joe Fioti
d6692721ef types, iota and gather 2025-12-16 23:13:04 +00:00
Joe Fioti
e17269fcfc improved gather op 2025-12-16 19:29:04 +00:00
Joe Fioti
1d91ea244f updated runtime interface 2025-12-16 19:18:20 +00:00
Joe Fioti
e09598e97b removed logical.rs / merged with hlir ops 2025-12-15 16:28:28 +00:00
Joe Fioti
5df9fdb311 Merge pull request #186 from luminal-ai/next
New IR
2025-12-14 23:20:16 -05:00
Joe Fioti
92e0f259e4 begin updating to 2.0 2025-12-14 18:31:39 +00:00
Joe Fioti
8c7ea89ade changed neg 2025-11-16 18:46:33 +00:00
Joe Fioti
886d9b812e added split dims 2025-11-13 21:55:02 -08:00
Joe Fioti
bb694febc5 better tests 2025-11-13 21:45:47 -08:00
Joe Fioti
4288566c33 added complex strides 2025-11-13 21:35:49 -08:00
Joe Fioti
14f348ecb1 new shapetracker 2025-11-13 21:04:48 -08:00
Joe Fioti
86ebdeb7a4 ceil division 2025-11-06 15:55:21 -08:00
Joe Fioti
4a4c511323 removed demos 2025-11-05 15:15:43 -08:00
Joe Fioti
73daea1a40 Update README.md 2025-11-05 18:15:14 -05:00
Joe Fioti
b60d348af6 Merge pull request #180 from amemov/ShowHN-Fix
Fix egglog compilation error for ShowHN demo
2025-11-04 14:49:07 -05:00
amemov
58ce0722eb - Changed egglog version to 1.0.0 2025-11-03 07:36:31 +00:00
Joe Fioti
dca3c3bdda recip on sigmoid 2025-10-31 10:40:29 -04:00
Anton Shepelev
f1d75b97cf Fix egglog compilation error for ShowHN demo
- Remove unused egglog_proof dependency from MilkBlock fork
- Use egraph-serialize 0.3.0 from crates.io instead of git
- Fix Span namespace conflict between ratatui and egglog
- Add RustSpan and Span imports from egglog::prelude

Fixes compilation error: missing field `extra` in ClassData
The demo now builds successfully with --features cuda or --features metal
2025-10-30 17:41:56 -07:00
Joe Fioti
1ccbafdd32 Merge branch 'main' of https://github.com/luminal-ai/luminal 2025-10-15 12:45:31 -07:00
Joe Fioti
c94bb93c40 updated egglog 2025-10-15 12:45:22 -07:00
Joe Fioti
57e23e5a7d Merge pull request #170 from danielleiszen/feature/adam-optimizer 2025-10-12 21:32:25 -04:00
Dániel Leiszen
8607e10d4d changed back to convergence original values 2025-10-12 20:37:31 +02:00
Dániel Leiszen
fd0b796d78 formatted by fmt 2025-10-12 20:34:47 +02:00
Dániel Leiszen
b6e3894693 Merge branch 'feature/adam-optimizer' of github.com:danielleiszen/luminal into feature/adam-optimizer 2025-10-12 20:14:01 +02:00
Dániel Leiszen
08f118ba83 increased iterations for convergence test not to fail 2025-10-12 20:13:56 +02:00
Dániel Leiszen
a6d785c55c Merge branch 'luminal-ai:main' into feature/adam-optimizer 2025-10-12 20:09:05 +02:00
Dániel Leiszen
bb99be01bf improvement works with nodes instead of tensors 2025-10-12 20:08:18 +02:00
Joe Fioti
f462614f9b Update README.md 2025-10-07 22:20:49 -04:00
Dániel Leiszen
612146afed adam implementation with tests 2025-10-06 23:22:37 +02:00
Joe Fioti
aa0c8a6532 Merge pull request #164 from luminal-ai/match_cublas
cuda fixes
2025-09-21 00:47:53 -04:00
Joe Fioti
629411b514 fixed for metal 2025-09-20 21:45:15 -07:00
Joe Fioti
9a9a5bd9e5 removed txt 2025-09-21 04:38:41 +00:00
Joe Fioti
ff16cd988e cuda fixes 2025-09-21 04:38:22 +00:00
Joe Fioti
0e99c676b5 fixed cargo toml file 2025-09-20 21:03:29 -07:00
Joe Fioti
c6caca1afc removed incomplete yolo and moondream 2025-09-20 17:32:39 -07:00
Joe Fioti
617e3c771c updated to work with cuda 2025-09-19 17:33:17 +00:00
Joe Fioti
5a1ecbad1e updated to work with cuda 2025-09-19 17:32:02 +00:00
Joe Fioti
bccb7b4e27 updated 2025-09-19 09:24:01 -07:00
Joe Fioti
8e481afd9f llama mlp 2025-09-18 20:51:34 -07:00
Joe Fioti
0a514e5fcc more explicit ref tracking 2025-09-18 20:11:20 -07:00
Joe Fioti
cbfa3bccf9 fixed rewrite 2025-09-17 10:07:48 -07:00
Joe Fioti
b32e65652f loop swapping 2025-09-17 07:49:29 -07:00
Joe Fioti
355378ddda loop unrolling 2025-09-17 07:37:37 -07:00
Joe Fioti
8a6afcc28c added tiling 2025-09-17 07:30:47 -07:00
Joe Fioti
2a16edb7eb search refinements, find decent glu 2025-09-16 22:01:54 -07:00
Joe Fioti
7d67297081 added loop merging 2025-09-16 09:58:18 -07:00
Joe Fioti
5130641477 fixed egglog and added kernel timings 2025-09-15 21:51:44 -07:00
Joe Fioti
c6e99c7255 fixed cuda 2025-09-15 00:19:25 +00:00
Joe Fioti
edcc2fbb1c removed metal feature 2025-09-14 17:07:56 -07:00
Joe Fioti
3a8811a1b5 Merge pull request #162 from luminal-ai/match_cublas
Simplifications and bug fixes
2025-09-14 19:08:02 -04:00
Joe Fioti
7b25e910b3 removed loop labels 2025-09-14 16:05:34 -07:00
Joe Fioti
2ec40362b0 a few different bug fixes in extract 2025-09-14 15:58:24 -07:00
Joe Fioti
c5127855ca fixed codegen kernel assignment issue 2025-09-14 10:06:05 -07:00
Joe Fioti
fffd08120d better loop fusion 2025-09-13 10:28:32 -07:00
Joe Fioti
71b35e1fb0 changes 2025-09-11 20:35:53 -07:00
Joe Fioti
2ae52b3570 clean up extraction 2025-09-11 09:06:17 -07:00
Joe Fioti
7856b08a5a polish 2025-09-10 22:52:08 -07:00
Joe Fioti
cbf594bcf9 added egraph visualizer and maybe fixed loop level analysis 2025-09-10 22:46:54 -07:00
Joe Fioti
46a11956b4 removed smem stuff from codegen 2025-09-07 09:07:10 -07:00
Joe Fioti
0dc29128cb Update README.md 2025-09-07 10:48:02 -04:00
Joe Fioti
c980a9fd3a Merge pull request #155 from luminal-ai/match_cublas
GELU working
2025-09-07 01:26:34 -04:00
Joe Fioti
291a780ab9 somplifications 2025-09-06 22:25:55 -07:00
Joe Fioti
adcf4d02db split kernels cleanup 2025-09-06 22:02:58 -07:00
Joe Fioti
a8c18d9f42 gated linear unit layer working 2025-09-06 20:15:33 -07:00
Joe Fioti
4c6986b012 handle disjoints 2025-09-06 18:15:29 -07:00
Joe Fioti
6e0ca785b4 handle disjoints 2025-09-06 18:14:50 -07:00
Joe Fioti
ca86d99f8b Merge pull request #154 from EricHallahan/flash_attention
Repair and enhance flash_attention demo
2025-09-06 18:29:07 -04:00
Eric Hallahan
0bd3b80c3e Repair and enhance flash_attention demo 2025-09-06 18:21:02 -04:00
Joe Fioti
455dc75efa matmul + swish 2025-09-06 14:37:11 -07:00
Joe Fioti
6ff270d3a0 removed communativity 2025-09-06 13:07:02 -07:00
Joe Fioti
33c934e8fc removed egglog.txt 2025-09-06 12:46:35 -07:00
Joe Fioti
3e18df684a swish and self mul fully fused 2025-09-06 12:46:14 -07:00
Joe Fioti
5b658f6355 silu working 2025-09-05 16:55:39 -07:00
Joe Fioti
3f727aa9f7 tc mlp 2025-09-05 12:53:12 -07:00
Joe Fioti
bb02334848 changed debugger 2025-09-04 14:33:28 -07:00
Joe Fioti
a7f6e63170 initial loop level analysis 2025-09-04 09:05:31 -07:00
Joe Fioti
96bb3e5e1c removed ir-generic 2025-09-04 08:19:04 -07:00
Joe Fioti
28c501f9d0 multi matmul 2025-09-03 21:37:14 -07:00
Joe Fioti
7520e71d05 back off scheduler in egglog experimental 2025-09-03 06:04:41 -07:00
Joe Fioti
5a241283b6 Merge pull request #153 from luminal-ai/match_cublas
Added cuda
2025-09-03 00:27:17 -04:00
Joe Fioti
62568054ec Merge branch 'main' into match_cublas 2025-09-03 00:25:18 -04:00
Joe Fioti
ca8394d7af merge 2025-09-02 21:23:35 -07:00
Joe Fioti
63ace83876 merge 2025-09-02 21:20:11 -07:00
Joe Fioti
884d34abeb Merge pull request #152 from matthewjgunton/cuda
Cuda
2025-09-03 00:13:49 -04:00
Joe Fioti
dee4aa3b69 Merge branch 'match_cublas' into cuda 2025-09-03 00:13:42 -04:00
Joe Fioti
106f0a8fa8 changes 2025-09-02 21:12:42 -07:00
Matthew Gunton
3c99f23a8b cuda run successful 2025-09-02 22:51:56 +00:00
Matthew Gunton
d2339084ca merge fix p2 2025-09-02 13:39:22 -07:00
Matthew Gunton
db24774070 fixing merge 2025-09-02 13:26:04 -07:00
Matthew Gunton
7cd8662650 Merge branch 'match_cublas' into cuda 2025-09-02 15:06:46 -05:00
Joe Fioti
217de87959 added demo to workspace 2025-09-02 09:40:18 -07:00
Joe Fioti
961849fda6 Merge pull request #151 from luminal-ai/match_cublas
Cleaned up rewrite rules
2025-09-02 11:12:14 -04:00
Joe Fioti
fa2799a629 fixed symbolic 2025-09-02 08:06:15 -07:00
Joe Fioti
97b221458b refactored features in matmul demo 2025-09-02 07:45:31 -07:00
Joe Fioti
823df2a826 added cleanup 2025-09-01 20:58:39 -07:00
Joe Fioti
d3d5f7f69c removed txt 2025-09-01 20:49:29 -07:00
Joe Fioti
540ba3d0ba cleaned up window functionality a bit 2025-09-01 20:49:19 -07:00
Matthew Gunton
b732016f2e metal dependency fixing 2025-09-01 20:47:07 -07:00
Joe Fioti
444b146987 disabled swaploops 2025-09-01 15:26:07 -07:00
Matthew Gunton
554775e6f7 cuda running base case 2025-09-01 20:53:16 +00:00
Joe Fioti
2661475362 solved loop fusion problem 2025-09-01 13:13:49 -07:00
Joe Fioti
775b6cafb2 fast tc 2025-08-31 22:36:30 -07:00
Joe Fioti
ab0d6226dd fast tc 2025-08-31 22:25:53 -07:00
Joe Fioti
7bd08631ab fast naive matmul 2025-08-31 22:19:49 -07:00
Matthew Gunton
8a5711fdb3 cuda on luminal 2 crate 2025-08-31 00:30:23 +00:00
Matthew Gunton
f7cdff5ed7 Revert "cloning"
This reverts commit 90134c0a7b.
2025-08-30 14:20:46 -07:00
Matthew Gunton
90134c0a7b cloning 2025-08-30 13:58:25 -07:00
Matthew Gunton
edc50b02ad light refactoring 2025-08-30 11:25:54 -07:00
Joe Fioti
0b9e3da0a3 possibly fixed swaploops problem 2025-08-29 16:43:58 -07:00
Joe Fioti
3edf33d80d possibly fixed swaploops problem 2025-08-29 16:42:39 -07:00
Joe Fioti
b6151ec2f5 simplified ui 2025-08-27 14:02:48 -07:00
Joe Fioti
4f5086b457 simplified ui 2025-08-27 11:40:34 -07:00
Joe Fioti
ca1c509261 better debugging zoom 2025-08-27 11:05:51 -07:00
Joe Fioti
1bcae63f0b added debugger 2025-08-26 23:17:23 -07:00
Joe Fioti
7c259f4def Merge pull request #148 from abeleinin/print-empty
Safely print empty tensors
2025-08-24 23:33:42 -04:00
aleinin
1d63c86f6c Safely print empty tensors 2025-08-24 14:31:33 -07:00
Joe Fioti
0e83fa91ed clean up 2025-08-23 20:28:15 -07:00
Joe Fioti
4aad8acab4 clean up 2025-08-23 20:27:59 -07:00
Joe Fioti
c8373bc6ed merge 2025-08-23 09:47:13 -07:00
Joe Fioti
2ccefef1f0 upgrade to new metal 2025-08-23 09:46:22 -07:00
Matthew Gunton
1fcaf70272 Merge pull request #147 from matthewjgunton/main
deduplicate logic added to extract
2025-08-22 20:45:16 -05:00
Matthew Gunton
5c12a76318 Merge branch 'match_cublas' into main 2025-08-22 20:45:01 -05:00
Matthew Gunton
3954ecc168 removing excess println 2025-08-22 18:43:20 -07:00
Matthew Gunton
8c1f41058a removing print 2025-08-22 18:42:34 -07:00
Matthew Gunton
e9a0d8353b deduplicate logic added to extract 2025-08-22 18:41:24 -07:00
Joe Fioti
bd4357e281 updated metal 2025-08-21 14:36:28 -07:00
Joe Fioti
01cede3124 fixed showhn 2025-08-21 14:35:12 -07:00
Joe Fioti
1f998aa4b9 clean 2025-08-20 23:14:52 -07:00
Joe Fioti
4bd28878d6 Merge remote-tracking branch 'origin/main' into match_cublas 2025-08-20 22:43:24 -07:00
Joe Fioti
421270a822 Merge pull request #142 from IanBoyanZhang/main
Typo fixing and variable renaming only
2025-08-21 01:43:07 -04:00
Joe Fioti
35b6d05b93 merge 2025-08-20 22:41:42 -07:00
Joe Fioti
fd7962dbc9 small 2025-08-20 22:40:16 -07:00
Matthew Gunton
207d3b9253 Merge pull request #143 from matthewjgunton/main
merge rule re-added + greedy protection in extract added
2025-08-21 00:29:47 -05:00
Matthew Gunton
76ffbcdd84 merge rule re-added + greedy protection in extract added 2025-08-20 22:25:09 -07:00
Ian Zhang
4741ea9861 Fix typos in crates luminal_2 src codegen.rs 2025-08-20 21:27:57 -07:00
Ian Zhang
deecb40d21 Update search.rs
Fix typos in search.rs
2025-08-20 21:26:16 -07:00
Joe Fioti
781f6cb925 changed ui 2025-08-20 19:18:39 -07:00
Joe Fioti
647fc4cadd fix ui 2025-08-20 19:14:59 -07:00
Joe Fioti
5ff583d657 rm 2025-08-20 19:02:14 -07:00
Joe Fioti
d0ef1f7034 Update README.md 2025-08-20 02:03:02 -04:00
Joe Fioti
c9c8adb5b4 Update README.md 2025-08-20 01:58:54 -04:00
Joe Fioti
cf8509ff3b Update README.md 2025-08-20 01:50:25 -04:00
Joe Fioti
ebf14c6962 Merge pull request #140 from luminal-ai/llama_chunking
Llama chunking
2025-08-20 01:46:58 -04:00
Joe Fioti
7a5fb61199 workspace change 2025-08-19 22:43:54 -07:00
Joe Fioti
352128e066 fmt 2025-08-19 22:35:17 -07:00
Joe Fioti
d177f698de Merge branch 'main' into llama_chunking 2025-08-20 01:34:12 -04:00
Joe Fioti
d5c84c1a86 merge 2025-08-19 22:33:05 -07:00
Joe Fioti
4f3702bfa0 merge 2025-08-19 22:30:46 -07:00
Joe Fioti
3857950e7e moved matmul to demos folder 2025-08-19 21:39:17 -07:00
Joe Fioti
b2b4a8de9b cleanups 2025-08-19 14:44:38 -07:00
Joe Fioti
2e5cc0367a merge 2025-08-18 17:21:17 -07:00
Joe Fioti
6826e9b248 fixed display 2025-08-18 17:07:07 -07:00
Joe Fioti
27de8c17f3 Merge branch 'llama_chunking' of https://github.com/jafioti/luminal into HEAD 2025-08-18 16:43:49 -07:00
Joe Fioti
b804383ed5 frontend to tensor core 2025-08-18 16:43:09 -07:00
Joe Fioti
6b126a0733 naive to tc 2025-08-18 16:09:31 -07:00
Joe Fioti
eb63c3c44b broken tc rule 2025-08-18 13:13:45 -07:00
Joe Fioti
ec730d4374 naive to fused working 2025-08-18 08:46:07 -07:00
Joe Fioti
77a3bfa730 fast tensor core matches from naive matmul 2025-08-17 22:11:33 -07:00
Joe Fioti
2173531763 tc matching? 2025-08-17 20:57:28 -07:00
Joe Fioti
7d40d58a86 tc search working 1 step 2025-08-17 10:40:23 -07:00
Joe Fioti
b332538c59 huge tc rewrite 2025-08-17 09:01:03 -07:00
Joe Fioti
ec9da07f82 tc codegen working 2025-08-16 21:46:11 -04:00
Matthew Gunton
c8d58e89d8 Merge pull request #139 from matthewjgunton/llama_chunking
extraction & codegen unit tests
2025-08-16 14:30:32 -05:00
Matthew Gunton
4ae2f12815 codegen tests added 2025-08-16 12:29:31 -07:00
Matthew Gunton
a3eaf48e04 extraction tests (avoiding TLS) 2025-08-16 11:48:49 -07:00
Matthew Gunton
79e385c9e2 Merge pull request #138 from matthewjgunton/llama_chunking
graph break now only removes ancestors from orig_to_subgraph_node_map
2025-08-15 19:36:32 -05:00
Matthew Gunton
6586ddad6c removing unnecessary printlns 2025-08-15 17:35:06 -07:00
Matthew Gunton
92d04ceed0 graph break now only removes ancestors from orig_to_subgraph_node_map 2025-08-15 17:33:50 -07:00
Matthew Gunton
d8246c5fde Merge pull request #137 from matthewjgunton/llama_chunking
unit tests added for translate
2025-08-15 15:58:15 -05:00
Matthew Gunton
52aed4f431 end to end testing here 2025-08-15 13:47:45 -07:00
Matthew Gunton
f59b0fa089 refactored out the helper functions from search 2025-08-15 11:38:15 -07:00
Matthew Gunton
072e39e7e8 unit tests added for translate 2025-08-15 11:22:02 -07:00
Joe Fioti
88dedff7f2 clippy 2025-08-14 21:55:20 -04:00
Joe Fioti
fdeb011eea Update README.md 2025-08-14 21:26:24 -04:00
Joe Fioti
31c9caf5a6 Update README.md 2025-08-14 21:19:56 -04:00
Joe Fioti
be85bf5dbe removed scope in simple 2025-08-14 15:50:30 -04:00
Joe Fioti
32e5e81591 merge 2025-08-14 15:46:33 -04:00
Joe Fioti
b65aba9153 naive matmul with only runtime signal 2025-08-14 15:43:21 -04:00
Joe Fioti
5348150f22 Merge pull request #136 from matthewjgunton/llama_chunking
refactoring for better clarity
2025-08-14 13:49:51 -05:00
Matthew Gunton
53de6ee5b3 Merge branch 'llama_chunking' into llama_chunking 2025-08-14 13:49:34 -05:00
Matthew Gunton
dac589c17a putting search into run 2025-08-14 11:16:41 -07:00
Joe Fioti
4157526890 simplified output indexes 2025-08-13 21:21:39 -04:00
Joe Fioti
ae46ced304 2-level tiling for matmul fusion 2025-08-13 21:05:05 -04:00
Matthew Gunton
f422616636 refactoring for better clarity 2025-08-13 11:11:04 -07:00
Joe Fioti
c6c7797148 fixed simple example 2025-08-12 22:32:24 -07:00
Joe Fioti
44e70dffde Merge branch 'llama_chunking' of https://github.com/jafioti/luminal into llama_chunking 2025-08-12 22:20:11 -07:00
Joe Fioti
d6ac3f7a83 altered rules, merging not robust yet 2025-08-12 22:20:01 -07:00
Joe Fioti
70027b4691 Merge pull request #135 from matthewjgunton/llama_chunking
refactoring translate function
2025-08-13 00:18:50 -05:00
Matthew Gunton
9912d01160 refactoring translate function 2025-08-12 22:14:11 -07:00
Joe Fioti
6562848a23 matmul with no invalid kernels generated 2025-08-11 10:00:26 -07:00
Joe Fioti
27d300fffd updated egglog 2025-08-10 22:10:47 -07:00
Joe Fioti
d4420b8e5f fused matmulgit add . 2025-08-09 23:19:53 -07:00
Joe Fioti
903e8ca7a6 handle thread buffers better 2025-08-09 23:16:51 -07:00
Joe Fioti
51062ba74d removed generic propogation for tiling and added custom 2025-08-09 21:17:24 -07:00
Joe Fioti
97fd41b338 fixed extractor 2025-08-09 17:52:37 -07:00
Joe Fioti
7e0f5cc227 fixed rewrite rules but removed fusion, need to add back 2025-08-09 15:14:27 -07:00
Joe Fioti
454e520490 simplified extraction fn, still broken? 2025-08-09 11:23:16 -07:00
Joe Fioti
97773d77fc fixed searching, pruned 2025-08-08 18:23:07 -07:00
Joe Fioti
56da21a661 searching is correct 2025-08-08 13:09:55 -07:00
Joe Fioti
b600cc5e0d chunked searching 2025-08-07 22:52:47 -07:00
Joe Fioti
6b8053c210 fixed chunking 2025-08-07 21:27:38 -07:00
Joe Fioti
bd6afb8cfc fixed scope-out-in in translate 2025-08-07 21:15:23 -07:00
Joe Fioti
944dd01971 added metal back to luminal_2 and fixed matmul 2025-08-07 20:54:42 -07:00
Joe Fioti
f06d4ad138 fixed acc stitching 2025-08-07 17:06:57 -07:00
Joe Fioti
0fec01bf43 stiched accs together 2025-08-07 14:34:22 -07:00
Joe Fioti
5418817301 Update README.md 2025-08-06 02:17:11 -04:00
Joe Fioti
48a9fa324d Update README.md 2025-08-06 02:15:24 -04:00
Joe Fioti
d8cb189351 Merge pull request #134 from kstonekuan/main
docs: update graphtensor example and CI link
2025-08-05 11:56:40 -05:00
Kingston
a6dba9e8f7 Update README.md 2025-08-05 20:58:09 +08:00
Kingston
4f0ed828e6 Update graphtensor.mdx 2025-08-05 20:57:31 +08:00
Joe Fioti
423b2ffa8c started adding chunking code 2025-07-30 17:19:34 +00:00
Joe Fioti
e4b18fe5ab opt 2025-07-30 05:31:31 +00:00
Joe Fioti
7e872ad8af optimized llama furthur 2025-07-30 02:43:13 +00:00
Joe Fioti
a0697b201b added custom attn kernel 2025-07-29 21:18:59 +00:00
Joe Fioti
12c1893c13 rope kernel 2025-07-29 18:27:44 +00:00
Joe Fioti
66775a03be tweaks 2025-07-29 05:40:07 +00:00
Joe Fioti
b051f10302 hybrid llama opts 2025-07-28 22:00:23 +00:00
Joe Fioti
5f62412e7e fixed input ordering on mixed kernels 2025-07-28 18:46:57 +00:00
Joe Fioti
89a9b04f60 custom lm head kernel 2025-07-28 17:08:24 +00:00
Joe Fioti
7f7ef1de42 Tweaks 2025-07-28 06:06:02 +00:00
Joe Fioti
8ad99caca1 static memory buffer management 2025-07-28 06:01:18 +00:00
Joe Fioti
a53dddbac0 Merge branch 'main' of https://github.com/jafioti/luminal 2025-07-27 21:25:33 -07:00
Joe Fioti
57c266e783 website update 2025-07-27 21:25:24 -07:00
Joe Fioti
17eab5fd8d Merge branch 'main' of https://github.com/jafioti/luminal 2025-07-28 04:21:14 +00:00
Joe Fioti
2fb35a6420 favicon 2025-07-27 21:21:09 -07:00
Joe Fioti
3cf26e5239 favicon 2025-07-28 04:10:46 +00:00
Joe Fioti
dc19fdd2ab optimized codegen 2025-07-28 03:37:49 +00:00
Joe Fioti
668e678882 optimized translate 2025-07-27 20:43:12 +00:00
Joe Fioti
9001d2ad11 greatly simplified luminal 2 execution 2025-07-27 19:39:36 +00:00
Joe Fioti
f63d79ee1d cached model weights 2025-07-27 16:52:07 +00:00
Joe Fioti
1fedf03098 convert luminal 2 to cuda 2025-07-26 22:33:39 +00:00
Joe Fioti
a5c6afa6b0 format 2025-07-26 17:38:19 +00:00
Joe Fioti
1c4ea9bdb7 got cuda tests running 2025-07-26 17:36:15 +00:00
Joe Fioti
c6b06a2130 tried to update cuda version, still failing on lambda 2025-07-26 05:43:24 +00:00
Joe Fioti
05fe3d64e8 simplifications 2025-07-25 21:18:57 -07:00
Joe Fioti
7a61edb23c removed bins 2025-07-25 18:10:45 -07:00
Joe Fioti
8b7fc39b8b luminal 2 llama working 2025-07-25 18:10:16 -07:00
Joe Fioti
8be90191b1 progress? 2025-07-24 18:40:28 -07:00
Joe Fioti
9cf0eec4f8 got luminal 1.0 working on llama 2025-07-24 13:42:38 -07:00
Joe Fioti
6f71ab1450 refactored to translate to a single grid dim, works on funky kv cache sizes now 2025-07-23 22:57:21 -07:00
Joe Fioti
f322c862ad removed bin files 2025-07-23 14:05:39 -07:00
Joe Fioti
5012a86249 llama with different kv cache sizes works + diff2 added 2025-07-23 14:05:15 -07:00
Joe Fioti
a1e752d7cc Merge pull request #129 from above-avg/patch-1 2025-07-23 09:18:40 -05:00
Pallab
8862e254fa Fix broken PyTorch Dynamo documentation link
The original link led to a 404. Replaced with the updated Dynamo overview from pytorch.org.
2025-07-23 19:43:53 +05:30
Joe Fioti
3e566a2e37 Merge pull request #128 from MilkBlock/fix-colored-string-panic 2025-07-22 09:16:03 -05:00
Joe Fioti
88316762ec Update README.md 2025-07-22 09:14:32 -05:00
MilkBlock
ba692ecfe0 fix panic when printing emoji 2025-07-22 18:32:15 +08:00
Joe Fioti
752d3e2401 symbolic inputs / outputs 2025-07-21 22:24:02 -07:00
Joe Fioti
97d5e05820 forward pass with kv out 2025-07-21 12:39:52 -07:00
Joe Fioti
89d235cf2b changed codegen for multiple outputs 2025-07-20 20:59:30 -07:00
Joe Fioti
cde6e0e55a llama forward prefill multi-token confirmed 2025-07-20 17:26:25 -07:00
Joe Fioti
eff6a677a0 website change 2025-07-20 11:13:52 -07:00
Joe Fioti
a1393db02a fixed metal gather 2025-07-19 10:36:25 -07:00
Joe Fioti
c415a1021c added custom kernel 2025-07-18 21:50:03 -07:00
Joe Fioti
5aee772092 fixed simple 2025-07-18 14:14:02 -07:00
Joe Fioti
44caf80f23 llama 2025-07-18 13:07:51 -07:00
Joe Fioti
6450de534c updates to luminal 2 2025-07-17 14:36:47 -07:00
Joe Fioti
67166cac81 tweak 2025-07-17 10:42:55 -07:00
Joe Fioti
d6901cc018 Merge pull request #123 from anuragsingh-tt/anuhsing/minor_fix 2025-07-16 21:15:08 -05:00
Your Name
9c2754d42d minor fix to flash attention demo 2025-07-16 15:46:04 -05:00
Joe Fioti
73299b535e Merge pull request #122 from matthewjgunton/main
adding github stars to homepage to make clear we're OSS
2025-07-15 16:24:27 -05:00
Matthew Gunton
fbbe4e1abb adding github stars to homepage 2025-07-15 16:21:30 -05:00
Joe Fioti
bd33460c96 remove 2 from workspace 2025-07-14 13:19:53 -07:00
Joe Fioti
5e9e811de9 merge 2025-07-14 13:16:50 -07:00
Joe Fioti
cbd2104dc5 started merging search into main project 2025-07-14 13:16:16 -07:00
Joe Fioti
a550d003f2 Merge pull request #119 from jafioti/codex/fix-clippy-warnings 2025-07-12 09:38:05 -05:00
Joe Fioti
0bdecf6d92 fix clippy warnings 2025-07-12 09:36:04 -05:00
Joe Fioti
cf8eb769fe clippy fix 2025-07-11 21:01:42 -07:00
Joe Fioti
63d950d106 Merge pull request #118 from jafioti/next
updates
2025-07-11 21:03:09 -05:00
Joe Fioti
73af2e22e8 updates 2025-07-11 19:02:21 -07:00
Joe Fioti
c0cf78f36f Merge pull request #117 from jafioti/next
Search Compiler
2025-07-11 20:57:56 -05:00
Joe Fioti
6ec4439714 cleanups 2025-07-11 08:20:05 -07:00
Joe Fioti
d2e34a3093 softmax layernorm 2025-07-10 22:09:58 -07:00
Joe Fioti
e9be274c48 validated tiled matmul through search 2025-07-10 19:49:12 -07:00
Joe Fioti
c309cb7b12 fully tiled matmul 2025-07-10 17:09:07 -07:00
Joe Fioti
523d3021fb Update README.md 2025-07-10 14:38:57 -04:00
Joe Fioti
49498aa923 Update README.md 2025-07-10 14:35:23 -04:00
Joe Fioti
952b45e231 Update README.md 2025-07-10 14:25:10 -04:00
Joe Fioti
ea53fa922d proper acc loading and saving 2025-07-10 10:19:51 -07:00
Joe Fioti
26ae0c3654 new extraction and fusion works 2025-07-10 07:53:15 -07:00
Joe Fioti
555960fa22 tests passing'
:
2025-07-09 09:35:07 -07:00
Joe Fioti
1efd2dd4f8 naive attn 2025-07-08 20:42:31 -07:00
Joe Fioti
dd8c34b035 transposed matmul 2025-07-08 19:35:38 -07:00
Joe Fioti
e2d3f2beb0 more extensive search on square matmuls 2025-07-08 18:22:29 -07:00
Joe Fioti
e334dbabe4 kernel fusion 2025-07-08 15:36:42 -07:00
Joe Fioti
83736c4886 naive attn 2025-07-08 14:25:33 -07:00
Joe Fioti
ab790153ce Merge pull request #116 from matthewjgunton/next
Graph Transcription v1 -> v2
2025-07-08 15:45:52 -05:00
Matthew Gunton
f4c17b0f1c kernel demo prep 2025-07-08 11:35:25 -07:00
Matthew Gunton
c5cf97ffbe Merge branch 'jafioti:next' into next 2025-07-08 12:28:55 -05:00
Joe Fioti
5a8903c5f3 returned best graph from search 2025-07-08 10:28:16 -07:00
Matthew Gunton
6137cd3efd Merge branch 'jafioti:next' into next 2025-07-08 11:04:12 -05:00
Joe Fioti
177beb108e fixed 2025-07-08 09:02:23 -07:00
Joe Fioti
a22f8a73a1 change tests 2025-07-08 08:57:51 -07:00
Matthew Gunton
be2c70c335 Merge branch 'jafioti:next' into next 2025-07-08 00:14:45 -05:00
Matthew Gunton
51172d9324 sum implemented in tests, ready for kernel testing 2025-07-07 15:58:05 -07:00
Joe Fioti
8d00d7628b matmul tiling rule 2025-07-07 14:55:34 -04:00
Matthew Gunton
af9214b416 striding and unit tests put in, pending sum reduce conf 2025-07-07 11:42:14 -07:00
Matthew Gunton
6b12fda15e reached the llama torch! 2025-07-05 18:14:41 -07:00
Matthew Gunton
30a075bc06 striding on loopins done, moving to accumulators 2025-07-05 17:35:24 -07:00
Joe Fioti
596a2240f2 Merge branch 'next' of https://github.com/jafioti/luminal into next 2025-07-05 19:42:08 -04:00
Joe Fioti
577d5f141f revamped accumulator system 2025-07-05 19:42:03 -04:00
Joe Fioti
d8f4201929 clippy 2025-07-05 18:44:39 -04:00
Joe Fioti
03a927a174 website change 2025-07-05 18:34:36 -04:00
Joe Fioti
fff1672936 website change 2025-07-05 18:30:02 -04:00
Joe Fioti
8912478ec1 Merge pull request #115 from matthewjgunton/one_commit 2025-07-05 14:09:17 -05:00
Matthew Gunton
8e30b0f939 Merge branch 'jafioti:next' into next 2025-07-05 13:10:07 -05:00
Matthew Gunton
3e3c721457 binary ops added 2025-07-04 13:34:15 -07:00
Matthew Gunton
89dcc519a8 added max reduce, removed looping issue 2025-07-04 12:59:05 -07:00
Matthew Gunton
fd63b22378 sum reduce 2025-07-04 11:06:22 -07:00
Matthew Gunton
4fe67b2a29 basic unary & binary translations -- looking into disjoint scaling issue 2025-07-04 11:06:22 -07:00
Joe Fioti
976b9888a0 tweaks 2025-07-04 10:48:52 -04:00
Joe Fioti
1daa82aab0 naive matmul searching 2025-07-04 07:39:12 -04:00
Joe Fioti
537a9ab161 added termdag -> egglog fn 2025-07-03 14:59:03 -07:00
Joe Fioti
602b25cc20 added loop splitting and merging again 2025-07-03 09:24:11 -07:00
Joe Fioti
86fe320fcb more correctness 2025-07-02 21:43:23 -07:00
Joe Fioti
59d3d656e4 correctness 2025-07-02 08:27:15 -07:00
Joe Fioti
61acd7ab0d correctness checks 2025-07-02 08:25:54 -07:00
Joe Fioti
71ea2d0693 fixed kernel splitting: 2025-07-01 20:43:04 -07:00
Joe Fioti
2e65c6f8be search matmul 2025-06-30 21:38:19 -07:00
Joe Fioti
1580491e29 loop level bug fixes 2025-06-30 21:26:41 -07:00
Joe Fioti
9dea6a442e Merge branch 'next' of https://github.com/jafioti/luminal into next 2025-06-29 22:09:05 -07:00
Joe Fioti
1d644dcbc1 more complex kernel search 2025-06-29 22:08:58 -07:00
Matthew Gunton
5cf9449698 Merge branch 'jafioti:next' into next 2025-06-29 23:29:37 -05:00
Joe Fioti
00d8dbfe31 Merge pull request #113 from matthewjgunton/next
basic quiz made for easier onboarding to code gen
2025-06-29 23:29:14 -05:00
Matthew Gunton
a6072212f7 egg quiz 2025-06-29 21:29:05 -07:00
Joe Fioti
2d45b714a5 fixed search 2025-06-29 21:05:36 -07:00
Matthew Gunton
bf7b8faba1 Merge branch 'jafioti:next' into next 2025-06-29 17:37:43 -05:00
Joe Fioti
ea381f9088 fix 2025-06-29 15:37:04 -07:00
Matthew Gunton
c72dacacfd Merge branch 'jafioti:next' into next 2025-06-29 17:34:03 -05:00
Joe Fioti
fc0343918d . 2025-06-29 15:32:39 -07:00
Matthew Gunton
4423661fe9 adding html 2025-06-29 13:42:07 -07:00
Matthew Gunton
3630c538ef basic quiz made for easier onboarding to code gen 2025-06-29 13:41:06 -07:00
Joe Fioti
a0d648594f serrach 2025-06-29 13:06:54 -07:00
Joe Fioti
2f83a25ef8 kernel search 2025-06-28 23:00:21 -07:00
Joe Fioti
3d5b8959a7 updated extraction 2025-06-28 16:15:54 -07:00
Joe Fioti
0f3f0c52e7 kernel search, possibly broken rewrite rules or codegen? 2025-06-26 22:41:23 -05:00
Joe Fioti
2be18f3d98 Merge pull request #111 from matthewjgunton/main
Update index.html
2025-06-23 11:00:48 -05:00
Joe Fioti
85edb92a89 merge 2025-06-23 10:57:01 -05:00
Joe Fioti
fb42bf6c02 Merge pull request #112 from jss8649/jake
Fixed typo on company website.
2025-06-23 10:55:34 -05:00
Joe Fioti
9fcf133d57 Merge branch 'main' into jake 2025-06-23 10:55:19 -05:00
Jake Stevens
24760ba120 Fixed typos on company website. 2025-06-23 10:52:53 -05:00
Jake Stevens
11a388e7cc Fixed typo on the company website. 2025-06-23 10:46:46 -05:00
Joe Fioti
70c427dc71 realized time difference between tiled and untiled matmul 2025-06-23 10:07:01 -05:00
Matthew Gunton
002cfeb4ce Update index.html 2025-06-23 08:25:41 -05:00
Joe Fioti
127a25b6f4 kernels unit tested! refactored smem caching and removed zerostrideload 2025-06-23 00:05:57 -05:00
Joe Fioti
acf2ea59c6 naive attention test passes 2025-06-22 09:37:00 -05:00
Joe Fioti
c4e0a4871c flash attention passes unit test, naive does not 2025-06-21 22:57:43 -05:00
Joe Fioti
ae9233829f added symbolic expressions 2025-06-20 10:06:26 -05:00
Joe Fioti
be8caae06d Merge branch 'next' of https://github.com/jafioti/luminal into next 2025-06-19 22:40:12 -05:00
Joe Fioti
072394c341 merge 2025-06-19 22:40:10 -05:00
Joe Fioti
39259d7126 unit tests working 2025-06-19 22:39:34 -05:00
Joe Fioti
041092b0ea Merge pull request #110 from jafioti/main
merge changes
2025-06-19 11:14:32 -05:00
Joe Fioti
ee4f8ba5a3 fixed egglog dep 2025-06-19 11:09:02 -05:00
Joe Fioti
486925da01 simplified spacing 2025-06-19 10:59:40 -05:00
Joe Fioti
ad4ea3633a cleanups 2025-06-19 10:46:51 -05:00
Joe Fioti
355f1b7fce codegen fixed, flash and naive work now 2025-06-18 20:31:07 -05:00
Joe Fioti
9a23bbd3bd flash attention codegen working 2025-06-18 17:23:05 -05:00
Joe Fioti
fcbf03add6 naive attention codegens correctly 2025-06-17 16:25:09 -05:00
Joe Fioti
f72b92e4f3 complex kernel: 2025-06-16 14:28:34 -05:00
Joe Fioti
6c90b412c9 added first rendition of flash attention 2025-06-15 08:11:03 -05:00
Joe Fioti
4bb8477002 added first rendition of naive attention termdag 2025-06-14 21:55:40 -05:00
Joe Fioti
02e6ab96ff Merge branch 'next' of https://github.com/jafioti/luminal into next 2025-06-13 21:51:49 -05:00
Joe Fioti
f27df3ddf3 smem tiled complete i think 2025-06-13 21:51:44 -05:00
Joe Fioti
7399cf3f5a Merge pull request #109 from matthewjgunton/next 2025-06-12 12:30:27 -05:00
Matthew Gunton
25461ea5bb renaming split factor to tile factor 2025-06-12 11:40:06 -05:00
Joe Fioti
8baff10811 Merge pull request #108 from matthewjgunton/main 2025-06-12 11:08:19 -05:00
Matthew Gunton
f4497242c4 Website Discord Link Wont Expire 2025-06-12 10:51:09 -05:00
Joe Fioti
26c7cf7d0c Merge pull request #107 from jafioti/codex/add-normalize-function
Implement normalize op
2025-06-10 22:10:17 -05:00
Joe Fioti
2ff3dffa95 merge 2025-06-10 22:08:07 -05:00
Joe Fioti
5057689ecb Merge remote-tracking branch 'origin/main' into codex/add-normalize-function 2025-06-10 22:07:42 -05:00
Joe Fioti
59dd0e2be6 removed comment 2025-06-10 22:05:10 -05:00
Joe Fioti
2602a5676c Reorder normalize args 2025-06-10 22:01:49 -05:00
Joe Fioti
17f8a49185 Merge pull request #106 from jafioti/3skfwm-codex/rename-expand-to-expand_dim-and-add-pytorch-style-expand
Implement PyTorch-style expand
2025-06-10 21:56:29 -05:00
Joe Fioti
e827571ba1 fix 2025-06-10 21:54:41 -05:00
Joe Fioti
e32a744739 Rename expand to expand_dim and add pytorch-style expand 2025-06-10 21:47:41 -05:00
Joe Fioti
c07b3a0f69 Merge pull request #104 from jafioti/9k8z5x-codex/review-tests-for-coverage-improvements 2025-06-10 17:36:46 -05:00
Joe Fioti
eab1260d56 Silence clippy warnings 2025-06-10 17:34:52 -05:00
Joe Fioti
71745f08e3 added agents.md 2025-06-10 17:09:19 -05:00
Joe Fioti
6dcb4dde02 Merge pull request #102 from jafioti/umhowy-codex/review-tests-for-coverage-improvements 2025-06-10 17:02:10 -05:00
Joe Fioti
35892ecba5 Run rustfmt 2025-06-10 16:59:34 -05:00
Joe Fioti
d5bccfb503 Merge pull request #100 from jafioti/codex/find-mismatched-or-missing-pytorch-apis-in-hl_ops
Fix leftover recip uses
2025-06-10 15:27:14 -05:00
Joe Fioti
469c7a15f0 fixed tests 2025-06-10 15:24:48 -05:00
Joe Fioti
038df8109c fixed formatting 2025-06-10 15:18:26 -05:00
Joe Fioti
57e8e68973 Fix leftover recip usage 2025-06-10 15:15:39 -05:00
Joe Fioti
51d652e1eb Merge pull request #99 from matthewjgunton/main
sliding ui bug
2025-06-10 14:12:06 -05:00
Matthew Gunton
60bcbfc792 sliding ui bug 2025-06-10 14:11:11 -05:00
Joe Fioti
ca88789df2 Merge pull request #98 from jafioti/codex/define-ops-using-metal_binary_op-macro
Use metal_binary_op macro for MetalSub and MetalEqual
2025-06-10 11:57:40 -05:00
Joe Fioti
0c038ee56b fix 2025-06-10 11:57:13 -05:00
Joe Fioti
f5bbeb331e Use metal_binary_op macro for binary ops 2025-06-10 11:51:38 -05:00
Joe Fioti
46ab40e82a Merge pull request #97 from jafioti/codex/simplify-luminal_metal-crate
Introduce macro for Metal binary ops
2025-06-10 11:36:14 -05:00
Joe Fioti
6bf42da03c fixed 2025-06-10 11:35:37 -05:00
Joe Fioti
23b199c6f3 Add metal_binary_op macro and refactor binary ops 2025-06-10 11:28:28 -05:00
Joe Fioti
482b8771b3 fix warnings 2025-06-10 11:27:12 -05:00
Joe Fioti
6b41630179 remove unused 2025-06-10 11:22:49 -05:00
Joe Fioti
e27109b63c Merge pull request #96 from jafioti/codex/remove-unused-code-in-src/tests
Remove unused harness module
2025-06-10 11:21:10 -05:00
Joe Fioti
344b599bca Remove unused test harness 2025-06-10 11:20:05 -05:00
Joe Fioti
6e050631df Merge pull request #95 from jafioti/codex/create-test-for-execute_no_delete
Add execute_no_delete test
2025-06-10 11:12:35 -05:00
Joe Fioti
79de7b71e5 move execute_no_delete test 2025-06-10 11:10:39 -05:00
Joe Fioti
d0083a0607 Merge pull request #94 from jafioti/codex/find-and-fix-a-bug-in-codebase
Fix tolerance for floating point equality
2025-06-10 11:02:30 -05:00
Joe Fioti
57edd100e0 Implement tolerance in CPU equality 2025-06-10 11:00:26 -05:00
Joe Fioti
6fa5aaccdb changed website 2025-06-10 10:16:57 -05:00
Joe Fioti
e32ef05e5d changed smem tiling 2025-06-09 23:20:06 -05:00
Jacob Stevens
35db03f9ad test 2025-06-09 15:24:09 -05:00
Joe Fioti
9dab0f5466 Merge pull request #92 from matthewjgunton/main
unsqueeze function added + arange
2025-06-09 13:23:57 -05:00
Matthew Gunton
d504dfaf21 moving unsqueeze to expand ratehr than reshape 2025-06-09 13:07:06 -05:00
Matthew Gunton
f350e92952 fixing vec compiler issue 2025-06-09 11:57:35 -05:00
Matthew Gunton
81eeb92f4d added arange helper functions 2025-06-09 11:54:20 -05:00
Matthew Gunton
3a8fbc8a70 Merge branch 'jafioti:main' into main 2025-06-09 11:30:52 -05:00
Matthew Gunton
3147ad793e added unsqueeze 2025-06-09 11:30:24 -05:00
Joe Fioti
dd2ee45b2c Merge pull request #91 from matthewjgunton/main
Added in transpose functionality
2025-06-09 11:02:35 -05:00
Matthew Gunton
23af621600 Transpose only allows 2D 2025-06-09 11:01:48 -05:00
Matthew Gunton
e04065de31 Merge branch 'jafioti:main' into main 2025-06-09 10:46:57 -05:00
Matthew Gunton
e7f3434fb7 added in transpose 2025-06-09 10:46:40 -05:00
Joe Fioti
be1f2c5d60 tiled matmul no smem 2025-06-07 09:27:42 -07:00
Joe Fioti
f91d50f67d added' 2025-06-04 14:42:15 -07:00
Joe Fioti
9a2eed2dab redirect 2025-06-04 14:40:10 -07:00
Joe Fioti
3ad169ed86 fmt 2025-06-03 10:54:31 -07:00
Joe Fioti
920ed56a5d Update README.md 2025-06-03 13:51:53 -04:00
Joe Fioti
add2a7520a Merge pull request #90 from matthewjgunton/main
email fix
2025-06-03 12:49:22 -05:00
Matthew Gunton
bfc80bee4e email fix 2025-06-02 19:11:30 -05:00
Joe Fioti
5b1a1001de added co website 2025-06-02 15:48:43 -05:00
Joe Fioti
56621e294f Merge pull request #89 from matthewjgunton/main
Company Website
2025-06-02 15:32:31 -05:00
Matthew Gunton
b0c8ceca9f Add files via upload 2025-06-02 15:01:59 -05:00
Matthew Gunton
a67742d203 update company link 2025-06-02 15:01:13 -05:00
Joe Fioti
9760d335f1 changed company website 2025-06-02 14:51:02 -05:00
Joe Fioti
e89907b42a naive matmul 2025-05-31 13:57:27 -05:00
Joe Fioti
8bb6d5ba1c added graph validation 2025-05-31 09:14:25 -05:00
Joe Fioti
a94f9ba13c Merge pull request #86 from abeleinin/fix-typos
fix typos/formatting in graphtensor + update ops to 12 in docs
2025-05-29 05:55:33 -07:00
aleinin
bf2fbb2f4f fix typos/formatting in graphtensor + update ops to 12 2025-05-28 23:17:01 -05:00
Joe Fioti
ba25f6fff6 added accumulators 2025-05-28 11:02:09 -05:00
Joe Fioti
c80666c0da simple kernels codegen correctly 2025-05-26 21:59:40 -05:00
Joe Fioti
7831a112d3 Merge pull request #84 from matthewjgunton/diff_avgpool
avg pooling + diff has atol and rtol
2025-05-24 20:59:24 -07:00
Matthew Gunton
ae5609c146 avg pooling + diff has atol and rtol 2025-05-24 22:33:53 -05:00
Joe Fioti
a525933737 merge 2025-05-23 23:14:47 -05:00
Joe Fioti
980b7a9148 merge with main 2025-05-23 16:54:09 -05:00
Joe Fioti
0ccd344a69 added flash attention demo 2025-05-23 13:01:09 -05:00
Joe Fioti
0d61e77d83 Merge pull request #83 from matthewjgunton/next
simple readme explaining our open-sourced code
2025-05-23 10:55:35 -07:00
Matthew Gunton
8f427daff6 simple readme explaining our open-sourced code 2025-05-23 11:35:10 -05:00
Joe Fioti
1a7050a120 Merge branch 'main' of https://github.com/jafioti/luminal 2025-05-21 16:12:54 -07:00
Joe Fioti
3c5aa2c253 fixed kernel fusion test 2025-05-21 16:12:48 -07:00
Joe Fioti
969689f8e9 Update README.md 2025-05-21 18:55:06 -04:00
Joe Fioti
ea2eb367e9 Update README.md 2025-05-21 18:53:37 -04:00
Joe Fioti
8afc084553 Update README.md 2025-05-20 12:50:47 -05:00
Joe Fioti
f7997d877a Merge pull request #82 from matthewjgunton/main
Metal Kernel Adjusted to use Generic
2025-05-19 23:02:54 -05:00
Joe Fioti
67bf8e7eea fixed pow function in hl ops 2025-05-18 22:01:16 -05:00
Joe Fioti
e3634ebbd6 fixed exp 2025-05-18 21:40:55 -05:00
Matthew Gunton
0ba3e9e43b Metal Kernel Adjusted to use Generic 2025-05-17 19:04:30 -05:00
Joe Fioti
25dd25bb38 fixed mean reduce compiler 2025-05-15 22:52:39 -05:00
Joe Fioti
01862e2ab7 updated candle 2025-05-15 22:15:28 -05:00
Joe Fioti
6318d1d26e Merge pull request #81 from matthewjgunton/main
Random initialization of weights + fixing pow reciprocal
2025-05-14 21:08:57 -05:00
Matthew Gunton
7278c722b7 Merge branch 'jafioti:main' into main 2025-05-14 20:27:52 -05:00
Matthew Gunton
b36eaa9ba6 Adding random initialization of weights + fixing reciprocal issue with pow() 2025-05-14 20:27:26 -05:00
Joe Fioti
92c3fff60a Update README.md 2025-05-13 14:01:01 -04:00
Joe Fioti
96a2187539 technically contiguous is an op, though it shouldn't be... 2025-05-13 13:57:35 -04:00
Joe Fioti
d40b68c367 qwen readme 2025-04-30 09:03:27 -05:00
Joe Fioti
8b608ecde7 switched qwen to fp32 on cuda 2025-04-30 08:47:47 -05:00
Joe Fioti
c767877507 cleaned up qwen example 2025-04-30 00:11:09 -05:00
Joe Fioti
f2b8fbf0ad added qwen 3! 2025-04-29 23:14:40 -05:00
Joe Fioti
e110bb516f working qwen on full prec 2025-04-29 22:54:07 -05:00
Joe Fioti
9a8c59c86e tweak to qwen 2025-04-29 16:39:26 -05:00
Joe Fioti
11db75b649 docs 2025-04-29 16:28:55 -05:00
Joe Fioti
7a198f816e qwen working on longer sequences but still precision issues 2025-04-29 16:24:58 -05:00
Joe Fioti
d235fb5a73 Merge branch 'main' of https://github.com/jafioti/luminal 2025-04-29 12:00:42 -05:00
Joe Fioti
ee993f2caa added qwen, still has high error rate 2025-04-29 12:00:37 -05:00
Joe Fioti
6a24aab3d3 Merge pull request #80 from matthewjgunton/main
fixing typo in docs
2025-04-29 08:32:04 -05:00
Matthew Gunton
4b95be8c15 Merge branch 'jafioti:main' into main 2025-04-28 23:37:36 -05:00
Matthew Gunton
bc86f88b4a fixing typoe in docs 2025-04-28 23:37:00 -05:00
Joe Fioti
758f496355 rewrite-trajectory verified optimal matmul 2025-04-25 23:41:12 -05:00
Joe Fioti
6d596f8d2a fixed symbolic equivalent 2025-04-24 22:26:23 -05:00
Joe Fioti
211ee828ab Merge pull request #79 from matthewjgunton/main
adding example documentation for llama 3b inferencing
2025-04-24 15:40:33 -04:00
Joe Fioti
84a2e285b6 simdgroup matmul works! 2025-04-23 20:43:25 -05:00
Joe Fioti
a8810b908f fix 2025-04-23 19:59:49 -05:00
Joe Fioti
44530716bd vectorized thread looped simdgroup matmul working, but with single tile accumulation 2025-04-23 19:54:15 -05:00
Joe Fioti
11bfb96245 fixed cuda warnings 2025-04-23 11:04:39 -05:00
Joe Fioti
da5d167a61 updated luminal_cuda to newest cudarc 2025-04-23 10:51:32 -05:00
Joe Fioti
16177fc5e6 thread looping 2025-04-23 07:30:26 -05:00
Matthew Gunton
50619342e7 edits to llama explanation 2025-04-22 21:37:46 -05:00
Matthew Gunton
30cf346af8 adding example documentation for llama 3b inferencing 2025-04-21 23:46:02 -05:00
Joe Fioti
3822a38a6e fixed tests 2025-04-21 19:39:33 -04:00
Joe Fioti
e4a9bb31be Added moondream, not working yet 2025-04-21 14:46:41 -04:00
Joe Fioti
beeccc20f3 tweaks 2025-04-20 23:54:15 -04:00
Joe Fioti
922c9534de Fixed clippy 2025-04-20 23:43:14 -04:00
Joe Fioti
ca99a3c58c updated candle dep 2025-04-20 23:41:26 -04:00
Joe Fioti
40704801eb cleaned up metal 2025-04-20 23:23:28 -04:00
Joe Fioti
09f125ca49 simdgroup matmul on metal working 2025-04-19 16:46:32 -04:00
Joe Fioti
a8bb86ca07 moved pyluminal 2025-04-18 23:08:42 -04:00
Joe Fioti
d4d90dd75e Update README.md 2025-04-18 23:08:00 -04:00
Joe Fioti
621a4ea5e3 Merge branch 'main' of https://github.com/jafioti/luminal 2025-04-18 23:06:07 -04:00
Joe Fioti
20f831d3ee added pyluminal 2025-04-18 23:06:01 -04:00
Joe Fioti
1b46e022ba got metal tests running again 2025-04-18 22:42:04 -04:00
Joe Fioti
5dc4d355fa added cuda support 2025-03-11 10:04:50 -05:00
Joe Fioti
fd288cfa32 added cuda support 2025-03-11 10:00:38 -05:00
Joe Fioti
cc188db593 rewritten tiled smem matmul 2025-03-11 00:10:22 -05:00
Joe Fioti
81175151cb Update README.md 2025-03-09 15:04:27 -04:00
Joe Fioti
3533538731 smem matmul works 2025-03-09 12:05:28 -05:00
Joe Fioti
67028efbf4 fixed tile matmul 2025-02-11 11:00:58 -06:00
Joe Fioti
40293587b8 padding dims 2025-02-10 15:14:25 -06:00
Joe Fioti
24660e793d validate graphs 2025-02-10 15:11:18 -06:00
Joe Fioti
88f7418100 exp-sin outer product works 2025-02-10 15:04:29 -06:00
Joe Fioti
63b15ea3fc output buffers in kernel 2025-02-10 12:58:24 -06:00
Joe Fioti
80602e4278 multiple output buffers 2025-02-10 11:40:41 -06:00
Joe Fioti
5de0e273e6 remapped kernerl inputs 2025-02-10 11:14:21 -06:00
Joe Fioti
59060da192 minor 2025-02-10 10:08:36 -06:00
Joe Fioti
be9dc77b84 tiled matmul 2025-02-10 09:19:55 -06:00
Joe Fioti
06aae2fe13 changed to more readable codegen 2025-02-10 09:13:06 -06:00
Joe Fioti
b28363b86d run graphs 2025-02-09 11:14:06 -06:00
Joe Fioti
99ada89a2e test metal kernels running 2025-02-09 09:42:20 -06:00
Joe Fioti
ce496b9640 cleanup 2025-02-07 20:22:54 -06:00
Joe Fioti
2f5d299fd1 added strudes 2025-02-07 15:10:44 -06:00
Joe Fioti
385ca78757 handle input strides better 2025-02-07 10:04:40 -06:00
Joe Fioti
31b7ba8aa6 better codegen 2025-02-05 09:23:53 -06:00
Joe Fioti
128d7c5b0e cleanup 2025-02-04 15:30:31 -06:00
Joe Fioti
f2291ffb6b e2e codegen 2025-02-04 15:19:32 -06:00
Joe Fioti
1f6d629bd7 better matching 2025-02-04 11:17:54 -06:00
Joe Fioti
d8e9b898e9 added start of codegen 2025-02-03 18:39:37 -06:00
Joe Fioti
000db2f194 added ir 2025-02-03 09:15:10 -06:00
Joe Fioti
d311003e8e Update elementwise_fusion.rs 2024-10-09 06:47:19 -04:00
Joe Fioti
caa7e55524 mostly fixed symbolic 2024-08-03 22:46:37 -05:00
Joe Fioti
ecf6d65f23 Removed yolo files 2024-08-03 13:04:24 -05:00
Joe Fioti
5c7d67bff1 Changed symbolic 2024-08-03 13:03:41 -05:00
Joe Fioti
c23aa440b2 Changed mean reduce 2024-07-30 11:08:44 -05:00
Joe Fioti
c47b38a56a Merge branch 'main' of https://github.com/jafioti/luminal 2024-07-30 10:35:51 -05:00
Joe Fioti
e1279d9780 Changed mean reduce 2024-07-30 10:35:46 -05:00
Joe Fioti
0db8f6c793 Merge pull request #76 from janroden/fix-clip 2024-07-27 08:36:43 -05:00
Jan Roden
207bc3686d Fix clip operation on tensor and add test for it. 2024-07-27 11:14:29 +02:00
Joe Fioti
fe8eb971d5 Update README.md 2024-07-25 15:32:54 -05:00
Joe Fioti
81db899306 Clippy 2024-07-25 15:28:55 -05:00
Joe Fioti
38e2029ff4 Clippy 2024-07-25 15:28:14 -05:00
Joe Fioti
a31000a9a6 Clippy 2024-07-25 15:22:19 -05:00
Joe Fioti
50d1bb4d9f Fixed cuda 2024-07-25 15:16:06 -05:00
Joe Fioti
65909205e3 Merge branch 'main' of https://github.com/jafioti/luminal 2024-07-20 22:52:41 -05:00
Joe Fioti
1a1fd056b0 removed yolo data 2024-07-20 22:52:35 -05:00
Joe Fioti
b5c38dc6db Refactored expression system 2024-07-20 22:51:24 -05:00
Joe Fioti
56e5c5797d Update README.md 2024-07-11 22:48:03 -04:00
Joe Fioti
2bdad487d9 Fixed examples 2024-07-09 11:34:47 -04:00
Joe Fioti
9aba341b95 Merge branch 'main' of https://github.com/jafioti/luminal 2024-07-09 11:28:55 -04:00
Joe Fioti
fb544b7530 Fixed whisper 2024-07-09 11:28:51 -04:00
Joe Fioti
9d37f25c7e Update README.md 2024-07-08 13:17:15 -04:00
Joe Fioti
e917ad2b43 removed yolo target 2024-07-08 12:55:26 -04:00
Joe Fioti
8d7b8c8972 Switched to runtime shapes 2024-07-08 12:54:42 -04:00
Joe Fioti
ddc201e1c1 Merge branch 'main' of https://github.com/jafioti/luminal 2024-07-04 10:57:20 -04:00
Joe Fioti
0c98ce2701 Changed metal copy 2024-07-04 10:57:15 -04:00
Joe Fioti
7136f54404 Merge pull request #71 from swfsql/cuda_gather_failure
Cuda `GatherCompiler` fails on low dimensionality (failing test)
2024-07-02 21:13:25 -04:00
Thiago Machado
82c9833e30 add fix and reduced test dimensions 2024-06-25 14:05:42 -04:00
Joe Fioti
8d36e703d7 removed gguf warnings 2024-06-22 18:40:49 -05:00
Joe Fioti
14c269e604 Merge branch 'main' of https://github.com/jafioti/luminal 2024-06-22 17:55:36 -05:00
Joe Fioti
4db3120d68 minor changes 2024-06-22 17:55:32 -05:00
Thiago Machado
9f98da3581 add a passing and failing tests 2024-06-21 21:16:40 -04:00
Joe Fioti
f61d53f859 Fixed phi 2024-06-15 20:50:50 -05:00
Joe Fioti
82f2165e20 working 2024-06-15 20:20:44 -05:00
Joe Fioti
26af0a81d7 Partially fixed pphi 2024-06-15 16:24:58 -05:00
Joe Fioti
db69f64c31 remove metal dep from examples 2024-06-08 09:31:58 -05:00
Joe Fioti
f5c9f6d56b StorageBufferCompiler causing segfault 2024-06-05 14:11:31 -05:00
Joe Fioti
3c47c9f874 removed metal softmax 2024-06-05 10:35:44 -05:00
Joe Fioti
e4fecf85ea Merge branch 'main' of https://github.com/jafioti/luminal 2024-06-04 15:28:59 -05:00
Joe Fioti
2605a02d04 Changed metal include system 2024-06-04 15:28:48 -05:00
Joe Fioti
35b3883f98 Fixed llama server 2024-06-01 16:47:05 -05:00
Joe Fioti
ad1a3d9eca Fixed phi example 2024-06-01 11:45:01 -05:00
Joe Fioti
2756c87b42 Checked mul 2024-05-31 13:55:14 -05:00
Joe Fioti
1c59194427 Cleaned up luminal metal 2024-05-29 15:05:06 -05:00
Joe Fioti
ee69188842 Cleaned up luminal_cuda 2024-05-27 22:48:10 -05:00
Joe Fioti
d1a493b162 Small chagnge 2024-05-27 19:34:38 -05:00
Joe Fioti
aebbd9c08c Cuda test fixes 2024-05-27 19:21:10 -05:00
Joe Fioti
ad2d73fa9b Fixed whisper example 2024-05-26 11:45:59 -05:00
Joe Fioti
89b6d753bc Fixed deps 2024-05-26 10:38:43 -05:00
Joe Fioti
2d5393f29f Small changes 2024-05-24 21:16:55 -05:00
Joe Fioti
619b354bb9 Added graph serialization to metal 2024-05-22 10:31:41 -05:00
Joe Fioti
e35dd4b95b Updated action 2024-05-20 15:35:23 -05:00
Joe Fioti
07e6f64a3c Fixed workspace 2024-05-20 15:33:52 -05:00
Joe Fioti
38c5699977 Improved symbolic speed 2024-05-20 15:22:38 -05:00
Joe Fioti
c11fc64edd Fixed weight loading 2024-05-19 16:58:13 -05:00
Joe Fioti
706758b45e Whisper working! 2024-05-19 13:20:33 -05:00
Joe Fioti
e2f7e813dc Fixed conv partially 2024-05-13 11:35:58 -05:00
Joe Fioti
1df0327af3 removed luminal_cudarc 2024-05-12 13:23:56 -05:00
Joe Fioti
1b00a04238 Whisper update 2024-05-12 12:58:21 -05:00
Joe Fioti
b385db2c3a whisper decoding pass 2024-05-11 22:40:00 -05:00
Joe Fioti
63e22243a0 Whisper encode audio 2024-05-11 22:20:10 -05:00
Joe Fioti
571224ec28 exclude whisper example 2024-05-11 14:28:15 -05:00
Joe Fioti
efd9b1ce0b GI Actions change 2024-05-11 14:27:34 -05:00
Joe Fioti
252992f6f5 GI Actions change 2024-05-11 14:26:42 -05:00
Joe Fioti
2795c594a5 GI Actions change 2024-05-11 14:24:03 -05:00
Joe Fioti
f108b7aa99 GI Actions change 2024-05-11 14:22:40 -05:00
Joe Fioti
92bf7470dc Whisper model loads 2024-05-11 14:20:12 -05:00
Joe Fioti
25f4247208 Switched symbolic to use egg 2024-05-08 17:20:17 -05:00
Joe Fioti
60eba92e34 Merge pull request #60 from jafioti/cas
removed reduce triples and added start of cas integration into symbolic
2024-05-07 15:16:35 -05:00
Joe Fioti
2364298296 Merge branch 'main' into cas 2024-05-07 15:16:24 -05:00
Joe Fioti
3c56aa6842 Removed cas-rs 2024-05-07 15:10:40 -05:00
Joe Fioti
af0f7a150a Working on llama 2024-05-07 12:19:11 -05:00
Joe Fioti
d96cdc3430 Removed many clones from symbolic 2024-05-07 09:56:59 -05:00
Joe Fioti
a52ab8072b Initial cas implementation. Still need mod 2024-05-07 09:39:13 -05:00
Joe Fioti
07ab3b31d6 Fixed metal 2024-05-06 22:13:12 -05:00
Joe Fioti
a5d292a47d Added more whisper model and matmul errors 2024-05-06 12:52:26 -05:00
Joe Fioti
1e85a64f56 Fixed timed ops 2024-05-06 09:16:12 -05:00
Joe Fioti
3f75ed6c28 Slightly simplified conv 2024-05-05 17:13:14 -05:00
Joe Fioti
3ed9fa6353 Refactored conv 2024-05-05 17:06:41 -05:00
Joe Fioti
61fc7aaccb Merge pull request #53 from NewBornRustacean/feature-conv3d
Feature conv3d
2024-05-05 17:04:56 -05:00
Joe Fioti
2597455824 Merge pull request #56 from xnorpx/dev/fix_warns
Dev/fix warns and add ci
2024-05-05 09:06:36 -05:00
xnorpx
aa615c003e Add clippy and fmt to CI 2024-05-04 20:57:39 -07:00
xnorpx
5409c64349 Mark example as WIP and allow unused 2024-05-04 20:53:04 -07:00
xnorpx
4b2c0c6251 Clippy fixes 2024-05-04 20:51:58 -07:00
Joe Fioti
5f730aef1f Changed arcmax 2024-05-03 19:00:35 -05:00
Joe Fioti
6d917dd579 small changes 2024-05-03 18:47:54 -05:00
Joe Fioti
97b66a376f Fixed metal 2024-05-03 18:46:35 -05:00
Joe Fioti
17dca47be7 Better padding api 2024-05-03 13:48:57 -05:00
Joe Fioti
4213177982 phi model change 2024-05-03 13:13:38 -05:00
Joe Fioti
114c587a49 Small example changes 2024-05-03 12:55:39 -05:00
Joe Fioti
bf97c89873 Fixed phi example on metal 2024-05-03 12:49:42 -05:00
Joe Fioti
816deea270 Merge branch 'main' of https://github.com/jafioti/luminal 2024-05-03 12:17:21 -05:00
Joe Fioti
70bbd8d4cf phi changes 2024-05-03 12:16:57 -05:00
Joe Fioti
bb9552e13a Cleaned up slicing syntax 2024-05-03 10:52:29 -05:00
NewBornRustacean
a86b915b00 test pass 2024-05-03 21:52:35 +09:00
NewBornRustacean
a73c2b3dd5 conv3d dim. missmatch resolved 2024-05-03 21:47:07 +09:00
NewBornRustacean
ffb2baa3a8 conv3d test cased revised 2024-05-03 17:54:24 +09:00
NewBornRustacean
7c78ea4ec7 Merge branch 'jafioti:main' into feature-conv3d 2024-05-03 10:50:28 +09:00
Joe Fioti
dfd40edb95 updated phi example 2024-05-02 18:58:54 -05:00
NewBornRustacean
9f615d3319 Merge branch 'jafioti:main' into feature-conv3d 2024-05-03 07:51:00 +09:00
Joe Fioti
78de438035 gitignore changes 2024-05-02 16:17:21 -05:00
Joe Fioti
bb97aab756 Merge branch 'main' of https://github.com/jafioti/luminal 2024-05-02 16:14:47 -05:00
Joe Fioti
b7e40c1317 example changes 2024-05-02 16:14:42 -05:00
NewBornRustacean
aa939cb031 conv3d test: index out of bounds error 2024-05-02 17:38:54 +09:00
NewBornRustacean
349ff23685 Merge branch 'jafioti:main' into feature-conv3d 2024-05-02 14:45:59 +09:00
Joe Fioti
dc01280737 Merge branch 'main' of https://github.com/jafioti/luminal 2024-05-01 22:54:25 -05:00
Joe Fioti
07b7bb69b8 Per iteration timing 2024-05-01 22:54:20 -05:00
NewBornRustacean
0f38920aaa Merge branch 'jafioti:main' into feature-conv3d 2024-05-02 11:56:51 +09:00
Joe Fioti
9e048ff98d Merge pull request #52 from TheSeamau5/main
Add llama server
2024-05-01 21:28:06 -05:00
Hassan Hayat
962e301d2b Merge conflicts 2024-05-01 14:10:38 -05:00
Joe Fioti
f2c3130e3b Conv tests 2024-05-01 14:04:32 -05:00
Joe Fioti
fcb824df6a Conv tests 2024-05-01 14:04:23 -05:00
Joe Fioti
a92d13642f Simplified conv 2024-05-01 13:17:44 -05:00
Joe Fioti
284532420d Removed luminal_symbolic 2024-05-01 11:56:27 -05:00
Joe Fioti
a29319e2fb Small doc change 2024-05-01 14:52:58 +00:00
NewBornRustacean
d03c8a041a draft conv3d and test cases 2024-05-01 08:49:20 +09:00
NewBornRustacean
b323f827dd add permutation for Axes6 2024-05-01 08:48:44 +09:00
NewBornRustacean
c1cedbc268 Merge branch 'feature-conv3d' of https://github.com/NewBornRustacean/luminal into feature-conv3d 2024-05-01 08:12:05 +09:00
NewBornRustacean
0f196966a7 Merge branch 'jafioti:main' into feature-conv3d 2024-05-01 08:11:51 +09:00
Hassan Hayat
ae300ec7e4 Merge remote-tracking branch 'upstream/main' 2024-04-30 12:10:11 -05:00
Joe Fioti
965b362f34 Update README.md 2024-04-30 11:23:26 -05:00
Hassan Hayat
efb53d6900 Merge remote-tracking branch 'upstream/main' 2024-04-29 21:10:26 -05:00
NewBornRustacean
ced823c707 Merge branch 'jafioti:main' into feature-conv3d 2024-04-30 09:38:03 +09:00
Joe Fioti
c50d51a69e Added excalidraw images to docs 2024-04-29 23:34:46 +00:00
NewBornRustacean
c5f54b5104 Merge branch 'jafioti:main' into feature-conv3d 2024-04-30 07:53:11 +09:00
Hassan Hayat
eb774b7647 Merge remote-tracking branch 'upstream/main' 2024-04-29 16:58:15 -05:00
Joe Fioti
88ddc0cf09 Small docs 2024-04-29 18:17:43 +00:00
Hassan Hayat
fc9aa59a13 Remove eos token from output 2024-04-29 13:17:27 -05:00
Hassan Hayat
0ad2d92a5b Add gitignore, move llama folder 2024-04-29 13:07:16 -05:00
Joe Fioti
d3178b3443 Removed nohup 2024-04-29 17:44:20 +00:00
Joe Fioti
e2be277347 Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-29 17:43:49 +00:00
Joe Fioti
15ba813ca6 Added gpu blog post 2024-04-29 17:43:31 +00:00
Hassan Hayat
8f01f4dba3 Merge remote-tracking branch 'upstream/main' 2024-04-29 10:32:55 -05:00
Joe Fioti
0efbc51e41 Fixed metal tests 2024-04-29 09:59:29 -05:00
Joe Fioti
4fee36107d Small changes 2024-04-29 09:40:16 -05:00
NewBornRustacean
b48f611464 Merge branch 'main' into feature-conv3d 2024-04-29 17:55:57 +09:00
Joe Fioti
4e6fa93c9b Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-28 23:30:00 -05:00
Joe Fioti
41d6d08cd3 Phi 3 working 2024-04-28 23:29:54 -05:00
Joe Fioti
0f63429513 Update README.md 2024-04-28 22:00:20 -05:00
Hassan Hayat
eb5c40832e Get model running, but not currently functioning 2024-04-28 20:57:32 -05:00
Hassan Hayat
d1c06cc9f3 Upload code 2024-04-28 20:10:32 -05:00
Hassan Hayat
c62419270d Update main.rs 2024-04-28 19:58:33 -05:00
Hassan Hayat
fa01dbe099 Save code 2024-04-28 19:55:40 -05:00
Hassan Hayat
67f267a276 Get the timestamp and uuid working 2024-04-28 19:02:54 -05:00
NewBornRustacean
f497531908 Merge branch 'main' into feature-conv3d 2024-04-29 07:50:56 +09:00
Hassan Hayat
7b3214797e Merge remote-tracking branch 'upstream/main' 2024-04-28 17:48:59 -05:00
Hassan Hayat
9e6403ca75 Bring in llama implementation from example 2024-04-28 17:48:51 -05:00
Joe Fioti
003c824a02 Added elementwise fusion to cuda 2024-04-28 15:59:02 -05:00
Hassan Hayat
568b5de507 move to examples folder 2024-04-28 15:53:05 -05:00
Hassan Hayat
012734c044 Move to folder 2024-04-28 15:50:59 -05:00
Hassan Hayat
38bc35230f Base axum server mocking openai api 2024-04-28 14:57:23 -05:00
Joe Fioti
1e5d63949e Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-28 12:35:20 -05:00
Joe Fioti
fa2b7ac22e Fixed many cuda bugs 2024-04-28 12:35:14 -05:00
NewBornRustacean
18a9aaa38e Merge branch 'main' into feature-conv3d 2024-04-28 10:55:55 +09:00
Joe Fioti
868e1c667e Support fp16 on llama 2024-04-27 10:00:31 -05:00
Joe Fioti
9b734d6cbd Fixed llama layers 2024-04-27 09:43:51 -05:00
Joe Fioti
076a165904 Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-27 09:35:57 -05:00
Joe Fioti
fb84e93815 Metal fixes 2024-04-27 09:35:55 -05:00
NewBornRustacean
6869d81383 define conv3d struct 2024-04-27 21:54:50 +09:00
Joe Fioti
8bf379b3eb Changed tests 2024-04-26 22:03:22 -05:00
Joe Fioti
ad8908ab79 Fixed cpu gather compiler 2024-04-26 21:49:50 -05:00
Joe Fioti
3850e14649 Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-26 20:20:22 -05:00
Joe Fioti
5e3e69d109 Fixed cuda graph prints'
:
2024-04-26 20:20:17 -05:00
Joe Fioti
09264419ec Update README.md 2024-04-26 16:05:13 -05:00
Joe Fioti
b502804e98 Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-26 16:03:37 -05:00
Joe Fioti
a859870a08 move dag image 2024-04-26 15:43:27 -05:00
Joe Fioti
45dc0f688a Small opt 2024-04-26 15:38:54 -05:00
Joe Fioti
358dcebde2 Small shape changes 2024-04-26 15:23:44 -05:00
Joe Fioti
c49507e087 Removed device from CopyToDevice for metal 2024-04-26 14:20:21 -05:00
Joe Fioti
194f221372 Combined dims for expressions 2024-04-26 14:16:27 -05:00
Joe Fioti
cd0195ca81 Combine dims for index and valid expressions 2024-04-26 13:56:34 -05:00
Joe Fioti
5bc7417f55 Fixed metal 2024-04-26 11:15:26 -05:00
Joe Fioti
7357817b46 Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-26 11:07:27 -05:00
Joe Fioti
a7f823080d Simplified index expression fn 2024-04-26 11:07:20 -05:00
Joe Fioti
862af13096 Fixed metal 2024-04-26 10:26:44 -05:00
Joe Fioti
a39176dc7a Cleaned up symbolic more 2024-04-25 23:47:07 -05:00
Joe Fioti
b359ecf887 Fixed intro page 2024-04-25 02:47:11 +00:00
Joe Fioti
2bc3c3b00e remove nohup 2024-04-25 02:19:34 +00:00
Joe Fioti
4bceec314f Fixed menu issue 2024-04-25 02:19:19 +00:00
Joe Fioti
31a93c0045 Changed meta tags 2024-04-24 23:55:53 +00:00
Joe Fioti
f6029cefc4 Removed nohup 2024-04-24 22:18:01 +00:00
Joe Fioti
340a31f410 Fixed doc buttons 2024-04-24 22:17:34 +00:00
Joe Fioti
29b019fe01 Changed docs 2024-04-24 17:12:08 -05:00
Joe Fioti
5c9a236843 Changed docs 2024-04-24 16:59:42 -05:00
Joe Fioti
896d3913d9 remove docs 2024-04-24 16:57:28 -05:00
Joe Fioti
172ae602ca Changed docs 2024-04-24 16:57:01 -05:00
Joe Fioti
29a0bfe195 Added mint json 2024-04-24 16:55:13 -05:00
Joe Fioti
5b8e922a70 Added new doc site 2024-04-24 16:44:52 -05:00
Joe Fioti
1424c40384 Update README.md 2024-04-24 16:41:11 -05:00
Joe Fioti
b9d1f6b00e Update README.md 2024-04-24 16:35:14 -05:00
Joe Fioti
149154ba26 Update README.md 2024-04-24 16:34:58 -05:00
Joe Fioti
8fc438550c Merge branch 'main' of https://github.com/jafioti/luminal 2024-04-22 14:55:50 -05:00
Joe Fioti
5c99c9da04 Removed mistral example 2024-04-22 14:55:45 -05:00
Joe Fioti
3184d273cf Update README.md 2024-04-19 23:09:25 -05:00
Joe Fioti
1a43b7485e Updated cuda 2024-04-19 22:59:33 -05:00
Joe Fioti
1c40dd0ed4 Cleanup core 2024-04-19 19:54:15 -05:00
Joe Fioti
577fd5531b Cached regexes 2024-04-19 19:23:54 -05:00
Joe Fioti
3412708a91 Small changes 2024-04-19 12:42:16 -05:00
Joe Fioti
daaefaeabd Fixed compiler visibility 2024-04-19 11:53:14 -05:00
Joe Fioti
c576a4e7e4 Fixed timed compiler 2024-04-19 11:36:42 -05:00
Joe Fioti
f0d6a58a7f Added llama3 2024-04-19 11:34:52 -05:00
Joe Fioti
396cae039b Fixed speed 2024-04-19 11:08:53 -05:00
Joe Fioti
9c63b6456f Falliable symbolic ops 2024-04-19 09:07:34 -05:00
Joe Fioti
4a0bd46d4d Minor changes 2024-04-12 23:39:10 -05:00
Joe Fioti
a94c97e724 Update README.md 2024-04-12 13:03:12 -05:00
Joe Fioti
d603aac3f2 Simplified mistral example 2024-04-10 16:38:31 -05:00
Joe Fioti
90e350b6fb Removed serialization and TraitObjEq 2024-04-10 16:16:22 -05:00
Joe Fioti
2126556f25 Removed print 2024-04-09 20:09:41 -04:00
Joe Fioti
a15a79bff6 Fixed elementwise fusion 2024-04-09 19:54:18 -04:00
Joe Fioti
da68799982 Fixed elementwise replacements 2024-04-08 19:29:34 -04:00
Joe Fioti
ddd001ef0e Fixed cpu gather op 2024-04-08 10:02:55 -04:00
Joe Fioti
acf1e4b465 small 2024-04-06 20:56:06 -04:00
Joe Fioti
589148707f New elementwise fusion 2024-04-06 20:05:53 -04:00
Joe Fioti
239af0df26 Simplified primop reductions 2024-04-03 16:09:51 -04:00
Joe Fioti
5fcee0050c Started whisper 2024-04-03 15:40:47 -04:00
Joe Fioti
481e416b0d Fixed metal again 2024-04-02 21:36:07 -04:00
Joe Fioti
a6eef04817 Metal training works 2024-04-02 20:19:00 -04:00
Joe Fioti
a4d033a9af Fixed metal 2024-04-02 11:21:52 -04:00
Joe Fioti
a68d8bd341 reorganized core 2024-04-02 11:07:07 -04:00
Joe Fioti
8ef34a5805 Cleaned up training example 2024-04-02 10:53:41 -04:00
Joe Fioti
9a34224891 First training examplegit add . 2024-04-01 21:17:23 -04:00
Joe Fioti
21c1e72d1f Reogranized nn module 2024-04-01 17:07:24 -04:00
Joe Fioti
784fe20f32 name change 2024-04-01 14:01:42 -04:00
Joe Fioti
355f1a4816 Reorganized into luminal_train 2024-04-01 13:54:16 -04:00
Joe Fioti
aff47d5960 Added losses 2024-04-01 13:47:46 -04:00
Joe Fioti
5d07fbc853 tweaks 2024-04-01 10:53:37 -04:00
Joe Fioti
9540ab1c88 working autograd on transformer 2024-04-01 10:39:13 -04:00
Joe Fioti
03a9dedb04 Autograd working for matmul reshape 2024-03-29 10:59:13 -04:00
Joe Fioti
bf64f3cc7b Merge branch 'main' of https://github.com/jafioti/luminal 2024-03-28 20:30:12 -04:00
Joe Fioti
c95c514c52 Autograd works for mlps, does not work with reshapes 2024-03-28 20:30:07 -04:00
Joe Fioti
a30212bb11 Merge pull request #45 from zeux/cuda-fix
Fix CUDA build on Rust 1.77
2024-03-24 13:59:39 -04:00
Arseny Kapoulkine
6ed065d4b7 Remove nn submodule imports and rely on prelude re-exporting symbols 2024-03-23 18:05:32 -07:00
Arseny Kapoulkine
00be879d6a Fix CUDA compilation by defining Output type for Compiler trait 2024-03-23 17:50:58 -07:00
Joe Fioti
72c2ff6c3c Cleaned up autograd more 2024-03-22 14:44:34 -04:00
Joe Fioti
ca1c7666cb Fixed autograd 2024-03-22 14:07:42 -04:00
Joe Fioti
dcb8b0de22 Fixed maxreduce autograd 2024-03-20 22:34:00 -04:00
Joe Fioti
f65b029249 Initial autograd 2024-03-20 22:24:39 -04:00
Joe Fioti
297b0c3a9d Fixed metla api 2024-03-18 10:29:48 -04:00
Joe Fioti
f0248d2954 More hl ops 2024-03-17 21:00:42 -04:00
Joe Fioti
eb87e69b0e Compiler return values 2024-03-15 23:22:00 -04:00
Joe Fioti
6a76c4a7d5 Changed simple example 2024-03-14 22:50:15 -05:00
Joe Fioti
59cd270b3f Merge 2024-03-14 19:42:56 -05:00
Joe Fioti
4219e5c0ef Changed mistral layers 2024-03-14 19:42:13 -05:00
Joe Fioti
0c1885fce7 Broken quantized cuda mistral 2024-03-14 19:41:42 -05:00
Joe Fioti
c7777ecc3a Small mistral tweak 2024-03-13 22:58:12 -05:00
Joe Fioti
f7e699d9d8 Merge pull request #42 from TheSeamau5/main
Fix tokenizer issue by switching to HF Tokenizers
2024-03-13 22:52:17 -05:00
Joe Fioti
371df9ecb0 Small change 2024-03-13 22:51:51 -05:00
Hassan Hayat
95438654ac Print total token count 2024-03-13 19:02:32 -05:00
Hassan Hayat
9b178b99ed Switch to hf tokenizers library for correct mistral tokenizer 2024-03-13 18:47:00 -05:00
Joe Fioti
aac542edcb More tests 2024-03-13 11:18:03 -05:00
Joe Fioti
c636bd34d5 Simplified tests 2024-03-13 11:04:53 -05:00
Joe Fioti
780951c828 removed rope kernel 2024-03-12 14:53:10 -05:00
Joe Fioti
eb3e8a26b2 Furthur fusion cleanups 2024-03-12 12:44:00 -05:00
Joe Fioti
19abb4a6b0 Merge branch 'main' of https://github.com/jafioti/luminal 2024-03-12 12:27:53 -05:00
Joe Fioti
b58b9848ad Simplified kernel fusion 2024-03-12 12:27:49 -05:00
Joe Fioti
8ffd089b51 Fixed cuda for new remap api 2024-03-12 11:39:25 -05:00
Joe Fioti
32bbff4953 Cleaned up fusion compiler: 2024-03-11 18:08:22 -05:00
Joe Fioti
3e956f91e7 Better elementwise fusion 2024-03-11 15:28:30 -05:00
Joe Fioti
3be6fb103c Broken fusion 2024-03-09 12:17:43 -06:00
Joe Fioti
890e3bba93 Added initial new elementwise fusion to cuda 2024-03-08 18:04:32 -06:00
Joe Fioti
953ce1b1f8 Cuda cleanup 2024-03-07 12:55:18 -06:00
Joe Fioti
ede1454338 cuda cleanup 2024-03-07 12:36:43 -06:00
Joe Fioti
f7e3786455 cuda cleanup 2024-03-07 12:34:53 -06:00
Joe Fioti
bff20dd582 matmul compiler fix 2024-03-07 09:43:27 -06:00
Joe Fioti
3617460ac5 Merge branch 'main' of https://github.com/jafioti/luminal 2024-03-07 09:39:28 -06:00
Joe Fioti
34fc6156e5 minor ergonomics 2024-03-07 09:39:19 -06:00
Joe Fioti
f9fac0ceec Merge pull request #37 from eltociear/patch-1 2024-03-05 14:00:35 -06:00
Ikko Eltociear Ashimine
029c22ef95 Update README.md
signifigant -> significant
2024-03-06 00:25:03 +09:00
Joe Fioti
e89e3617b8 Fixed mistral, 17 tps 2024-03-04 17:25:32 -06:00
Joe Fioti
b3b34277f4 Merge branch 'main' of https://github.com/jafioti/luminal 2024-03-04 16:52:42 -06:00
Joe Fioti
9b18274db6 Partial fix 2024-03-04 16:52:37 -06:00
Joe Fioti
32f10ac347 Merge pull request #36 from jcsoo/mistral-bytes-no-copy
Use new_buffer_with_bytes_no_copy when creating Metal buffer.
2024-03-04 16:32:30 -06:00
Joe Fioti
d47ca517a4 Merged new selector with metal, still broken 2024-03-04 16:30:45 -06:00
Joe Fioti
8078dea150 new selector api 2024-03-04 15:20:23 -06:00
Jonathan Soo
b785132c2c Use bytes_no_copy option when creating Metal buffer. 2024-03-04 15:38:53 -05:00
Joe Fioti
cb07523f02 fp16 quantized q8 matmul 2024-03-03 11:23:25 -06:00
Joe Fioti
75dd9a2856 Added metal q8 matmul 2024-03-03 11:03:25 -06:00
Joe Fioti
94dd04d3d4 Update README.md 2024-03-03 10:45:30 -06:00
Joe Fioti
9394c5b24c Added test macros to cuda 2024-03-02 19:01:24 -06:00
Joe Fioti
13604b696b removed print 2024-03-02 15:22:12 -06:00
Joe Fioti
9d4d9a51da broken test 2024-03-02 15:04:56 -06:00
Joe Fioti
dc9b57b91a Added test permutations 2024-03-02 14:42:35 -06:00
Joe Fioti
11da4b339d Merge branch 'main' of https://github.com/jafioti/luminal 2024-03-02 14:05:18 -06:00
Joe Fioti
b7e02f7995 Added initial test macros 2024-03-02 14:05:10 -06:00
Joe Fioti
be4fb7dd9f Changed expressionstorage vis 2024-03-01 14:16:04 -06:00
Joe Fioti
6b2c216e45 Update README.md 2024-03-01 11:05:17 -06:00
Joe Fioti
d309b4f338 Update README.md 2024-03-01 07:57:39 -06:00
Joe Fioti
d562a0321b Update README.md 2024-03-01 07:21:43 -06:00
Joe Fioti
1b61fd2e4d Pushed to cratesio 2024-02-29 22:52:52 -06:00
Joe Fioti
e0c9b2b1ff removed luminal_macro target 2024-02-29 22:42:31 -06:00
Joe Fioti
d976f71585 Begin publishing crates 2024-02-29 22:41:47 -06:00
Joe Fioti
00acf7aebe removed diffs 2024-02-29 14:41:14 -06:00
Joe Fioti
29751edf20 Passing metal tests 2024-02-29 13:52:23 -06:00
Joe Fioti
235db905da Update README.md 2024-02-29 12:37:57 -06:00
Joe Fioti
6337d90bce Update README.md 2024-02-29 12:37:36 -06:00
Joe Fioti
8c434b5081 Update README.md 2024-02-29 12:36:45 -06:00
Joe Fioti
84ba491b56 readme 2024-02-28 22:16:42 -06:00
Joe Fioti
c8044504c5 Refactored tests / api 2024-02-28 22:16:05 -06:00
Joe Fioti
9132ad8d94 Updated docs 2024-02-28 21:54:24 -06:00
Joe Fioti
c20b257657 tweak 2024-02-28 21:12:04 -06:00
Joe Fioti
89d9bbe105 Merge branch 'main' of https://github.com/jafioti/luminal 2024-02-28 21:10:05 -06:00
Joe Fioti
2f34d413e1 CPU mistral 2024-02-28 21:09:57 -06:00
Joe Fioti
a9875fde4d Fixed matvec test 2024-02-28 10:58:29 -06:00
Joe Fioti
c19e211629 readme 2024-02-27 22:30:23 -06:00
Joe Fioti
3e49033616 Match llama and mistral 2024-02-27 22:26:55 -06:00
Joe Fioti
f0135920aa Bring llama closer to mistral 2024-02-27 21:51:41 -06:00
Joe Fioti
8776a1c3de Small changess 2024-02-27 20:07:27 -06:00
Joe Fioti
b6022900b0 Refined llama 2024-02-27 19:46:10 -06:00
Joe Fioti
cc94802ed0 Update README.md 2024-02-27 18:35:58 -06:00
Joe Fioti
ed2fc61c73 Fixed test 2024-02-27 18:29:33 -06:00
Joe Fioti
ebbc86a312 CPU mistral 2024-02-27 18:11:05 -06:00
Joe Fioti
b065cdd22b Small tweaks 2024-02-27 16:12:58 -06:00
Joe Fioti
c7a0944eda gitignore 2024-02-27 16:05:17 -06:00
Joe Fioti
b2d6a48eab Updated llama 2024-02-27 16:04:05 -06:00
Joe Fioti
ca71f0dd16 CPU mistral 2024-02-27 15:56:04 -06:00
Joe Fioti
94306b086a feature change 2024-02-27 15:48:12 -06:00
Joe Fioti
7f2b9cf336 Mistral working on cuda 2024-02-27 15:44:30 -06:00
Joe Fioti
df1f5c3ca8 Fixed mistral metal 2024-02-27 12:53:48 -06:00
Joe Fioti
ffb7e4c706 Updated gitignore 2024-02-27 11:41:08 -06:00
Joe Fioti
598e303649 Split crates 2024-02-27 11:40:35 -06:00
Joe Fioti
866dfb7804 Spun compilers into crates 2024-02-27 10:43:25 -06:00
Joe Fioti
1cfefed1ce Update README.md 2024-02-27 00:14:08 -06:00
Joe Fioti
d5c6ef451c Update README.md 2024-02-27 00:11:25 -06:00
Joe Fioti
e363385d3f Update readme 2024-02-27 00:08:38 -06:00
Joe Fioti
258d1be49f updated deps 2024-02-26 23:37:40 -06:00
Joe Fioti
18f15d98e2 re-added sqrt 2024-02-26 23:11:58 -06:00
Joe Fioti
0b5ce105a3 removed sqrt 2024-02-26 17:16:55 -06:00
Joe Fioti
d66bf3412a Fixed cuda 2024-02-26 17:10:01 -06:00
Joe Fioti
5ede551fcb Merge pull request #24 from jafioti/q8
Q8 Weight Quantization
2024-02-26 14:33:26 -06:00
Joe Fioti
e10a19668c tweaks 2024-02-26 13:16:18 -06:00
Joe Fioti
881fa13a13 Working fused rope kernel 2024-02-26 12:58:55 -06:00
Joe Fioti
e279912bb8 Started rope metal 2024-02-23 16:42:32 -06:00
Joe Fioti
6dc2a996d2 furthur cleanup 2024-02-23 15:04:03 -06:00
Joe Fioti
7ee1cad15c Cleanup 2024-02-23 14:52:19 -06:00
Joe Fioti
b8a0f08cea Working 8bit mistral 2024-02-23 14:10:18 -06:00
Joe Fioti
edc96f3626 comparisons 2024-02-20 22:09:50 -06:00
Joe Fioti
dd369f18a9 temp 2024-02-17 11:47:05 -06:00
Joe Fioti
9326fe3cc8 broke af 2024-02-17 00:15:39 -06:00
Joe Fioti
3bd99c9f24 Kernel cleanup 2024-02-13 00:10:35 -06:00
Joe Fioti
bd56364160 Batched matvec 2024-02-13 00:05:12 -06:00
Joe Fioti
9547004247 Added quantized matvec 2024-02-12 14:27:11 -06:00
Joe Fioti
647f119d3c Fixed compiler macros 2024-02-09 11:53:00 -06:00
Joe Fioti
8952443ebd refactor metal compiler 2024-02-07 12:21:15 -06:00
Joe Fioti
5947e5cd3d refactor metal compiler 2024-02-07 12:12:58 -06:00
Joe Fioti
10d94710f7 Change feature flags 2024-02-07 12:02:56 -06:00
Joe Fioti
d13af7c562 Remove local cudarc fork 2024-02-07 12:01:39 -06:00
Joe Fioti
c2bbe446da Merged cuda compilers 2024-02-07 11:25:33 -06:00
Joe Fioti
b0a732e5b0 chagen readme 2024-02-07 10:56:37 -06:00
Joe Fioti
59cf7998c9 Fixed cuda tests 2024-02-07 10:12:44 -06:00
Joe Fioti
a6f38be402 Changed features 2024-02-06 21:52:19 -06:00
Joe Fioti
bc92e3137f Fixed many cuda bugs 2024-02-06 21:48:13 -06:00
Joe Fioti
30310a173d Update CONTRIBUTING.md 2024-02-06 17:22:08 -06:00
Joe Fioti
c00935b451 Addded contributing 2024-02-06 17:19:14 -06:00
Joe Fioti
15e4ee6aa3 fix doctests 2024-02-06 09:53:40 -06:00
Joe Fioti
9ec1e75fe6 tweak 2024-02-04 13:29:31 -06:00
Joe Fioti
5898076da5 Added documentation 2024-02-01 22:11:40 -06:00
Joe Fioti
5b17c1880e bug fix 2024-01-29 17:11:37 -06:00
Joe Fioti
1afea6bd86 renaming 2024-01-29 15:49:30 -06:00
Joe Fioti
8dff3619b9 Fixed speed 2024-01-29 15:46:19 -06:00
Joe Fioti
111452a68e Single mistral graph 2024-01-29 15:11:46 -06:00
Joe Fioti
d147ed5063 Mistral 10tps on M1 pro 2024-01-29 09:51:39 -06:00
Joe Fioti
162859dedb small changes 2024-01-28 16:17:49 -06:00
Joe Fioti
56de7fa4c3 small chagnes 2024-01-27 20:09:34 -06:00
Joe Fioti
7cc02dd51d core optimizations 2024-01-27 12:44:52 -06:00
Joe Fioti
e5963f1c9a Update Cargo.toml 2024-01-26 22:21:13 -06:00
Joe Fioti
9d32721ca7 dep changes 2024-01-26 21:07:00 -06:00
Joe Fioti
bc6b8fb283 Small kernel change 2024-01-26 19:57:24 -06:00
Joe Fioti
12381b2624 Changed tril triu api 2024-01-26 18:54:14 -06:00
Joe Fioti
2821145268 Removed isize 2024-01-26 17:40:15 -06:00
Joe Fioti
959528efad Added matmul support for repeated B batches 2024-01-26 17:32:39 -06:00
Joe Fioti
6a5a45eeae Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-25 21:22:14 -06:00
Joe Fioti
4166e27055 gemm refactor 2024-01-25 21:22:06 -06:00
Joe Fioti
f55cf6c0f7 Merge pull request #17 from TheSeamau5/debug
Small Improvements to main
2024-01-25 09:15:59 -06:00
Hassan Hayat
6ddabf2995 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 05:01:46 +01:00
Joe Fioti
54461a6d33 non-contiguous rotate 2024-01-24 21:24:26 -06:00
Hassan Hayat
b5d6f424d9 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 04:22:24 +01:00
Joe Fioti
f846af5901 rotation speedup 2024-01-24 21:20:43 -06:00
Hassan Hayat
f9c766dca7 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 03:41:07 +01:00
Joe Fioti
218db50c79 Small improvements to std_norm 2024-01-24 15:54:18 -06:00
Hassan Hayat
3fddb7e5a8 Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 23:22:32 +01:00
Joe Fioti
7bd8de272b steel matmul is ass 2024-01-23 16:21:02 -06:00
Joe Fioti
80915d3f3a Fixed rotate compiler 2024-01-23 15:19:38 -06:00
Joe Fioti
791f1395d5 Small changes 2024-01-23 14:26:15 -06:00
Hassan Hayat
b5a13381a9 Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 18:34:23 +01:00
Joe Fioti
c64e408471 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-23 10:19:44 -06:00
Joe Fioti
b1770a0b0e Added broken rotate op 2024-01-23 10:19:38 -06:00
Hassan Hayat
37dc4428af Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 09:03:28 +01:00
Joe Fioti
2d198b6be7 Rename LICENSE to LICENSE-APACHE 2024-01-22 21:47:39 -06:00
Joe Fioti
67e8e439c0 Create LICENSE 2024-01-22 21:47:26 -06:00
Joe Fioti
908d2c9222 Create LICENSE-MIT 2024-01-22 21:47:12 -06:00
Joe Fioti
c401a95af2 Update README.md 2024-01-22 21:46:00 -06:00
Joe Fioti
e2864d852f Update README.md 2024-01-22 21:44:53 -06:00
Hassan Hayat
f043ba2d5e Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 04:44:01 +01:00
Joe Fioti
cf8412d3bf Small gemv change 2024-01-22 16:24:21 -06:00
Joe Fioti
5b4bde0070 Change how metal imports work gemv 2024-01-22 16:19:13 -06:00
Joe Fioti
9fead8dad3 Change how metal imports work 2024-01-22 16:17:13 -06:00
Joe Fioti
0d44507f3c Fused softmax kernel 2024-01-22 15:26:22 -06:00
Hassan Hayat
3272749663 Merge remote-tracking branch 'upstream/main' into debug 2024-01-22 19:37:46 +01:00
Joe Fioti
5f917dcbcf Removed one contiguous call 2024-01-22 10:54:29 -06:00
Hassan Hayat
85a08aca3f Merge remote-tracking branch 'upstream/main' into debug 2024-01-22 11:01:24 +01:00
Joe Fioti
192858edf1 Simplified mistral 2024-01-21 23:58:14 -06:00
Joe Fioti
9a5e6f6e69 Simplified mistral 2024-01-21 23:57:15 -06:00
Joe Fioti
6884bd010d Moved CSE to pre generic compiler 2024-01-21 23:40:50 -06:00
Hassan Hayat
9dd852c27e move clap to dev dependencies 2024-01-22 00:27:52 +01:00
Hassan Hayat
198fe76cb3 Update main.rs 2024-01-22 00:12:21 +01:00
Hassan Hayat
9696c4ce09 Improvements to main 2024-01-22 00:09:59 +01:00
Joe Fioti
9a2f8fadd3 Reorg 2024-01-21 13:30:33 -06:00
Joe Fioti
b59fefaa11 Reorganizing 2024-01-21 12:13:47 -06:00
Joe Fioti
8348d06902 reorganized tests 2024-01-21 12:06:13 -06:00
Joe Fioti
8f7f6a6ab3 Commonized metal compiler 2024-01-21 12:03:10 -06:00
Joe Fioti
13e6dc6da5 More commonalities 2024-01-21 11:54:18 -06:00
Joe Fioti
244711d46e Commonized matmul 2024-01-21 11:30:44 -06:00
Joe Fioti
9695bcef84 Fused constants 2024-01-21 10:33:47 -06:00
Joe Fioti
2f20b9959c Removed custom swish kernel 2024-01-21 10:22:04 -06:00
Joe Fioti
308938ec02 Fixed elementwise fusion 2024-01-21 10:07:15 -06:00
Joe Fioti
b1c435b6be Fixed matmuls 2024-01-21 09:32:59 -06:00
Joe Fioti
4219d8ec7b Fixed layer norm 2024-01-20 21:54:24 -06:00
Joe Fioti
8bd7598678 Generalized matmul compiler 2024-01-19 23:37:37 -06:00
Joe Fioti
e89bdbb612 Closer to working elementwise fusion 2024-01-19 17:45:55 -06:00
Joe Fioti
ebb0df6c69 Disabled elementwise fusion 2024-01-19 11:21:30 -06:00
Joe Fioti
8f2d13df3d Enabled elementwise on metal prims 2024-01-16 17:12:56 -06:00
Joe Fioti
69c207b599 Fixed fusion bugs 2024-01-16 17:06:29 -06:00
Joe Fioti
fa04b05b5d Custom fn util 2024-01-16 15:44:29 -06:00
Joe Fioti
54912c4f6a Initial version of elementwise fusion 2024-01-16 15:36:58 -06:00
Joe Fioti
1c0f525e57 Added looped compiler 2024-01-16 09:03:41 -06:00
Joe Fioti
26c0de512f Unified matmul and matvec 2024-01-15 21:25:43 -06:00
Joe Fioti
0c27cb02a8 util functions 2024-01-15 11:57:15 -06:00
Joe Fioti
b822800ffe export node index 2024-01-13 10:24:01 -06:00
Joe Fioti
b54da0ddde bring in line with ggml kernel 2024-01-12 16:23:09 -06:00
Joe Fioti
9295ff8d72 Changed matvec 2024-01-12 16:16:04 -06:00
Joe Fioti
e5dcff3f34 Test commit 2024-01-12 13:28:04 -06:00
Joe Fioti
a1acd5883b Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:26:32 -06:00
Joe Fioti
556e386621 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:23:08 -06:00
Joe Fioti
9f9256f08a Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:23:08 -06:00
Joe Fioti
f3c53c1193 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
9f668ee333 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
617ef95c09 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
c539946c25 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
7e9f1c7fc0 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
cf0e6ad2f6 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
9813b188f3 reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
bf7c1c5608 reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
ec09c0202b reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
71365cf2d4 Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
481d074f5a Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
a240e2adc8 Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
c3643925ef removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
a6b368fa14 removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
ab9df3d94e removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
c727113351 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
d203df40d5 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
c506d1e783 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
56ce86f194 Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
54a8ebc60d Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
b3e07bd638 Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
94a6a0a9e9 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
fb279c9ee6 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
3ae34ad3b3 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
6b08212df8 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
03d2d02d00 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
0f09b19199 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
fcf232699f Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
1ed89b5656 Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
69da97727b Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
9edf9cdc0b Fixed swish compiler 2024-01-10 15:46:53 -06:00
Joe Fioti
2f13fd6100 Fixed swish compiler 2024-01-10 15:46:53 -06:00
Joe Fioti
ed278c9be3 Merge pull request #12 from TheSeamau5/matmul
Minor improvement to f16 matmul, Longer prompt and token generation for testing
2024-01-10 12:38:32 -06:00
Joe Fioti
9e04457895 Merge pull request #12 from TheSeamau5/matmul
Minor improvement to f16 matmul, Longer prompt and token generation for testing
2024-01-10 12:38:32 -06:00
Joe Fioti
e6c4291db6 Update other.rs 2024-01-10 12:36:34 -06:00
Joe Fioti
f62e6ad85e Update other.rs 2024-01-10 12:36:34 -06:00
Hassan Hayat
0ba62fde38 Minor improvement to f16 matmul, Longer prompt and token generation for testing 2024-01-10 12:31:43 -06:00
Hassan Hayat
d62f2e217a Minor improvement to f16 matmul, Longer prompt and token generation for testing 2024-01-10 12:31:43 -06:00
Joe Fioti
f385ea287e Fix 2024-01-10 12:30:25 -06:00
Joe Fioti
140ee69480 Fix 2024-01-10 12:30:25 -06:00
Joe Fioti
2c93b7788c Simplified copy compiler 2024-01-10 12:29:43 -06:00
Joe Fioti
4fdc8f38eb Simplified copy compiler 2024-01-10 12:29:43 -06:00
Joe Fioti
c0645fe35e Small changes 2024-01-10 09:18:42 -06:00
Joe Fioti
5b5812defa Small changes 2024-01-10 09:18:42 -06:00
Joe Fioti
349e3d2472 Merge pull request #11 from TheSeamau5/fix_swish
Fix swish
2024-01-10 09:12:27 -06:00
Joe Fioti
fa67608d48 Merge pull request #11 from TheSeamau5/fix_swish
Fix swish
2024-01-10 09:12:27 -06:00
Hassan Hayat
527c20d146 Fix swish 2024-01-10 00:52:04 -06:00
Hassan Hayat
ff1da67423 Fix swish 2024-01-10 00:52:04 -06:00
Joe Fioti
efd7489a1c Small kernel simplifications 2024-01-09 22:22:36 -06:00
Joe Fioti
4dd7cd7cfd Small kernel simplifications 2024-01-09 22:22:36 -06:00
Joe Fioti
33274b905e Fixed 2024-01-09 22:04:12 -06:00
Joe Fioti
3670378bc6 Fixed 2024-01-09 22:04:12 -06:00
Joe Fioti
275180be20 Improvements to vecmat 2024-01-09 22:02:47 -06:00
Joe Fioti
40a62e70be Improvements to vecmat 2024-01-09 22:02:47 -06:00
Joe Fioti
95462aa89e Shapetracker hack 2024-01-09 10:02:32 -06:00
Joe Fioti
7a9f9e04d0 Shapetracker hack 2024-01-09 10:02:32 -06:00
Joe Fioti
cf35b286f2 organization 2024-01-09 09:56:45 -06:00
Joe Fioti
e1cf44a4e0 organization 2024-01-09 09:56:45 -06:00
Joe Fioti
b891b8b595 Added unused softmax op 2024-01-09 00:54:02 -06:00
Joe Fioti
67366e1a2f Added unused softmax op 2024-01-09 00:54:02 -06:00
Joe Fioti
ee8206e2ca Improved compiler matching 2024-01-08 23:29:09 -06:00
Joe Fioti
5cdc559241 Improved compiler matching 2024-01-08 23:29:09 -06:00
Joe Fioti
daa7166534 Small cse improvement 2024-01-08 16:30:04 -06:00
Joe Fioti
2cf0bc29c8 Small cse improvement 2024-01-08 16:30:04 -06:00
Joe Fioti
139ae0ddad ggml rms norm 2024-01-08 12:46:17 -06:00
Joe Fioti
703f4d3847 ggml rms norm 2024-01-08 12:46:17 -06:00
Joe Fioti
d79042d334 dyn symbols in ops 2024-01-07 16:42:45 -06:00
Joe Fioti
f9b52f0058 dyn symbols in ops 2024-01-07 16:42:45 -06:00
Joe Fioti
5b50192830 Swish op 2024-01-07 14:13:43 -06:00
Joe Fioti
ae431e0dd4 Swish op 2024-01-07 14:13:43 -06:00
Joe Fioti
35626309ac Fixed generic compiler 2024-01-07 00:06:02 -06:00
Joe Fioti
a38168a91c Fixed generic compiler 2024-01-07 00:06:02 -06:00
Joe Fioti
64ebab654f Small changes 2024-01-06 23:38:52 -06:00
Joe Fioti
ec0ea40bbe Small changes 2024-01-06 23:38:52 -06:00
Joe Fioti
49ae10a25e Fixes 2024-01-06 22:26:56 -06:00
Joe Fioti
1a1ba5216b Fixes 2024-01-06 22:26:56 -06:00
Joe Fioti
0bbc6215d8 named structs 2024-01-06 22:11:06 -06:00
Joe Fioti
4e5300c4d4 named structs 2024-01-06 22:11:06 -06:00
Joe Fioti
166d4a12a5 Small 2024-01-06 12:11:48 -06:00
Joe Fioti
e4f90c304b Small 2024-01-06 12:11:48 -06:00
Joe Fioti
fa966c8c7c Contiguous elimination 2024-01-06 12:11:17 -06:00
Joe Fioti
9a0261acd2 Contiguous elimination 2024-01-06 12:11:17 -06:00
Joe Fioti
743bacb125 Fixed graph selector bug and added broken contiguous elimination 2024-01-05 22:58:44 -06:00
Joe Fioti
d0afd42eb2 Fixed graph selector bug and added broken contiguous elimination 2024-01-05 22:58:44 -06:00
Joe Fioti
4c9691c49d Fast mistral loading 2024-01-05 10:13:01 -06:00
Joe Fioti
9aaff41dfa Fast mistral loading 2024-01-05 10:13:01 -06:00
Joe Fioti
a8b6508155 Fast mistral loading 2024-01-05 10:11:29 -06:00
Joe Fioti
a23e536fa0 Fast mistral loading 2024-01-05 10:11:29 -06:00
Joe Fioti
e654f3e72d No copy metal buffers 2024-01-03 20:39:55 -05:00
Joe Fioti
1a6ce5df82 No copy metal buffers 2024-01-03 20:39:55 -05:00
Joe Fioti
a6cd8d9b0f Small changes 2024-01-03 19:46:54 -05:00
Joe Fioti
8a62e090a3 Small changes 2024-01-03 19:46:54 -05:00
Joe Fioti
b550de47e4 reinterpret entire array at once 2024-01-03 19:29:40 -05:00
Joe Fioti
5bc2477352 reinterpret entire array at once 2024-01-03 19:29:40 -05:00
Joe Fioti
370973108d Changed embedding test 2024-01-03 19:22:03 -05:00
Joe Fioti
88ed1ded6d Changed embedding test 2024-01-03 19:22:03 -05:00
Joe Fioti
e9b8a883d0 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-03 19:19:34 -05:00
Joe Fioti
4a7db75715 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-03 19:19:34 -05:00
Joe Fioti
72b3cba68b removed embedding init 2024-01-03 19:19:33 -05:00
Joe Fioti
e7c78e9b46 removed embedding init 2024-01-03 19:19:33 -05:00
Joe Fioti
0bc32b9c92 Changed graphselector and to_ids 2024-01-03 19:16:28 -05:00
Joe Fioti
9b81ef2326 Changed graphselector and to_ids 2024-01-03 19:16:28 -05:00
Joe Fioti
cfc8e7dae2 Update README.md 2024-01-03 12:24:44 -05:00
Joe Fioti
09666f93ab Update README.md 2024-01-03 12:24:44 -05:00
Joe Fioti
b489a86fa9 Removed petgraph fork 2024-01-03 11:14:31 -05:00
Joe Fioti
4d4338fb58 Removed petgraph fork 2024-01-03 11:14:31 -05:00
Joe Fioti
805ebb1931 Fixed mistral and llama 2024-01-02 20:20:03 -05:00
Joe Fioti
a57b316216 Fixed mistral and llama 2024-01-02 20:20:03 -05:00
Joe Fioti
94e08ae947 reenabled metal tests 2024-01-02 13:34:02 -05:00
Joe Fioti
21aee96114 reenabled metal tests 2024-01-02 13:34:02 -05:00
Joe Fioti
ac802a3273 tests pasing 2024-01-02 13:13:29 -05:00
Joe Fioti
70f4fff5c2 tests pasing 2024-01-02 13:13:29 -05:00
Joe Fioti
f2e1c17c8c Remoded id_remap 2024-01-02 13:05:36 -05:00
Joe Fioti
9493c11a53 Remoded id_remap 2024-01-02 13:05:36 -05:00
Joe Fioti
7c72d5b06f Started adding remap infra 2024-01-02 12:48:50 -05:00
Joe Fioti
a15cfbae65 Started adding remap infra 2024-01-02 12:48:50 -05:00
Joe Fioti
34ab545763 Small changes 2024-01-01 21:26:04 -05:00
Joe Fioti
e67d3e6598 Small changes 2024-01-01 21:26:04 -05:00
Joe Fioti
621536a1dd Fixed cse 2024-01-01 14:04:59 -05:00
Joe Fioti
6d9f9176cd Fixed cse 2024-01-01 14:04:59 -05:00
Joe Fioti
2e81b54446 Arange fix 2024-01-01 12:04:32 -05:00
Joe Fioti
38acdf315e Arange fix 2024-01-01 12:04:32 -05:00
Joe Fioti
30dff8597c Fix 2024-01-01 11:32:11 -05:00
Joe Fioti
2ebd5f2deb Fix 2024-01-01 11:32:11 -05:00
Joe Fioti
162b8c38a1 Changeed hl_ops 2024-01-01 11:31:42 -05:00
Joe Fioti
1fb155ddfd Changeed hl_ops 2024-01-01 11:31:42 -05:00
Joe Fioti
241b9f527b Optimized storage compiler 2024-01-01 00:02:06 -05:00
Joe Fioti
53dc4dd9df Optimized storage compiler 2024-01-01 00:02:06 -05:00
Joe Fioti
7c7558fcb3 Optimized graph selector 2023-12-31 23:18:06 -05:00
Joe Fioti
5262e32346 Optimized graph selector 2023-12-31 23:18:06 -05:00
Joe Fioti
664fad5f84 Changed graph searcher 2023-12-31 13:33:40 -05:00
Joe Fioti
4c3e530ef3 Changed graph searcher 2023-12-31 13:33:40 -05:00
Joe Fioti
d582111d04 Working limited reuse 2023-12-30 17:27:03 -05:00
Joe Fioti
e9384dc714 Working limited reuse 2023-12-30 17:27:03 -05:00
Joe Fioti
b97da50c9d Update README.md 2023-12-30 14:18:52 -05:00
Joe Fioti
517124b424 Update README.md 2023-12-30 14:18:52 -05:00
Joe Fioti
0ef3121ac6 Update readme 2023-12-30 14:12:58 -05:00
Joe Fioti
542f74f404 Update readme 2023-12-30 14:12:58 -05:00
Joe Fioti
8662ba864d Update readme 2023-12-30 14:12:06 -05:00
Joe Fioti
0fc68006d5 Update readme 2023-12-30 14:12:06 -05:00
Joe Fioti
eac3a57b6d Update readme 2023-12-30 14:11:12 -05:00
Joe Fioti
f46bc1cb99 Update readme 2023-12-30 14:11:12 -05:00
Joe Fioti
8f7004c4c3 Merge pull request #10 from TheSeamau5/mistral
Shell script to download mistral that actually works
2023-12-30 13:59:05 -05:00
Joe Fioti
185facb1d5 Merge pull request #10 from TheSeamau5/mistral
Shell script to download mistral that actually works
2023-12-30 13:59:05 -05:00
Joe Fioti
07d0febef1 Fixed memory leak in shared storage buffers 2023-12-30 13:52:02 -05:00
Joe Fioti
d35a40eacb Fixed memory leak in shared storage buffers 2023-12-30 13:52:02 -05:00
Hassan Hayat
8a744e6035 Shell script that actually works 2023-12-30 02:26:31 -06:00
Hassan Hayat
50b47f8610 Shell script that actually works 2023-12-30 02:26:31 -06:00
Joe Fioti
c1af144891 Cleaned up llama 2023-12-29 21:19:19 -05:00
Joe Fioti
dd123fec89 Cleaned up llama 2023-12-29 21:19:19 -05:00
Joe Fioti
92cca97a76 Updates 2023-12-29 20:42:25 -05:00
Joe Fioti
a5d01c7576 Updates 2023-12-29 20:42:25 -05:00
Joe Fioti
84fbf805c3 Merge 2023-12-29 14:54:50 -05:00
Joe Fioti
51545ee82c Merge 2023-12-29 14:54:50 -05:00
Joe Fioti
e16771035f Small changes 2023-12-29 14:51:14 -05:00
Joe Fioti
10ee2c7343 Small changes 2023-12-29 14:51:14 -05:00
Joe Fioti
f637fff192 Merge pull request #9 from TheSeamau5/mistral
Add support for Mistral
2023-12-29 14:45:58 -05:00
Joe Fioti
3e0cafbae3 Merge pull request #9 from TheSeamau5/mistral
Add support for Mistral
2023-12-29 14:45:58 -05:00
Joe Fioti
d6c9c977d8 removed metal 2023-12-29 14:43:24 -05:00
Joe Fioti
1a0f59943e removed metal 2023-12-29 14:43:24 -05:00
Joe Fioti
5032d894b8 Fixed mistral example 2023-12-29 14:30:12 -05:00
Joe Fioti
2046ee9ade Fixed mistral example 2023-12-29 14:30:12 -05:00
Joe Fioti
d48ac14458 KV cached mistral 2023-12-29 14:21:52 -05:00
Joe Fioti
24ff638e43 KV cached mistral 2023-12-29 14:21:52 -05:00
Joe Fioti
7bb4e856ec Small alterations 2023-12-28 16:06:49 -05:00
Joe Fioti
be93cfe817 Small alterations 2023-12-28 16:06:49 -05:00
Joe Fioti
140aeb4591 Faster mistral 2023-12-28 15:59:01 -05:00
Joe Fioti
907dadc6a0 Faster mistral 2023-12-28 15:59:01 -05:00
Hassan Hayat
c9a1e5c47d Convert the transformer layers into an array 2023-12-28 02:06:05 -06:00
Hassan Hayat
f36d98363c Convert the transformer layers into an array 2023-12-28 02:06:05 -06:00
Hassan Hayat
c05c0e0575 Implemented serialize module, and call keep_weights 2023-12-27 23:17:09 -06:00
Hassan Hayat
123b48d5ec Implemented serialize module, and call keep_weights 2023-12-27 23:17:09 -06:00
Hassan Hayat
c4553fc132 Revert "Merge remote-tracking branch 'upstream/main' into mistral"
This reverts commit c2a11bf114, reversing
changes made to 22ae700048.
2023-12-27 21:19:08 -06:00
Hassan Hayat
b38be86191 Revert "Merge remote-tracking branch 'upstream/main' into mistral"
This reverts commit c2a11bf114, reversing
changes made to 22ae700048.
2023-12-27 21:19:08 -06:00
Hassan Hayat
da3970082a Merge remote-tracking branch 'upstream/main' into mistral 2023-12-27 21:06:24 -06:00
Hassan Hayat
c2a11bf114 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-27 21:06:24 -06:00
Joe Fioti
75a141d8ba Shared Metal storage buffers 2023-12-27 21:36:54 -05:00
Joe Fioti
ee17a48dbe Shared Metal storage buffers 2023-12-27 21:36:54 -05:00
Hassan Hayat
02df7e7f8d Spring cleaning 2023-12-27 19:30:22 -06:00
Hassan Hayat
22ae700048 Spring cleaning 2023-12-27 19:30:22 -06:00
Hassan Hayat
dbe6a42018 Iteration works! Slowly but it works 2023-12-27 11:50:45 -06:00
Hassan Hayat
7387ca1b19 Iteration works! Slowly but it works 2023-12-27 11:50:45 -06:00
Hassan Hayat
9e03c3421f Update model.rs 2023-12-27 02:16:21 -06:00
Hassan Hayat
8a6d088ff3 Update model.rs 2023-12-27 02:16:21 -06:00
Hassan Hayat
305e8f104c Remove more unused code 2023-12-27 02:08:21 -06:00
Hassan Hayat
472eae1576 Remove more unused code 2023-12-27 02:08:21 -06:00
Hassan Hayat
6c234daba2 Simplify code 2023-12-27 02:00:43 -06:00
Hassan Hayat
dfb8691923 Simplify code 2023-12-27 02:00:43 -06:00
Hassan Hayat
2784738e41 Loading is fast 2023-12-27 01:57:32 -06:00
Hassan Hayat
b86b27e0c7 Loading is fast 2023-12-27 01:57:32 -06:00
Hassan Hayat
ba3faa49df It works! 2023-12-27 00:38:45 -06:00
Hassan Hayat
c833a65153 It works! 2023-12-27 00:38:45 -06:00
Hassan Hayat
cab6b2fff2 Made the slices into expressions 2023-12-26 17:10:50 -06:00
Hassan Hayat
35e5da1ff4 Made the slices into expressions 2023-12-26 17:10:50 -06:00
Hassan Hayat
67aac97299 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 17:01:56 -06:00
Hassan Hayat
2b884d6304 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 17:01:56 -06:00
Joe Fioti
9fa0b8d0a5 Added symbolic slicing 2023-12-26 17:59:19 -05:00
Joe Fioti
422fd32d74 Added symbolic slicing 2023-12-26 17:59:19 -05:00
Hassan Hayat
1d88be2001 Update model.rs 2023-12-26 14:22:49 -06:00
Hassan Hayat
3d5c3180be Update model.rs 2023-12-26 14:22:49 -06:00
Hassan Hayat
47b61ac847 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 14:20:05 -06:00
Hassan Hayat
18560d0852 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 14:20:05 -06:00
Joe Fioti
1400aecf1d Small changes 2023-12-26 13:36:25 -05:00
Joe Fioti
9e3bea8cac Small changes 2023-12-26 13:36:25 -05:00
Joe Fioti
b4bf84840e Assign operators 2023-12-26 12:50:48 -05:00
Joe Fioti
941a8b93eb Assign operators 2023-12-26 12:50:48 -05:00
Joe Fioti
666cbe6c5a Llama reductions 2023-12-26 12:40:07 -05:00
Joe Fioti
33b7f0914f Llama reductions 2023-12-26 12:40:07 -05:00
Joe Fioti
2b2e06d6fa Simplified llama 2023-12-26 12:27:44 -05:00
Joe Fioti
eaa4ad8ef5 Simplified llama 2023-12-26 12:27:44 -05:00
Hassan Hayat
750a6e9e8b Remove unused imports 2023-12-26 11:13:30 -06:00
Hassan Hayat
0028b5ca78 Remove unused imports 2023-12-26 11:13:30 -06:00
Hassan Hayat
d4b18a0e35 Update model.rs 2023-12-26 11:00:34 -06:00
Hassan Hayat
22d7c563cb Update model.rs 2023-12-26 11:00:34 -06:00
Hassan Hayat
4671708601 Try to do an inference loop and failt 2023-12-26 10:50:03 -06:00
Hassan Hayat
9b3948a3ff Try to do an inference loop and failt 2023-12-26 10:50:03 -06:00
Hassan Hayat
4c415fba7b Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 10:48:51 -06:00
Hassan Hayat
f775833e10 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 10:48:51 -06:00
Joe Fioti
b1b06b1e15 Removed forward_kv from llama 2023-12-26 11:47:27 -05:00
Joe Fioti
bf8f3d91d2 Removed forward_kv from llama 2023-12-26 11:47:27 -05:00
Hassan Hayat
b40fb1a94b we have logits 2023-12-26 09:58:01 -06:00
Hassan Hayat
2e52833bb5 we have logits 2023-12-26 09:58:01 -06:00
Hassan Hayat
7e2518bbba comment out attention mask 2023-12-26 06:58:20 -06:00
Hassan Hayat
808cf7849e comment out attention mask 2023-12-26 06:58:20 -06:00
Hassan Hayat
1a454b23f8 Remove unused code 2023-12-26 06:39:53 -06:00
Hassan Hayat
bc4483706b Remove unused code 2023-12-26 06:39:53 -06:00
Hassan Hayat
8d0cff2b0b Start debugging the full transformer pass 2023-12-26 06:34:27 -06:00
Hassan Hayat
35097e8e2b Start debugging the full transformer pass 2023-12-26 06:34:27 -06:00
Hassan Hayat
1fdc8de899 Single layer is correct now 2023-12-26 06:00:23 -06:00
Hassan Hayat
995293e5da Single layer is correct now 2023-12-26 06:00:23 -06:00
Hassan Hayat
e9d7604f0b Successfully move the code into a forward method 2023-12-26 05:37:34 -06:00
Hassan Hayat
85824bb1ee Successfully move the code into a forward method 2023-12-26 05:37:34 -06:00
Hassan Hayat
652f0e365f Yay! A full layer now just werks 2023-12-25 19:17:30 -06:00
Hassan Hayat
554331f567 Yay! A full layer now just werks 2023-12-25 19:17:30 -06:00
Hassan Hayat
7311c8f48c Got query states working 2023-12-25 18:28:21 -06:00
Hassan Hayat
9a904b6dcc Got query states working 2023-12-25 18:28:21 -06:00
Hassan Hayat
8b4234eb60 Get rotary embeddings working dammit 2023-12-25 18:24:54 -06:00
Hassan Hayat
b121bcb20b Get rotary embeddings working dammit 2023-12-25 18:24:54 -06:00
Hassan Hayat
d63ceba488 Precompute rope once using throwaway graph 2023-12-25 14:18:22 -06:00
Hassan Hayat
67d8d6b992 Precompute rope once using throwaway graph 2023-12-25 14:18:22 -06:00
Hassan Hayat
0130d5dfd9 Found the issue with rotary embeddings
It was f16. Rotary embeddings have to be precomputed in f32
2023-12-25 13:51:31 -06:00
Hassan Hayat
ab8f7187e6 Found the issue with rotary embeddings
It was f16. Rotary embeddings have to be precomputed in f32
2023-12-25 13:51:31 -06:00
Hassan Hayat
0c291d594b Merge remote-tracking branch 'upstream/main' into mistral 2023-12-25 12:53:05 -06:00
Hassan Hayat
1d828f7982 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-25 12:53:05 -06:00
Hassan Hayat
6f3d52f345 save progress 2023-12-25 12:52:57 -06:00
Hassan Hayat
ed1f76808d save progress 2023-12-25 12:52:57 -06:00
Joe Fioti
a00fe78aa1 Fixed scalar ops 2023-12-25 09:33:37 -05:00
Joe Fioti
58a56f9fc0 Fixed scalar ops 2023-12-25 09:33:37 -05:00
Hassan Hayat
2254b4c96c save progress 2023-12-25 02:15:15 -06:00
Hassan Hayat
b6a0caa79b save progress 2023-12-25 02:15:15 -06:00
Hassan Hayat
4e7c6c27ce repeat kv 2023-12-25 01:32:00 -06:00
Hassan Hayat
4eb0a8e1fb repeat kv 2023-12-25 01:32:00 -06:00
Hassan Hayat
e222cb7a97 Update model.rs 2023-12-25 01:14:26 -06:00
Hassan Hayat
858f198b43 Update model.rs 2023-12-25 01:14:26 -06:00
Hassan Hayat
4b5872b5d1 Applying rotary embeddings work 2023-12-25 01:12:21 -06:00
Hassan Hayat
d2269eebf7 Applying rotary embeddings work 2023-12-25 01:12:21 -06:00
Hassan Hayat
7f8b21f71f Get rotate half working 2023-12-25 00:36:21 -06:00
Hassan Hayat
ca1703745f Get rotate half working 2023-12-25 00:36:21 -06:00
Hassan Hayat
c3f2547349 Update model.rs 2023-12-24 23:55:01 -06:00
Hassan Hayat
5c24050775 Update model.rs 2023-12-24 23:55:01 -06:00
Hassan Hayat
927fb9fac2 Still not there with attention 2023-12-24 23:29:39 -06:00
Hassan Hayat
167944b422 Still not there with attention 2023-12-24 23:29:39 -06:00
Hassan Hayat
4cec36f4b5 Fix broken division 2023-12-24 23:23:55 -06:00
Hassan Hayat
66fbf23d67 Fix broken division 2023-12-24 23:23:55 -06:00
Hassan Hayat
7e401c69c7 past norm, to query states 2023-12-24 21:45:10 -06:00
Hassan Hayat
5c4076bc8c past norm, to query states 2023-12-24 21:45:10 -06:00
Hassan Hayat
d2cb4f0d48 Get Mistral RMS norm working 2023-12-24 21:04:25 -06:00
Hassan Hayat
ef16ee6b23 Get Mistral RMS norm working 2023-12-24 21:04:25 -06:00
Hassan Hayat
c518caacf2 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:28:52 -06:00
Hassan Hayat
acef1725f3 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:28:52 -06:00
Joe Fioti
6cad14a20b Changed gather 2023-12-24 21:28:32 -05:00
Joe Fioti
c33333724d Changed gather 2023-12-24 21:28:32 -05:00
Hassan Hayat
8149440f8f Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:25:16 -06:00
Hassan Hayat
b4717747d5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:25:16 -06:00
Joe Fioti
9fc98f3288 Fixed contiguous 2023-12-24 21:07:48 -05:00
Joe Fioti
24347bf69c Fixed contiguous 2023-12-24 21:07:48 -05:00
Hassan Hayat
c5aa4d2975 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:07:26 -06:00
Hassan Hayat
9ec05b25a8 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:07:26 -06:00
Joe Fioti
bd83d880a9 Fixed metal prim compiler and other things 2023-12-24 21:05:45 -05:00
Joe Fioti
8037d370ee Fixed metal prim compiler and other things 2023-12-24 21:05:45 -05:00
Hassan Hayat
4a0a86577e Embeddings is correct 2023-12-24 19:25:10 -06:00
Hassan Hayat
75ea980bd2 Embeddings is correct 2023-12-24 19:25:10 -06:00
Hassan Hayat
de5049577c Push code 2023-12-24 11:47:45 -06:00
Hassan Hayat
5f99756be4 Push code 2023-12-24 11:47:45 -06:00
Hassan Hayat
33724c7214 found the source of crash 2023-12-24 11:35:16 -06:00
Hassan Hayat
4f75032c7e found the source of crash 2023-12-24 11:35:16 -06:00
Hassan Hayat
a45b4b6e85 Focus on debug 2023-12-24 11:03:29 -06:00
Hassan Hayat
912db261fe Focus on debug 2023-12-24 11:03:29 -06:00
Hassan Hayat
fad53704fd include the correct value for rope theta 2023-12-24 07:17:43 -06:00
Hassan Hayat
29aeac0531 include the correct value for rope theta 2023-12-24 07:17:43 -06:00
Hassan Hayat
a426971470 Found a panic 2023-12-24 07:12:55 -06:00
Hassan Hayat
1bd50bff21 Found a panic 2023-12-24 07:12:55 -06:00
Hassan Hayat
d2d733b931 Add print statement, find all zeros 2023-12-24 06:03:24 -06:00
Hassan Hayat
1ad6edd9ce Add print statement, find all zeros 2023-12-24 06:03:24 -06:00
Hassan Hayat
d924809d85 Fix grouped query attention 2023-12-24 05:51:45 -06:00
Hassan Hayat
24b1b324e6 Fix grouped query attention 2023-12-24 05:51:45 -06:00
Hassan Hayat
531b28f75a Test inference code 2023-12-24 02:58:05 -06:00
Hassan Hayat
ce40bb7f58 Test inference code 2023-12-24 02:58:05 -06:00
Hassan Hayat
be667fb936 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 02:55:33 -06:00
Hassan Hayat
de2a2c8bb8 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 02:55:33 -06:00
Joe Fioti
55e68dff43 Moved allocations outside MetalKernelForward 2023-12-23 15:52:34 -05:00
Joe Fioti
e922d565a7 Moved allocations outside MetalKernelForward 2023-12-23 15:52:34 -05:00
Hassan Hayat
1763e85aa7 Remove unneeded annotation 2023-12-23 13:50:21 -06:00
Hassan Hayat
5d10422881 Remove unneeded annotation 2023-12-23 13:50:21 -06:00
Hassan Hayat
a004408327 Implement generic arange and argmax 2023-12-23 13:49:22 -06:00
Hassan Hayat
cc0b34a640 Implement generic arange and argmax 2023-12-23 13:49:22 -06:00
Hassan Hayat
21596a01d7 Successfully load all the weights 2023-12-23 02:02:54 -06:00
Hassan Hayat
68f0c6f6ca Successfully load all the weights 2023-12-23 02:02:54 -06:00
Hassan Hayat
3a2ab1d176 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 17:21:12 -06:00
Hassan Hayat
ff7289ef39 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 17:21:12 -06:00
Hassan Hayat
0b370359c4 precompute inverse freqs works 2023-12-22 16:41:27 -06:00
Hassan Hayat
d2b720da3f precompute inverse freqs works 2023-12-22 16:41:27 -06:00
Joe Fioti
ef1054a921 fixed embedding test 2023-12-22 16:27:11 -05:00
Joe Fioti
0b30af2a7a fixed embedding test 2023-12-22 16:27:11 -05:00
Joe Fioti
5d97b4ee52 Fixed metal subtraction and llama 2023-12-22 16:23:45 -05:00
Joe Fioti
e179494ac4 Fixed metal subtraction and llama 2023-12-22 16:23:45 -05:00
Hassan Hayat
db2fc3cbb0 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 13:46:05 -06:00
Hassan Hayat
935caa24ce Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 13:46:05 -06:00
Joe Fioti
268a9b2cf8 Added metal equal compiler 2023-12-22 10:38:29 -05:00
Joe Fioti
5d8238bcf4 Added metal equal compiler 2023-12-22 10:38:29 -05:00
Hassan Hayat
a1c4f18725 Almost done loading model 2023-12-22 02:13:39 -06:00
Hassan Hayat
19ec1f1d36 Almost done loading model 2023-12-22 02:13:39 -06:00
Hassan Hayat
3032c685cd yoke 2023-12-22 00:43:48 -06:00
Hassan Hayat
c890ebdbe1 yoke 2023-12-22 00:43:48 -06:00
Hassan Hayat
a26d2fe86f Get it to compile again 2023-12-21 23:32:08 -06:00
Hassan Hayat
312305fcb7 Get it to compile again 2023-12-21 23:32:08 -06:00
Hassan Hayat
a402a29f93 Implement the model 2023-12-21 23:00:33 -06:00
Hassan Hayat
ed964105ec Implement the model 2023-12-21 23:00:33 -06:00
Hassan Hayat
414a3dcc83 Initial attention impl 2023-12-21 21:41:12 -06:00
Hassan Hayat
c51e87385f Initial attention impl 2023-12-21 21:41:12 -06:00
Hassan Hayat
fec403b9f5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-21 17:35:06 -06:00
Hassan Hayat
90e06d90e5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-21 17:35:06 -06:00
Hassan Hayat
25bf6ee63a Load the embeddings properly 2023-12-21 17:31:51 -06:00
Hassan Hayat
9e5880b130 Load the embeddings properly 2023-12-21 17:31:51 -06:00
Joe Fioti
9220e7b1e0 redid metal arange compiler 2023-12-21 14:54:16 -05:00
Joe Fioti
4a97c8bee9 redid metal arange compiler 2023-12-21 14:54:16 -05:00
Joe Fioti
93d45509ad Added subtraction compiler 2023-12-21 14:20:19 -05:00
Joe Fioti
c0d0ec0c32 Added subtraction compiler 2023-12-21 14:20:19 -05:00
Hassan Hayat
c4b4233e20 Include start token 2023-12-21 06:09:14 -06:00
Hassan Hayat
61e59b27ec Include start token 2023-12-21 06:09:14 -06:00
Hassan Hayat
d16d22492e Get tokenizer working 2023-12-21 06:02:33 -06:00
Hassan Hayat
3a25325d37 Get tokenizer working 2023-12-21 06:02:33 -06:00
Joe Fioti
4599fec534 Primop gather 2023-12-20 19:36:11 -05:00
Joe Fioti
1637d0fdb8 Primop gather 2023-12-20 19:36:11 -05:00
Joe Fioti
13de77b68c New common metal buffer compiler 2023-12-20 16:56:16 -05:00
Joe Fioti
700d8f71e2 New common metal buffer compiler 2023-12-20 16:56:16 -05:00
Joe Fioti
84eea2a0eb Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-20 10:38:02 -05:00
Joe Fioti
b2be7b2583 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-20 10:38:02 -05:00
Joe Fioti
5c396368b6 Changed readme 2023-12-20 10:37:57 -05:00
Joe Fioti
7335d07755 Changed readme 2023-12-20 10:37:57 -05:00
Joe Fioti
b63746fe84 Merge pull request #8 from TheSeamau5/conv2d
Conv2d
2023-12-20 09:35:38 -05:00
Joe Fioti
ef964536e9 Merge pull request #8 from TheSeamau5/conv2d
Conv2d
2023-12-20 09:35:38 -05:00
Hassan Hayat
f96e3a903e Conv2D forward implemented 2023-12-19 23:23:13 -06:00
Hassan Hayat
7ec82a97d6 Conv2D forward implemented 2023-12-19 23:23:13 -06:00
Hassan Hayat
741b167910 First implementation of conv2d 2023-12-19 23:02:32 -06:00
Hassan Hayat
268fb4e9aa First implementation of conv2d 2023-12-19 23:02:32 -06:00
Joe Fioti
3b0b264ba5 Update README.md 2023-12-19 20:19:08 -05:00
Joe Fioti
1c7a3b8ed9 Update README.md 2023-12-19 20:19:08 -05:00
Joe Fioti
7c307c886e Made conv forward public 2023-12-19 20:00:50 -05:00
Joe Fioti
e00a89c647 Made conv forward public 2023-12-19 20:00:50 -05:00
Joe Fioti
c6d37ed5c5 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-19 19:59:54 -05:00
Joe Fioti
deef279977 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-19 19:59:54 -05:00
Joe Fioti
835527333c Move cumsum 2023-12-19 19:59:47 -05:00
Joe Fioti
b2735b8dc6 Move cumsum 2023-12-19 19:59:47 -05:00
Joe Fioti
7050a8bd7a Merge pull request #7 from TheSeamau5/conv1d
Conv1D module
2023-12-19 19:54:10 -05:00
Joe Fioti
2c6ac7124e Merge pull request #7 from TheSeamau5/conv1d
Conv1D module
2023-12-19 19:54:10 -05:00
Hassan Hayat
cc0c2bf8cb Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 16:40:57 -06:00
Hassan Hayat
e335bb24df Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 16:40:57 -06:00
Joe Fioti
ef0768ebef ARange in llama 2023-12-19 17:39:19 -05:00
Joe Fioti
a6c8c4c254 ARange in llama 2023-12-19 17:39:19 -05:00
Hassan Hayat
1f81ffb182 Remove extra comment 2023-12-19 16:32:39 -06:00
Hassan Hayat
23b7937507 Remove extra comment 2023-12-19 16:32:39 -06:00
Hassan Hayat
ac23472220 Remove extra comments 2023-12-19 16:31:49 -06:00
Hassan Hayat
0e07eb7614 Remove extra comments 2023-12-19 16:31:49 -06:00
Hassan Hayat
f6e2fd1be2 Fix and pass the tests, define conv as a Rank-2 tensor (remove a reshape) 2023-12-19 16:31:10 -06:00
Hassan Hayat
ddc6644a87 Fix and pass the tests, define conv as a Rank-2 tensor (remove a reshape) 2023-12-19 16:31:10 -06:00
Joe Fioti
6d987df3e2 Small change 2023-12-19 17:23:34 -05:00
Joe Fioti
7b2fd581b6 Small change 2023-12-19 17:23:34 -05:00
Joe Fioti
6f810111c4 Tril and triu 2023-12-19 17:20:57 -05:00
Joe Fioti
3b154540da Tril and triu 2023-12-19 17:20:57 -05:00
Joe Fioti
0675610007 ARange, better symbolic minimizer 2023-12-19 17:02:50 -05:00
Joe Fioti
2ae67dd894 ARange, better symbolic minimizer 2023-12-19 17:02:50 -05:00
Hassan Hayat
8623843e72 Remove extra .DS_Store 2023-12-19 16:01:37 -06:00
Hassan Hayat
98ef29fec0 Remove extra .DS_Store 2023-12-19 16:01:37 -06:00
Hassan Hayat
7d37b56c20 Add harder test, doesn't pass yet 2023-12-19 15:42:21 -06:00
Hassan Hayat
54c48df279 Add harder test, doesn't pass yet 2023-12-19 15:42:21 -06:00
Hassan Hayat
7460fcde9d Simple design, no pool_out 2023-12-19 15:16:21 -06:00
Hassan Hayat
fc2a56039a Simple design, no pool_out 2023-12-19 15:16:21 -06:00
Hassan Hayat
b29f8e3a0f Alternative design, custom forward with generics at the forward function 2023-12-19 14:51:47 -06:00
Hassan Hayat
3031ead6dc Alternative design, custom forward with generics at the forward function 2023-12-19 14:51:47 -06:00
Hassan Hayat
20951c0721 Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 14:28:37 -06:00
Hassan Hayat
75b1064922 Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 14:28:37 -06:00
Joe Fioti
1a4135515b Symbolic changes 2023-12-19 11:28:15 -05:00
Joe Fioti
acbb1b6e2c Symbolic changes 2023-12-19 11:28:15 -05:00
Hassan Hayat
c0632cb689 Update convolution.rs 2023-12-18 22:35:58 -06:00
Hassan Hayat
144e3b7a98 Remove unnecessary comments 2023-12-18 22:30:58 -06:00
Hassan Hayat
dfd21a343b Conv1D module first pass 2023-12-18 22:29:50 -06:00
Joe Fioti
0faadea621 Added cumsum 2023-12-18 23:28:37 -05:00
Joe Fioti
0dd8f4b7c7 2D convolutions 2023-12-18 18:17:42 -05:00
Joe Fioti
96e39c2535 1D last dim pooling on 2D tensors 2023-12-18 11:51:09 -06:00
Joe Fioti
909d5b7836 Pooling with dilation 2023-12-17 17:06:39 -06:00
Joe Fioti
1125351f4c 1D pooling 2023-12-17 12:45:04 -06:00
Joe Fioti
345622f452 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-15 21:55:36 -06:00
Joe Fioti
53b9bd6e61 Added MetalArange 2023-12-15 21:55:29 -06:00
Joe Fioti
e7d0a08150 Merge pull request #4 from TheSeamau5/arange
Simple Pooling implementation
2023-12-15 11:23:24 -06:00
Joe Fioti
0939f50ce2 removed fake sum reduction, generalized constants 2023-12-15 11:18:31 -06:00
Hassan Hayat
84d7a0cedc Remove unused imports 2023-12-15 10:51:13 -06:00
Hassan Hayat
c9c540057b Let's keep it simple for now, kernel size = stride 2023-12-15 10:46:19 -06:00
Joe Fioti
a2edbe14ec Commonize more metal compilers 2023-12-15 10:46:06 -06:00
Hassan Hayat
694fa93d30 Reverting to simpler impl 2023-12-15 01:58:05 -06:00
Hassan Hayat
4214a33525 Save code 2023-12-15 01:34:06 -06:00
Hassan Hayat
404322b4ab Save code 2023-12-15 01:23:26 -06:00
Joe Fioti
afd3eeee88 Removed Function output type 2023-12-14 21:13:56 -06:00
Joe Fioti
84adc99c33 llama cleanups 2023-12-14 21:00:12 -06:00
Joe Fioti
3f4b592c60 Fixed 2023-12-14 20:47:06 -06:00
Joe Fioti
d61c848f6a Added more symbolic minimization rules 2023-12-14 20:38:03 -06:00
Joe Fioti
94c7d00517 Small changes 2023-12-14 19:44:01 -06:00
Hassan Hayat
e799363d0d remove unused code 2023-12-14 19:41:24 -06:00
Hassan Hayat
d0d7f74e42 remove comments 2023-12-14 19:39:08 -06:00
Hassan Hayat
de5835822d Merge remote-tracking branch 'upstream/main' into arange 2023-12-14 19:29:45 -06:00
Hassan Hayat
77fb4305e8 add tests 2023-12-14 19:28:05 -06:00
Joe Fioti
4cdb364e4a Metal vecmat 2023-12-14 19:27:23 -06:00
Hassan Hayat
fbebf6d485 Added more tests 2023-12-14 15:41:16 -06:00
Hassan Hayat
4fde0f4524 Make pool n-dimensional 2023-12-14 14:22:36 -06:00
Joe Fioti
6f3cff1cd4 symbolic changes 2023-12-13 15:02:08 -06:00
Hassan Hayat
b90847c43f Start working on pooling 2023-12-13 13:55:32 -06:00
Joe Fioti
e5c7c8b2a2 Removed indexer 2023-12-12 23:47:58 -06:00
Joe Fioti
9e453719e3 Unified expressions 2023-12-12 22:36:38 -06:00
Joe Fioti
63b04f1e9a remvoed checking stuff in print 2023-12-12 16:50:28 -06:00
Joe Fioti
904baefa68 removed unsafe graph ref dereference 2023-12-12 16:48:17 -06:00
Joe Fioti
c82a00981a Tweaks 2023-12-12 15:07:13 -06:00
Joe Fioti
d1add4231f Fixed slow vecmat 2023-12-12 15:05:29 -06:00
Joe Fioti
678591a1a5 Low performance vecmat 2023-12-12 13:16:07 -06:00
Joe Fioti
f4a07f5259 Small refinements 2023-12-11 22:19:28 -06:00
Joe Fioti
e5e904498c Batch matmul fixes 2023-12-11 20:34:42 -06:00
Joe Fioti
89740bdd30 small 2023-12-09 09:36:20 -06:00
Joe Fioti
c6b72fa317 still broken bmm 2023-12-09 09:35:52 -06:00
Joe Fioti
80b917b02f Added RemapDownstream compiler 2023-12-08 22:31:19 -06:00
Joe Fioti
971361feac Removed copy remap 2023-12-08 21:53:05 -06:00
Joe Fioti
4a553724a2 Symblic changes 2023-12-08 14:17:24 -06:00
Joe Fioti
b87b30f045 simd batch matmul 2023-12-08 12:01:54 -06:00
Joe Fioti
a3a69f53da Working paddded batch matmul 2023-12-08 11:55:04 -06:00
Joe Fioti
cb659f3c25 improved symbolic algebra minimizer 2023-12-08 10:08:32 -06:00
Joe Fioti
8135540b22 Changed metal input rendering 2023-12-07 20:45:05 -06:00
Joe Fioti
3a10b6f4db Changed test 2023-12-07 16:41:02 -06:00
Joe Fioti
82d4a96ae1 Simd matmul 2D 2023-12-07 16:25:45 -06:00
Joe Fioti
802091e15e Removed expression interfaces 2023-12-07 00:00:43 -06:00
Joe Fioti
4292259db1 Replace dims with expressions 2023-12-06 23:36:35 -06:00
Joe Fioti
b6efabf216 Expr interface 2023-12-06 16:02:58 -06:00
Joe Fioti
7e5471bdfa Added big expression 2023-12-06 15:05:30 -06:00
Joe Fioti
5fa5aff813 Added small symbolic algebra lib and removed savage 2023-12-06 14:57:25 -06:00
Joe Fioti
5035ad1d99 Small improvements' 2023-12-06 11:56:29 -06:00
Joe Fioti
63797c90f9 small changes 2023-12-05 21:44:50 -06:00
Joe Fioti
8b475ea4f2 Fix CI 2023-12-05 10:28:48 -06:00
Joe Fioti
1cd4fb2e73 Merge 2023-12-05 10:21:10 -06:00
Joe Fioti
946ea8dfb8 Hybrid matmul 2023-12-05 10:15:15 -06:00
Joe Fioti
8c264fb2a5 broken matmul 2023-12-03 19:17:43 -05:00
Joe Fioti
b96b792612 Changed tensor api 2023-11-28 18:41:56 -05:00
Joe Fioti
909ea995b6 Fixed tests 2023-11-27 14:36:41 -05:00
Joe Fioti
4e197b512f Changed metal dispatching 2023-11-27 14:35:24 -05:00
Joe Fioti
aa3f8cce3d Matmul correction 2023-11-27 14:24:32 -05:00
Joe Fioti
00dcc29eb1 Faster batch matmul 2023-11-27 14:03:00 -05:00
Joe Fioti
032cec5c5a Compile time col-row-major ordering 2023-11-26 16:31:15 -05:00
Joe Fioti
f0d6fedc90 Small rmsnorm opt 2023-11-24 10:33:53 -05:00
Joe Fioti
bb9ff4f113 removed mutex from shared command buffer 2023-11-21 16:50:11 -06:00
Joe Fioti
1c3f6735f8 Remvoed metal attn matmul 2023-11-21 15:45:33 -06:00
Joe Fioti
2d210641d3 Added metal gpu gather 2023-11-21 13:59:17 -06:00
Joe Fioti
4078d895c7 Common metal prim ops 2023-11-20 12:17:37 -06:00
Joe Fioti
d67820b6ba Finished cuda prim unification 2023-11-20 10:42:42 -06:00
Joe Fioti
6869047b44 Unifyed cuda unary ops 2023-11-20 01:02:22 -06:00
Joe Fioti
ba7c3972b5 Common cuda copyto and copyfrom 2023-11-19 23:08:04 -06:00
Joe Fioti
f931504a09 Fixed cuda 2023-11-19 22:13:10 -06:00
Joe Fioti
52c18171a1 Small 2023-11-19 16:57:49 -06:00
Joe Fioti
b89dbefb3c Added set marking 2023-11-19 16:54:25 -06:00
Joe Fioti
07936bc8e4 remove mac test 2023-11-19 15:25:44 -06:00
Joe Fioti
647eda7895 Update and rename rust.yml to test.yml 2023-11-19 15:25:19 -06:00
Joe Fioti
5bb703084c tests passing 2023-11-19 15:13:17 -06:00
Joe Fioti
e7283e9105 Fixed native concats 2023-11-19 14:12:50 -06:00
Joe Fioti
7102d06e73 Partial fix 2023-11-18 10:34:45 -06:00
Joe Fioti
cf54dee88e broken version of native concat 2023-11-18 09:03:47 -06:00
Joe Fioti
854864ac5e Merge branch 'main' of https://github.com/jafioti/luminal 2023-11-18 08:38:33 -06:00
Joe Fioti
3eceaae45f Added dim slices and pading 2023-11-18 08:38:25 -06:00
Joe Fioti
e604a8cba0 Update README.md 2023-11-17 23:28:58 -06:00
Joe Fioti
c351acb075 more optimizations 2023-11-17 23:27:05 -06:00
Joe Fioti
d880efc1db graph selector optimizations 2023-11-16 18:07:16 -06:00
Joe Fioti
e118b293fd toposort at compile 2023-11-16 17:40:37 -06:00
Joe Fioti
254996063a Reworked selector API 2023-11-16 17:37:19 -06:00
Joe Fioti
a85b2ac301 Shared Metal Command Buffers 2023-11-16 15:21:01 -06:00
Joe Fioti
d2f8471943 action change 2023-11-12 14:43:06 -06:00
Joe Fioti
b6a7a3bc1e action change 2023-11-12 14:41:15 -06:00
Joe Fioti
3b3007cbdd Changed action 2023-11-12 14:38:27 -06:00
Joe Fioti
adc2092275 Macos gpu action 2023-11-12 14:35:21 -06:00
Joe Fioti
96831f2d4e Small 2023-11-12 12:04:50 -06:00
Joe Fioti
baf8664d10 Small changes 2023-11-12 11:55:15 -06:00
Joe Fioti
d071fd5397 Pasing tests 2023-11-12 11:38:00 -06:00
Joe Fioti
44f9415811 New testing, fixed cpu bug 2023-11-12 10:34:55 -06:00
Joe Fioti
b104364edb Removed simple tracker again 2023-11-10 10:19:11 -05:00
Joe Fioti
24bbf0ead9 Removed dyn_data 2023-11-10 10:18:15 -05:00
Joe Fioti
e8af292958 Changed tests 2023-11-09 22:17:53 -05:00
Joe Fioti
8d14f83bc3 Simplifications and API changes 2023-11-09 22:15:48 -05:00
Joe Fioti
2ff89167c2 Closer to working CommonBufferCompiler 2023-11-09 21:18:59 -05:00
Joe Fioti
87854bbdf0 Working llama at same speed as before 2023-11-08 16:38:51 -05:00
Joe Fioti
a3a4a972d7 optimized common buffer compilation 2023-11-06 21:23:25 -05:00
Joe Fioti
75fbb709d7 Share command queues on fp16 primops 2023-11-06 16:28:01 -05:00
Joe Fioti
6a311347bf New common buffer 2023-11-06 16:04:20 -05:00
Joe Fioti
2787fdd8b6 Switched to compilers 2023-11-05 22:31:37 -05:00
Joe Fioti
634f5c26ee Added schedule dependencies 2023-11-05 14:04:28 -05:00
Joe Fioti
e7683ac3ff Bugged common buffer 2023-11-01 22:20:42 -04:00
Joe Fioti
72fdc3bcfe Added preliminary internal graph to shared buffer op 2023-10-27 21:47:27 -04:00
Joe Fioti
271977d1dd Added unary shared metal command buffer 2023-10-27 21:17:34 -04:00
Joe Fioti
e6de090ed3 Started shared command buffers 2023-10-26 21:53:17 -04:00
Joe Fioti
65c0224ae5 Small 2023-10-15 22:03:04 -05:00
Joe Fioti
e957a4c99a Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-14 09:53:09 -05:00
Joe Fioti
4e6d5b733c Added constant to fakesumreduce opt 2023-10-14 09:53:03 -05:00
Joe Fioti
be591b2f4a Removed arch flags 2023-10-11 23:09:35 -05:00
Joe Fioti
bb90b73533 Fixed example 2023-10-11 23:03:25 -05:00
Joe Fioti
0abf5c2379 Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-11 17:31:36 -05:00
Joe Fioti
ccbf55923d Small changes 2023-10-11 11:56:05 -05:00
Joe Fioti
eb842428b7 Small changes 2023-10-10 23:04:24 -05:00
Joe Fioti
e61aa736db Added slowbatch matmul 2023-10-10 19:42:03 -05:00
Joe Fioti
8794afb246 re-added toposort caching 2023-10-10 16:00:05 -05:00
Joe Fioti
ddf32b6215 Cached weights 2023-10-10 15:57:17 -05:00
Joe Fioti
811fe65412 Working llama fp16 metal 2023-10-10 15:10:13 -05:00
Joe Fioti
0ad73d19ed Added metal fp16 copy opt 2023-10-03 13:02:55 -05:00
Joe Fioti
0b845dc7ee Merged 2023-10-03 12:48:01 -05:00
Joe Fioti
a0449b4d6b CopyOptimizer 2023-10-03 12:46:13 -05:00
Joe Fioti
7e58e1f299 Added MeanReduce and RMSNorm fused ops 2023-10-02 20:53:57 -05:00
Joe Fioti
f1da8c3cb7 Added mps matmul 2023-10-02 11:31:36 -05:00
Joe Fioti
b87f0124b7 Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-01 23:28:55 -05:00
Joe Fioti
49db4cdea8 Added metal half precision (tests failing) 2023-10-01 23:28:46 -05:00
Joe Fioti
b72e0a2270 Update Introduction.md 2023-09-30 23:58:55 -05:00
Joe Fioti
67965bc275 Added metal half precision (tests failing) 2023-09-30 23:48:05 -05:00
Joe Fioti
a8abee1422 Llama running on metal 2023-09-30 22:58:31 -05:00
Joe Fioti
4da5e94adf Complete metal primops 2023-09-30 14:08:59 -05:00
Joe Fioti
0dc1e71148 Added metal contiguous 2023-09-30 11:34:44 -05:00
Joe Fioti
7d7972d54c Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-29 23:54:34 -05:00
Joe Fioti
41de512cdc Fixed llama fp16 2023-09-29 23:54:32 -05:00
Joe Fioti
07b2b1f28c Removed cuda kernel 2023-09-29 11:18:34 -05:00
Joe Fioti
2aec49d0e5 Started metal primops 2023-09-29 11:17:33 -05:00
Joe Fioti
ffa50d43c5 Fixed feature 2023-09-27 23:20:25 -05:00
Joe Fioti
ef06f5a746 Added half precision (llama not working 2023-09-27 23:14:27 -05:00
Joe Fioti
aebdbe5ca8 Simplifications 2023-09-26 23:45:49 -05:00
Joe Fioti
acdcfc14fb Added test improvements 2023-09-26 23:27:27 -05:00
Joe Fioti
a6b403e667 Fixed feature 2023-09-26 20:02:28 -05:00
Joe Fioti
e5cfe80029 Added transfer_weights and mark_weights 2023-09-26 19:57:09 -05:00
Joe Fioti
798ac9dd69 Fixed llama setup script 2023-09-26 18:40:36 -05:00
Joe Fioti
64a05e2f14 Added cuda batch matmul 2023-09-26 18:19:26 -05:00
Joe Fioti
99f5843c42 remove rerun 2023-09-26 14:03:23 -05:00
Joe Fioti
3f2250e51f Added cublas matmul 2023-09-26 14:03:04 -05:00
Joe Fioti
9a3de0103d Fixed cuda! Validated llama run 2023-09-25 13:07:22 -05:00
Joe Fioti
1848ef4905 Partial fix of cuda 2023-09-24 23:34:36 -05:00
Joe Fioti
b8725ec9aa Cuda still broken 2023-09-23 23:55:10 -05:00
Joe Fioti
fbba2eb1db Fully precompiled cuda kernels 2023-09-18 23:52:57 -05:00
Joe Fioti
8554a1fcfc Precompiled unary cuda ops 2023-09-18 20:50:43 -05:00
Joe Fioti
2f4e189f93 First precompiled kernel? 2023-09-18 18:10:24 -05:00
Joe Fioti
1bda13aec0 Re-added cuda 2023-09-18 16:36:45 -05:00
Joe Fioti
49cadac789 Comment 2023-09-18 11:23:18 -05:00
Joe Fioti
1fe9f3a068 Selectors for multi-output 2023-09-18 11:22:17 -05:00
Joe Fioti
edb102f7a2 Added multi-output ops 2023-09-18 11:15:39 -05:00
Joe Fioti
97376b36bc Small changes 2023-09-17 22:52:23 -05:00
Joe Fioti
41d88b0c4a Multi-graph llama 2023-09-17 11:30:16 -05:00
Joe Fioti
ee70a44f8b Added proint 2023-09-17 11:14:22 -05:00
Joe Fioti
1be715f322 Small changes 2023-09-17 11:12:49 -05:00
Joe Fioti
519319c9b2 Added debug prints 2023-09-17 10:51:21 -05:00
Joe Fioti
76abe671e4 Added batched matmul cpu op 2023-09-16 23:45:28 -05:00
Joe Fioti
14541394dc Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-16 17:19:17 -05:00
Joe Fioti
78a10f89ed Re-added optimizers 2023-09-16 17:19:12 -05:00
Joe Fioti
f20a9fd2ed Chaanged tokenizer 2023-09-16 14:19:03 -05:00
Joe Fioti
18eb48735d Private data in ndexer 2023-09-16 10:19:29 -05:00
Joe Fioti
6274ba8169 Added indexer for CPU 2023-09-16 10:18:54 -05:00
Joe Fioti
a63bae227e Fixed llama 2023-09-16 08:33:33 -05:00
Joe Fioti
6182590829 Removed noop 2023-09-11 16:25:27 -05:00
Joe Fioti
0922bcb903 Fixed serialization test 2023-09-11 15:32:16 -05:00
Joe Fioti
6eb62664a5 Indexing fix 2023-09-11 14:50:08 -05:00
Joe Fioti
dcb2072f36 Partially fixed shapes 2023-09-11 13:39:07 -05:00
Joe Fioti
c5bd1a9ce9 Update README.md 2023-09-11 00:32:30 -05:00
Joe Fioti
da1192bd01 Even more fixes 2023-09-11 00:30:57 -05:00
Joe Fioti
ef3b917f5e More fixes 2023-09-11 00:04:01 -05:00
Joe Fioti
2f32bcbb8f Shape fixes 2023-09-10 23:35:49 -05:00
Joe Fioti
71adf60a71 Fixes 2023-09-10 16:04:38 -05:00
Joe Fioti
8a1c51317c Removed shape functions 2023-09-04 10:41:35 -05:00
Joe Fioti
783d01dd6f Added dyn map to graph 2023-09-04 09:11:17 -05:00
Joe Fioti
8c0567146f Added global dyn dim resolution fn 2023-09-04 09:05:28 -05:00
Joe Fioti
8a4a98fa27 Removed realdim and put symbols in Dim 2023-09-04 09:00:27 -05:00
Joe Fioti
37b363a92f Added dyn shape 2023-09-04 08:43:35 -05:00
Joe Fioti
e9352d0506 Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-04 08:19:44 -05:00
Joe Fioti
6efdcdb2b9 Removed shape resolution 2023-09-04 08:19:32 -05:00
Joe Fioti
eb1355c65a More fixes 2023-09-03 18:46:56 -05:00
Joe Fioti
a135938588 Added graph() function 2023-09-03 00:02:59 -05:00
Joe Fioti
aa3bf1ef51 Minor fixes 2023-09-02 23:52:20 -05:00
Joe Fioti
dcaa13b20e re-added serialization 2023-09-02 23:07:46 -05:00
Joe Fioti
c8ca146d0c Fixes 2023-09-02 20:13:20 -05:00
Joe Fioti
8c73edb584 Re-added shape resolution 2023-09-02 18:51:35 -05:00
Joe Fioti
f6a52704d9 Added InputTensor system 2023-09-02 18:27:02 -05:00
Joe Fioti
40814bc323 More removals 2023-09-02 16:42:27 -05:00
Joe Fioti
c652f8050a Removed old shapetracker 2023-09-02 16:30:26 -05:00
Joe Fioti
c4bf441fc1 Finished first draft of shape tracker 2023-09-02 15:44:16 -05:00
Joe Fioti
8ca9add11e Remvoed movement ops 2023-09-02 11:20:17 -05:00
Joe Fioti
1282de3d05 Added padding and slicing 2023-09-02 10:58:41 -05:00
Joe Fioti
f25c40bb08 Merge 2023-09-01 22:53:05 -05:00
Joe Fioti
3e12aa3492 tmp 2023-09-01 22:37:11 -05:00
Joe Fioti
78696adb53 Changed tracker 2023-09-01 22:15:39 -05:00
Joe Fioti
51f649da8a Partway transition to new shape tracker 2023-09-01 12:50:15 -05:00
Joe Fioti
3a35e59691 Fixed cuda 2023-08-30 00:13:07 -05:00
Joe Fioti
f5098784d7 Moved shape resolution to graph execution loop 2023-08-30 00:03:27 -05:00
Joe Fioti
b3a21eaa52 Fixed mean reduce 2023-08-29 22:05:01 -05:00
Joe Fioti
cc1d92e62f Remvoed function from mean reduce 2023-08-29 21:53:49 -05:00
Joe Fioti
3fdc34e286 Re-added arange with cumsum function 2023-08-29 21:29:52 -05:00
Joe Fioti
10f3eaad39 Changed 100 magic number to usize::MAX 2023-08-29 21:10:31 -05:00
Joe Fioti
ede46bd1e0 Added broken arange function 2023-08-14 21:09:59 -05:00
Joe Fioti
aa48af32ea Merge branch 'main' of https://github.com/jafioti/luminal 2023-08-14 14:29:53 -05:00
Joe Fioti
4acc6d1114 Started pool 2023-08-14 14:29:43 -05:00
Joe Fioti
b781be3cc2 Fixed example 2023-08-14 14:22:17 -05:00
Joe Fioti
d4a04a5055 Removed max op 2023-08-14 14:07:39 -05:00
Joe Fioti
3985301749 Added binary comparisons 2023-08-12 23:10:59 -05:00
Joe Fioti
51b6d2536d Added data function 2023-08-12 22:26:05 -05:00
Joe Fioti
2001353e9e Cuda kernels use valid 2023-08-12 20:02:24 -05:00
Joe Fioti
da8b5f62d2 Added contiguous op 2023-08-12 19:50:49 -05:00
Joe Fioti
1cf47a06c6 Merge branch 'main' of https://github.com/jafioti/luminal 2023-08-12 19:28:02 -05:00
Joe Fioti
2b68b022f9 Generalized unary sequential opt 2023-08-12 19:27:54 -05:00
Joe Fioti
7de2e883b9 Update README.md 2023-08-12 14:08:19 -05:00
Joe Fioti
ebbbfa1998 Added move_outoing_edges 2023-08-11 15:25:01 -05:00
Joe Fioti
7120aede15 Added better readme and llama setup 2023-08-11 15:07:06 -05:00
Joe Fioti
14c93f0e96 Fixed bugs 2023-08-11 14:41:06 -05:00
378 changed files with 94658 additions and 41601 deletions

View File

@@ -0,0 +1,130 @@
---
name: aoti-debug
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
---
# AOTInductor Debugging
Debug errors when compiling and deploying PyTorch models with AOTInductor.
## First Step: Always Check Device and Shape Matching
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
```python
# Compilation -- note the device and shapes
model = MyModel().eval().cuda()
inp = torch.randn(2, 10, device="cuda")
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
# Loading -- device type MUST match compilation
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
# Inference -- device and shapes MUST match
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
```
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
## Current vs Deprecated API
### Current API (use this)
```python
torch._inductor.aoti_compile_and_package() # compile
torch._inductor.aoti_load_package() # load (auto-detects device)
```
### Deprecated API (migrate away)
```python
torch._export.aot_compile() # deprecated
torch._export.aot_load() # deprecated
```
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
## Common Error Patterns
### Device Mismatch Segfault
**Symptom**: Segfault, exception, or crash during load or execution.
**Example errors**:
- `The specified pointer resides on host memory and is not registered with any CUDA device`
- Crash during constant loading
- `Expected out tensor to have device cuda:0, but got cpu instead`
**Solution**: Ensure compile and load use the same device type.
### Input Device Mismatch at Runtime
**Symptom**: RuntimeError during model execution.
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
```
Produces actionable messages like:
```
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
```
## Debugging CUDA Illegal Memory Access (IMA)
### Step 1: Sanity Checks
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
```
Both flags take effect at **compile time** (codegen time).
### Step 2: Make IMA Deterministic
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
Both take effect at **runtime**.
### Step 3: Identify the Problematic Kernel
```bash
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
```
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
To inspect inputs to specific kernels:
```bash
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
```
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
## Environment Variables Reference
| Variable | When | Purpose |
|---|---|---|
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
## Common Sources of Issues
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.

View File

@@ -0,0 +1,195 @@
---
name: pt2-debug
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
---
# PyTorch 2 Compile Debugging
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
## Diagnostic Environment Variables
Pick the right diagnostic based on the error:
| Command | When to use |
|---|---|
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
## Error Triage
Classify the failure and jump to the right section:
| Error Pattern | Category |
|---|---|
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
## Graph Breaks
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
**Diagnose:**
```bash
TORCH_LOGS="graph_breaks" python script.py
```
**Common causes:**
- Data-dependent control flow
- Unsupported Python builtins
- In-place ops on inputs, unsupported dtypes
- Calls to non-traceable functions
**Fix approaches:**
1. Read the graph break message to identify the unsupported operation
2. Check for a decomposition or supported alternative
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
## Backend Compiler Failures
`BackendCompilerFailed` means Inductor crashed during compilation.
**Diagnose with the minifier:**
```bash
# Generate minifier launcher
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
# Run the minifier to get minimal failing graph
python minifier_launcher.py minify
# Run the minimized reproduction
python minifier_launcher.py run
```
**Then inspect:**
```bash
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
```
## Recompilation Issues
Excessive recompilation from guards that are too specific, causing cache misses.
**Diagnose:**
```bash
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
```
**Key config:**
```python
torch._dynamo.config.recompile_limit # default: 8
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
```
**Common causes:**
- Changing tensor shapes without marking them dynamic
- Python scalar values that change between calls
- Global state mutations between calls
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
- Fix the source of guard instability
## Accuracy Issues
Compiled model produces different numerical results than eager mode.
**Diagnose:**
```bash
# Compares compiled vs eager with fp64 reference, dumps repro on failure
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```
**Fix approach:**
1. Get minimal failing graph from the minifier
2. Compare eager vs compiled output at fp64 precision
3. Binary search through ops to find the diverging operation
4. Check for known issues: reduction order, fused kernels, dtype promotions
## Internal Dynamo Errors
`InternalTorchDynamoError` indicates a bug in Dynamo.
**Diagnose:**
```bash
TORCHDYNAMO_VERBOSE=1 python script.py
# or equivalently:
TORCH_LOGS="+dynamo" python script.py
```
**Debug interactively:**
```bash
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
```
## Runtime Crashes
Segfaults and CUDA illegal memory access during execution of compiled code.
**Make crash deterministic:**
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
**Add NaN checks to find the first bad kernel:**
```bash
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
```
**Inductor sync debugging:**
```python
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
```
**Fix approach:**
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
2. Check input shapes, devices, dtypes
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
5. Dynamic shapes are historically a common source of IMA
## Triton Kernel Failures
Triton assertion failures or index-out-of-bounds in generated kernels.
**Diagnose:**
```bash
TORCH_LOGS="output_code,schedule" python script.py
```
**Fix approach:**
1. Get the generated Triton kernel from `output_code` logs
2. Check index computations for off-by-one or wrong stride calculations
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
4. Check if fusion decisions created invalid index combinations
## Distinguish Trace-Time vs Runtime
Many bugs come from confusing these:
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
- **Runtime**: Real tensors, real Python calls.
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
## Using the Minifier
The minifier reduces a failing graph to the smallest reproduction:
```bash
# For compilation failures (level 2)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
python minifier_launcher.py minify
python minifier_launcher.py run
# For accuracy failures (level 4)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```

View File

@@ -0,0 +1,134 @@
---
name: ruff
description:
Guide for using ruff, the extremely fast Python linter and formatter. Use this
when linting, formatting, or fixing Python code.
---
# ruff
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
isort, Black, pyupgrade, autoflake, and dozens of other tools.
## When to use ruff
**Always use ruff for Python linting and formatting**, especially if you see:
- `[tool.ruff]` section in `pyproject.toml`
- A `ruff.toml` or `.ruff.toml` configuration file
However, avoid making unnecessary changes:
- **Don't format unformatted code** - If `ruff format --diff` shows changes
throughout an entire file, the project likely isn't using ruff for formatting.
Skip formatting to avoid obscuring actual changes.
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
relevant to the code you're changing. Only apply fixes to files you're
modifying unless the user explicitly asks for broader fixes.
## How to invoke ruff
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
you use the pinned version
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
one-off checks
- `ruff ...` - Use if ruff is installed globally
## Commands
### Linting
```bash
ruff check . # Check all files in current directory
ruff check path/to/file.py # Check specific file
ruff check --fix . # Auto-fix fixable violations
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
ruff check --watch . # Watch for changes and re-lint
ruff check --select E,F . # Only check specific rules
ruff check --ignore E501 . # Ignore specific rules
ruff rule E501 # Explain a specific rule
ruff linter # List available linters
```
### Formatting
```bash
ruff format . # Format all files
ruff format path/to/file.py # Format specific file
ruff format --check . # Check if files are formatted (no changes)
ruff format --diff . # Show formatting diff without applying
```
## Configuration
Ruff is configured in `pyproject.toml` or `ruff.toml`:
```toml
# pyproject.toml
[tool.ruff.lint]
select = ["E", "F", "I", "UP"] # Enable specific rule sets
ignore = ["E501"] # Ignore specific rules
[tool.ruff.lint.isort]
known-first-party = ["myproject"]
```
## Migrating from other tools
### Black → ruff format
```bash
black . → ruff format .
black --check . → ruff format --check .
black --diff . → ruff format --diff .
```
### Flake8 → ruff check
```bash
flake8 . → ruff check .
flake8 --select E,F . → ruff check --select E,F .
flake8 --ignore E501 . → ruff check --ignore E501 .
```
### isort → ruff check
```bash
isort . → ruff check --select I --fix .
isort --check . → ruff check --select I .
isort --diff . → ruff check --select I --diff .
```
## Common patterns
### Apply lint fixes before formatting
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
structure (e.g., reordering imports), which formatting then cleans up.
```bash
ruff check --fix .
ruff format .
```
### Applying and reviewing unsafe fixes
Ruff categorizes some auto-fixes as "unsafe" because they may change code
behavior, not just style. For example, removing unused imports could break code
that relies on side effects.
```bash
ruff check --fix --unsafe-fixes --diff . # Preview changes first
ruff check --fix --unsafe-fixes . # Apply changes
```
**Always review changes before applying `--unsafe-fixes`:**
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
- Verify the fix doesn't violate those assumptions in your code
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ruff/

135
.agents/skills/ty/SKILL.md Normal file
View File

@@ -0,0 +1,135 @@
---
name: ty
description:
Guide for using ty, the extremely fast Python type checker and language
server. Use this when type checking Python code or setting up type checking in
Python projects.
---
# ty
ty is an extremely fast Python type checker and language server. It replaces
mypy, Pyright, and other type checkers.
## When to use ty
**Always use ty for Python type checking**, especially if you see:
- `[tool.ty]` section in `pyproject.toml`
- A `ty.toml` configuration file
## How to invoke ty
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
use the pinned version or when ty is installed globally and you are in a
project so the virtual environment is updated.
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
checks
## Commands
### Type checking
```bash
ty check # Check all files in current directory
ty check path/to/file.py # Check specific file
ty check src/ # Check specific directory
```
### Rule configuration
```bash
ty check --error possibly-unresolved-reference # Treat as error
ty check --warn division-by-zero # Treat as warning
ty check --ignore unresolved-import # Disable rule
```
### Python version targeting
```bash
ty check --python-version 3.12 # Check against Python 3.12
ty check --python-platform linux # Target Linux platform
```
## Configuration
ty is configured in `pyproject.toml` or `ty.toml`:
```toml
# pyproject.toml
[tool.ty.environment]
python-version = "3.12"
[tool.ty.rules]
possibly-unresolved-reference = "warn"
division-by-zero = "error"
[tool.ty.src]
include = ["src/**/*.py"]
exclude = ["**/migrations/**"]
[tool.ty.terminal]
output-format = "full"
error-on-warning = false
```
### Per-file overrides
Use overrides to apply different rules to specific files, such as relaxing rules
for tests or scripts that have different typing requirements than production
code:
```toml
[[tool.ty.overrides]]
include = ["tests/**", "**/test_*.py"]
[tool.ty.overrides.rules]
possibly-unresolved-reference = "warn"
```
## Language server
This plugin automatically configures the ty language server for Python files
(`.py` and `.pyi`).
## Migrating from other tools
### mypy → ty
```bash
mypy . → ty check
mypy --strict . → ty check --error-on-warning
mypy path/to/file.py → ty check path/to/file.py
```
### Pyright → ty
```bash
pyright . → ty check
pyright path/to/file.py → ty check path/to/file.py
```
## Common patterns
### Don't add ignore comments
Fix type errors instead of suppressing them. Only add ignore comments when
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
prefer rule-specific ignores:
```python
# Good: rule-specific ignore
x = undefined_var # ty: ignore[possibly-unresolved-reference]
# Bad: blanket ty ignore
x = undefined_var # ty: ignore
# Bad: tool agnostic blanket ignore
x = undefined_var # type: ignore
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ty/

182
.agents/skills/uv/SKILL.md Normal file
View File

@@ -0,0 +1,182 @@
---
name: uv
description:
Guide for using uv, the Python package and project manager. Use this when
working with Python projects, scripts, packages, or tools.
---
# uv
uv is an extremely fast Python package and project manager. It replaces pip,
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
## When to use uv
**Always use uv for Python work**, especially if you see:
- The `uv.lock` file
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
Don't use uv in projects managed by other tools:
- Poetry projects (identifiable by `poetry.lock` file)
- PDM projects (identifiable by `pdm.lock` file)
## Choosing the right workflow
### Scripts
**Use when:** Running single Python files and standalone scripts.
**Key commands:**
```bash
uv run script.py # Run a script
uv run --with requests script.py # Run with additional packages
uv add --script script.py requests # Add dependencies inline to the script
```
### Projects
**Use when:** There is a `pyproject.toml` or `uv.lock`
**Key commands:**
```bash
uv init # Create new project
uv add requests # Add dependency
uv remove requests # Remove dependency
uv sync # Install from lockfile
uv run <command> # Run commands in environment
uv run python -c "" # Run Python in project environment
uv run -p 3.12 <command> # Run with specific Python version
```
### Tools
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
installation.
**Key commands:**
```bash
uvx <tool> <args> # Run a tool without installation
uvx <tool>@<version> <args> # Run a specific version of a tool
```
**Important:**
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
well-known tools.
- Only use `uv tool install` only when specifically requested by the user.
### Pip interface
**Use when:** Legacy workflows with `requirements.txt` or manual environment
management, no `uv.lock` present.
**Key commands:**
```bash
uv venv
uv pip install -r requirements.txt
uv pip compile requirements.in -o requirements.txt
uv pip sync requirements.txt
# Platform independent resolution
uv pip compile --universal requirements.in -o requirements.txt
```
**Important:**
- Don't use the pip interface unless clearly needed.
- Don't introduce new `requirements.txt` files.
- Prefer `uv init` for new projects.
## Migrating from other tools
### pyenv → uv python
```bash
pyenv install 3.12 → uv python install 3.12
pyenv versions → uv python list --only-installed
pyenv local 3.12 → uv python pin 3.12
pyenv global 3.12 → uv python install 3.12 --default
```
### pipx → uvx
```bash
pipx run ruff → uvx ruff
pipx install ruff → uv tool install ruff
pipx upgrade ruff → uv tool upgrade ruff
pipx list → uv tool list
```
### pip and pip-tools → uv pip
```bash
pip install package → uv pip install package
pip install -r req.txt → uv pip install -r req.txt
pip freeze → uv pip freeze
pip-compile req.in → uv pip compile req.in
pip-sync req.txt → uv pip sync req.txt
virtualenv .venv → uv venv
```
## Common patterns
### Don't use pip in uv projects
```bash
# Bad
pip install requests
# Good
uv add requests
```
### Don't run python directly
```bash
# Bad
python script.py
# Good
uv run script.py
```
```bash
# Bad
python -c "..."
# Good
uv run python -c "..."
```
```bash
# Bad
python3.12 -c "..."
# Good
uvx python@3.12 -c "..."
```
### Don't manually manage environments in uv projects
```bash
# Bad
python -m venv .venv
source .venv/bin/activate
# Good
uv run <command>
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/uv/llms.txt
The documentation links to specific pages for each of these workflows.

4
.cargo/config.toml Normal file
View File

@@ -0,0 +1,4 @@
[target.aarch64-unknown-linux-gnu]
rustflags = [
"-Ctarget-feature=+fp16,+fhm"
]

View File

@@ -0,0 +1,55 @@
{
"name": "Luminal (CPU)",
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
"initializeCommand": "touch .env",
"runArgs": [
"--env-file", ".env"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "ubuntu",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

View File

@@ -0,0 +1,59 @@
{
"name": "Luminal (CUDA)",
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
"initializeCommand": "touch .env",
"runArgs": [
"--env-file",
".env",
"--runtime=nvidia",
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "ubuntu",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

30
.github/workflows/cuda-clippy.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: CUDA Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
cuda_clippy:
name: CUDA Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Mark workspace as safe for git
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files

23
.github/workflows/fmt.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Fmt
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

25
.github/workflows/metal-clippy.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: Metal Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files

45
.github/workflows/modal-examples.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: Modal Examples
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
modal_example:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 70
strategy:
fail-fast: false
matrix:
example: [llama, gemma, qwen, qwen3_moe]
gpu:
- { type: "A100-80GB" }
# To add more GPUs, just append another entry:
# - { type: "H100" }
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
EXAMPLE: ${{ matrix.example }}
GPU_TYPE: ${{ matrix.gpu.type }}
run: modal run ci/modal_example.py

23
.github/workflows/ruff-format.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff Format
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files

23
.github/workflows/ruff.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files

View File

@@ -1,22 +0,0 @@
name: Rust
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose

24
.github/workflows/test-core.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Test Core
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose

35
.github/workflows/test-cuda.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Test CUDA
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
cuda_unit_test:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Cuda Unit Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run CUDA tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
run: modal run ci/modal_cargo_test.py

19
.github/workflows/test-metal.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: Test Metal
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_unit_test:
name: Metal Unit Tests
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1

47
.github/workflows/test-python-cuda.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: Test Python CUDA
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
python_cuda_tests:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Python CUDA Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 60
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run pytest with CUDA backend on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
- name: Upload Modal pytest profiling artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
retention-days: 7
if-no-files-found: warn

View File

@@ -0,0 +1,28 @@
name: Test Python Native
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
python_native_tests:
name: Python Native Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 45
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"

39
.gitignore vendored
View File

@@ -1,7 +1,42 @@
/target
/crates/**/target
/examples/**/target
.claude-project
.claude-memory
.codex
*.env
.claude/
.DS_Store
.vscode
*.vscode
*.zed
Cargo.lock
*.st
*.npx
*.npz
*.npz
*.model
*.gguf
.claude-project
.claude-memory
.codex
*.pftrace
*.safetensors
*.safetensors.index.json
tokenizer.json
**/.cache
**/proptest-regressions
opencode.json
# Python build artifacts
*.so
*.pyd
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
uv.lock

38
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,38 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: ruff-check
name: ruff check
files: ^crates/luminal_python/.*\.py$
- id: ruff-format
name: ruff format
files: ^crates/luminal_python/.*\.py$
- repo: local
hooks:
- id: cargo-fmt
name: cargo fmt
entry: cargo fmt --all --check
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy
name: cargo clippy
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy-metal
name: cargo clippy metal
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]
- id: cargo-clippy-cuda-lite
name: cargo clippy cuda_lite
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]

11
AGENTS.md Normal file
View File

@@ -0,0 +1,11 @@
# Contributor Guide
## Structure
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
## Testing Instructions
- Find the CI plan in the .github/workflows folder.
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.

34
CLAUDE.md Normal file
View File

@@ -0,0 +1,34 @@
# Luminal
## Package Management
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
- Use `uv sync` to sync the Python environment
- Never use pip, pip-tools, poetry, or conda
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
- Never generate requirements.txt
## Code Execution
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
## Building the Python Package (Maturin)
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
## Pre-commit
- Run with: `uv run pre-commit run --all-files`
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
## Testing
- **Rust tests**: `cargo test -p <crate_name>`
- **Python tests**: `cd crates/luminal_python && uv run pytest`
- `./run_test.sh` — native backend
- `./run_tests_cuda.sh` — CUDA backend
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions

View File

@@ -1,31 +1,57 @@
[package]
name = "luminal"
version = "0.1.0"
edition = "2021"
version = "0.2.0"
edition.workspace = true
rust-version = "1.85"
description = "Deep learning at the speed of light."
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
#default = ["cuda"]
cuda = ["dep:cudarc"]
[dependencies]
itertools = "0.11.0"
matrixmultiply = "0.3.7"
num-traits = "0.2.16"
petgraph = {path="./resources/petgraph"}
rand = "0.8.5"
strum = { version = "0.25.0", features = ["derive"] }
petgraph = "0.6.4"
rand = "0.9.2"
urlencoding = "2.1.2"
webbrowser = "0.8.10"
open = "5"
dyn-clone = "1.0.12"
cudarc = {version="0.9.13", optional=true}
safetensors = "0.3.1"
memmap2 = "0.7.1"
half = "2.3.1"
half = {version="2.7.1", features=["num-traits"]}
tinyvec = { version = "1.6.0", features = ["serde"] }
colored = "2.0.4"
regex = "1.9.5"
rustc-hash = "2.1.1"
as-any = "0.3.1"
serde = { version = "1.0.202", features = ["derive"] }
generational-box = "0.5.6"
serde_json = "1.0.140"
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
tracing = "0.1.43"
paste = "1.0.15"
pretty-duration = "0.1.1"
anyhow = "1.0"
graphviz-rust = { version = "0.9", default-features = false}
lru = "0.16.2"
[workspace.package]
edition = "2024"
[dev-dependencies]
dfdx = "0.13"
tokenizers = "0.13.3"
candle-core = "0.9.2"
candle-nn = "0.9.2"
ordered-float = "5.1.0"
proptest = "1.9.0"
[workspace]
members = [
"examples/*",
"crates/luminal_nn",
"crates/luminal_cuda_lite",
"crates/luminal_metal",
"crates/luminal_tracing",
"crates/luminal_bench",
"crates/luminal_python/rust",
]
[patch.crates-io]
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }

201
LICENSE-APACHE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

21
LICENSE-MIT Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Joe Fioti
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

150
README.md
View File

@@ -1,76 +1,126 @@
# luminal
![image](https://raw.githubusercontent.com/jafioti/luminal/main/resources/dag.jpeg)
**Deep learning at the speed of light.**
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54PM" src="https://github.com/user-attachments/assets/c5832634-55d5-45b7-ba65-6efe36afce4a" />
Luminal is a deep learning library that prioritizes **static computation** and **operator fusion** to achieve high performance.
<h3 align="center">
Luminal is a high-performance general-purpose inference compiler.
</h3>
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/jafioti/luminal/actions)
[![Docs](https://img.shields.io/badge/Documentation-green?style=for-the-badge&color=0D9373)](https://docs.luminalai.com)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
[![discord](https://dcbadge.limes.pink/api/server/APjuwHAbGy)](https://discord.gg/APjuwHAbGy)
## Usage
```rust
use luminal::prelude::*;
// Setup graph and tensors
// Create compute graph
let mut cx = Graph::new();
let a = cx.new_tensor::<R2<3, 1>>("A");
let b = cx.new_tensor::<R2<1, 4>>("B");
let a = cx.tensor((3, 1));
let b = cx.tensor((1, 4));
// Do stuff...
let c = a.matmul(b);
let c = a.matmul(b).output();
// Set inputs and mark outputs
a.set(vec![1.0, 2.0, 3.0]);
b.set(vec![1.0, 2.0, 3.0, 3.0]);
c.mark();
// Compile
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Optimize and run graph
cx.optimize(GenericOptimizer::default());
cx.execute();
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);
rt.set_data(b, vec![1.0, 2.0, 3.0, 3.0]);
// Get result
println!("Result: {:?}", c.retrieve().unwrap().data);
// Run
rt.execute(&cx.dyn_map);
// Get output tensor
println!("Result: {:?}", rt.get_f32(c));
```
## Why does this look so different from other DL libraries?
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. So when you see `x + y`, the addition actually happens right there. This is great for debugging, it works exactly as most developers expect.
## Getting Started
However, this isn't great for performance because what makes sense for a developer doesn't make sense for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
**Llama 3 8B**
Luminal takes a different approach, more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Here everything's static. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, and executed later.
Here's a quick example of how you can run Llama 3 8B locally using Luminal on CUDA:
```bash
cd ./examples/llama
cargo run --release
```
## But Why?
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our optimizers have global knowledge and can do much more aggressive optimization **without any sync points**.
## Features
Of course, we can still split the network into multiple seperate graphs if we want to insert dynamic control flow part-way through, which means this method doesn't preclude optimizations like KV caching, because the KV cached forward pass is just a seperate graph!
### Speed
Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100. The goal is to become the fastest ML framework for any model on any device.
### Simplicity
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
### RISC-style architecture
Everything in Luminal boils down to 14 primitive ops:
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
- Binary - `Add, Mul, Mod, LessThan`
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
These ops are enough to support transformers, convnets, and nearly every popular model.
### Search
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
### Native
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
### Validated against Pytorch
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
## Ideology
### Why does this look so different from other DL libraries?
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. In PyTorch, when you see `x + y`, the addition actually happens right there. This is great for debugging because it works exactly as most developers expect.
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
### Compile everything
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
**But why?**
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
Now we can do:
Some huge benefits are now unlocked:
- Aggressive kernel fusion
- Shape-specific kernels compiled at runtime
- Devices and Dtypes are handled through optimizers (just run the CUDA optimizer to convert the graph to use CUDA kernels, then the fp16 optimizer to convert to half-precision kernels)
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
## RISC-style architecture
Luminal can be ran on new accelerators by implementing 11 primitive ops. Take a look at `src/optimizers/cuda/prim.rs` to see 1-to-1 CUDA translations of the primops.
Accellerators are free to implement their own custom ops, and their own optimizers to convert luminal primitive ops to their bespoke ops.
## Compile-time Shape Checks
All operations are shape checked at compile time, so no more shape mismatches! All credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
## View the Graph
Once you've written all your computation code, run `cx.display_graph()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.optimize(GeneralOptimizer::default())` and display the graph again. Much better.
## Where are we?
Currently luminal is extremely alpha. Please don't use this in prod.
- Llama 1 is implemented in `examples/llama`. You'll need to follow the instructions in [llama-dfdx](https://github.com/coreylowman/llama-dfdx) to download and convert the llama weights, and point this example loading path at them.
- The llama example shows how to implement a loader for a custom format. Safetensors loaders are already implemented, and are the recommended way to load a model.
- We have a small library of NN modules in `nn`, including transformers.
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the tinygrad ops set.
- Currently there are very few optimizers, so primops are mostly used to run these models, which are very slow.
- Next release will bring a signifigant amount of optimizers which should fuse primops into much faster ops. The aim for 0.2 is to be usably fast, not SOTA yet.
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
- Full training support with graph-based autograd.
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
- We have a small library of NN modules in `luminal_nn`, including transformers.
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
Some things on the roadmap:
- Write common sense cuda ops and optimizer (matmuls, mul-add, etc.)
- Expand the search space to utilize Tensor Cores more flexibly
- Bring cuda to parity with Metal
- Add Blackwell intrinsics, such as TMEM and TMA
- Build a ROCm backend
- Build benchmarking suite to test against other libs
- Write specialized CUDA kernels for full transformer architecture (FlashAttention, etc.)
- Automatic differentiation of graphs
- Beat PT 2.0 perf on LLM training
- Distributed data, pipeline and tensor parallel.
- Beat PT 2.0 perf on LLM inference _and_ training
- Write compiler for quantum photonic retro encabulator
- Build dyson swarm
## License
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.

Binary file not shown.

67
ci/modal_cargo_test.py Normal file
View File

@@ -0,0 +1,67 @@
import modal
import subprocess
import os
import sys
gpu_type = os.environ.get("GPU_TYPE", "T4")
CUDARC_CUDA_VERSION = "12080"
app = modal.App("luminal-ci-cargo-test")
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=1800, # 30 minutes
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
# Detect GPU compute capability
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
capture_output=True,
text=True,
check=True,
)
compute_cap = result.stdout.strip().replace(".", "")
subprocess.run(
[
"cargo", "test",
"-p", "luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",
],
cwd=WORKDIR,
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"CUDA_COMPUTE_CAP": compute_cap,
},
check=True,
)
@app.local_entrypoint()
def main():
run_cargo_test.remote()

67
ci/modal_example.py Normal file
View File

@@ -0,0 +1,67 @@
import modal
import subprocess
import os
example = os.environ.get("EXAMPLE", "llama")
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
CUDARC_CUDA_VERSION = "12080"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
app = modal.App(f"luminal-ci-{example}")
hf_cache = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.03-py3"
)
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=3600, # 60 minutes
volumes={
HF_CACHE_PATH: hf_cache,
},
)
def run_example(example: str):
"""Build and run a luminal example on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
subprocess.run(
["cargo", "run", "--release"],
cwd=f"{WORKDIR}/examples/{example}",
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"HF_HOME": HF_CACHE_PATH,
},
check=True,
)
hf_cache.commit()
@app.local_entrypoint()
def main():
run_example.remote(example)

View File

@@ -0,0 +1,34 @@
[package]
name = "luminal_bench"
version = "0.1.0"
edition.workspace = true
description = "Universal benchmark infrastructure for Luminal backends"
license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
luminal = { path = "../.." }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = "0.4"
egraph-serialize = { version = "0.3.0", default-features = false }
# Backend dependencies - optional, enabled via features
luminal_metal = { path = "../luminal_metal", optional = true }
metal = { version = "0.32", optional = true }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
rand = "0.9.2"
[features]
default = []
metal = ["dep:luminal_metal", "dep:metal"]
[[bench]]
name = "micro"
harness = false
[[bench]]
name = "patterns"
harness = false

View File

@@ -0,0 +1,98 @@
# `luminal_bench`
Benchmarks and debugging utilities for Luminal (Criterion benchmarks + egglog lowering debug).
## Running Benchmarks
The benches in this crate are typically run with the Metal backend enabled via a feature flag.
```bash
# L1: micro (single op / HLIR primitive)
cargo bench -p luminal_bench --features metal --bench micro
# L2: patterns (composed patterns)
cargo bench -p luminal_bench --features metal --bench patterns
```
### Outputs (Criterion)
After running, common outputs are under:
- HTML report: `target/criterion/report/index.html`
- micro metrics mapping: `target/criterion/bench_metrics.json`
- micro full report: `target/criterion/bench_report.json`
- patterns metrics mapping: `target/criterion/pattern_metrics.json`
- patterns full report: `target/criterion/pattern_report.json`
These JSON files (constant metrics such as bytes/flops) can be combined with Criterion timing to
compute derived throughput metrics (MBU/MFU/etc.).
## Coverage (Overview)
### L1 micro (single op)
Measures single-op performance for HLIR primitives (currently includes):
- Unary: `Exp2` / `Log2` / `Sin` / `Recip` / `Sqrt`
- Binary: `Add` / `Mul` / `Mod` / `LessThan`
- Indexing: `Gather` / `Cast`
- Reduction: `Sum` / `Max`
### L2 patterns (composed patterns)
Covers common composed patterns (currently includes):
- `MatMul`
- `Softmax`
- `GeLU`
- `Attention`
- `LayerNorm` (currently skipped in the Metal bench: requires unsupported HLIR primitives)
## egglog Debug Tool: `debug_ops`
`examples/debug_ops.rs` is a general egglog / lowering debug tool to help diagnose:
- Why a particular HLIR op failed to lower into backend dialect ops (and cleanup triggers
`No valid graphs present in the e-graph!`)
- Why a particular egglog function fact (e.g. `dtype`) is missing for some nodes
### Common Commands (Metal examples)
```bash
# Default: print summaries (HLIR/egglog op counts + root) and try build_search_space
# (which prints egglog rule match counts)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner
# Explicit op coverage check: provide HLIR:Backend mapping(s)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --inspect-op Add:MetalAdd
# Print full analysis output (HLIR-only + Backend+HLIR)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --analyze --inspect-op Add:MetalAdd
# Trace an egglog function fact for a specific var (HLIR-only)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --trace-fact dtype t24
# Scan all vars whose op-head is Add, find the first missing dtype, then trace it (HLIR-only)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner \
--trace-first-missing-fact dtype --within-op Add
# Inspect a var's eclass/enodes/children and dtype facts (HLIR-only)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --inspect-var t24
# Dump the raw egglog program (the `(let tN ...)` program from `hlir_to_egglog`)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner \
--dump-egglog target/gelu-inner.egg
# Export structured JSON (useful for repro/diffing)
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --json target/debug_ops.json
```
Notes:
- `--trace-fact` can only evaluate functions that exist in the egglog program (e.g. `dtype`).
Many values such as shape/strides are encoded as IR term parameters, not as function facts.
For more options, see:
```bash
cargo run -p luminal_bench --features metal --example debug_ops -- --help
```

View File

@@ -0,0 +1,153 @@
#![allow(unused)]
//! Micro benchmark runner using criterion.
//!
//! Usage and output locations: see `crates/luminal_bench/README.md`.
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use std::time::Duration;
#[cfg(feature = "metal")]
use luminal_bench::{
BenchMetrics, BenchMetricsMap, BenchResultCollector, BenchmarkBackend, BenchmarkPattern,
HardwareSpec, MetalBenchmark, all_micro_patterns,
};
#[cfg(feature = "metal")]
use luminal::prelude::*;
#[cfg(feature = "metal")]
fn run_metal_pattern_benchmark(
c: &mut Criterion,
pattern: &dyn BenchmarkPattern,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
use luminal::hlir::Input;
use luminal::op::{Runtime, RuntimeStats};
use luminal_metal::runtime::MetalRuntime;
use rand::Rng;
let backend_name = MetalBenchmark::name();
let pattern_name = pattern.name();
let group_name = format!("{}/{}", backend_name, pattern_name);
let mut group = c.benchmark_group(&group_name);
for size in pattern.sizes() {
// Build graph and run search once per size; the benchmark loop only measures execution.
let mut cx = Graph::default();
pattern.build_graph(&mut cx, *size);
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let mut rng = rand::rng();
for node in cx.graph.node_indices() {
if let Some(Input { .. }) = (*cx.graph[node]).as_any().downcast_ref::<Input>() {
let data: Vec<f32> = (0..size.value).map(|_| rng.random::<f32>()).collect();
rt.set_data(node, &data);
}
}
let mut rt = cx.search(rt, 5);
rt.allocate_intermediate_buffers(&cx.dyn_map);
let mut bench_metrics = None;
if let Some(stats) = rt.execute_with_stats(&cx.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add(pattern_name, size.name, metrics.clone());
bench_metrics = Some(metrics);
}
let dyn_map = cx.dyn_map.clone();
group.bench_with_input(BenchmarkId::from_parameter(size.name), size, |b, size| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = rt.execute_with_stats(&dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = bench_metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add(pattern_name, size.name, size.value, avg_time_us, metrics);
}
total_time
});
});
}
group.finish();
}
#[cfg(feature = "metal")]
fn metal_micro_benchmarks(c: &mut Criterion) {
let hw = MetalBenchmark::hardware_info();
println!("\n=== Metal Benchmark ===");
println!("Device: {}", hw.device_name);
println!("Memory: {:.1} GB", hw.memory_gb);
if let Some(bw) = hw.peak_bandwidth_gbps {
println!("Peak Bandwidth: {:.0} GB/s", bw);
}
if let Some(tf) = hw.peak_tflops {
println!("Peak Compute: {:.1} TFLOPS", tf);
}
println!();
let hardware_spec = HardwareSpec {
device_name: hw.device_name.clone(),
memory_gb: hw.memory_gb,
peak_bandwidth_gbps: hw.peak_bandwidth_gbps.unwrap_or(100.0),
peak_tflops: hw.peak_tflops.unwrap_or(1.0),
};
let mut metrics_map = BenchMetricsMap::new(hardware_spec.clone());
let collector = BenchResultCollector::new(hardware_spec);
for pattern in all_micro_patterns() {
run_metal_pattern_benchmark(c, pattern.as_ref(), &mut metrics_map, &collector);
}
let metrics_path = std::path::Path::new("target/criterion/bench_metrics.json");
if let Some(parent) = metrics_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Err(e) = metrics_map.save(metrics_path) {
eprintln!("Warning: Failed to save metrics mapping: {}", e);
}
let report = collector.into_report();
report.print_summary();
let report_path = std::path::Path::new("target/criterion/bench_report.json");
if let Err(e) = report.save(report_path) {
eprintln!("Warning: Failed to save full report: {}", e);
} else {
println!("\nReports saved to:");
println!(" - {}", metrics_path.display());
println!(" - {}", report_path.display());
}
}
#[cfg(not(feature = "metal"))]
fn metal_micro_benchmarks(_c: &mut Criterion) {
println!("Metal benchmarks disabled. Run with --features metal");
}
criterion_group! {
name = benches;
config = Criterion::default()
.sample_size(50)
.warm_up_time(std::time::Duration::from_millis(500))
.measurement_time(std::time::Duration::from_secs(2));
targets = metal_micro_benchmarks
}
criterion_main!(benches);

View File

@@ -0,0 +1,514 @@
#![allow(unused)]
//! Pattern benchmark runner using criterion.
//!
//! Usage and output locations: see `crates/luminal_bench/README.md`.
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use std::time::Duration;
#[cfg(feature = "metal")]
use luminal_bench::{
ATTENTION_SIZES, BenchMetrics, BenchMetricsMap, BenchResultCollector, BenchmarkBackend,
HardwareSpec, MATMUL_SIZES, MetalBenchmark, TRANSFORMER_SIZES,
};
#[cfg(feature = "metal")]
use luminal::hlir::Input;
#[cfg(feature = "metal")]
use luminal::op::{Runtime, RuntimeStats};
#[cfg(feature = "metal")]
use luminal::prelude::*;
#[cfg(feature = "metal")]
use luminal_metal::runtime::MetalRuntime;
#[cfg(feature = "metal")]
use rand::Rng;
// ============================================================================
// Helper: Prepare runtime with graph and search (done once per size)
// ============================================================================
#[cfg(feature = "metal")]
struct PreparedBench {
rt: MetalRuntime,
dyn_map: luminal::prelude::FxHashMap<char, usize>,
metrics: Option<BenchMetrics>,
}
#[cfg(feature = "metal")]
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let mut rng = rand::rng();
for (node, size) in input_sizes {
let data: Vec<f32> = (0..*size).map(|_| rng.random::<f32>()).collect();
rt.set_data(*node, &data);
}
let rt = cx.search(rt, 5);
Some(PreparedBench {
rt,
dyn_map: cx.dyn_map.clone(),
metrics: None,
})
}
// ============================================================================
// MatMul Benchmark
// ============================================================================
#[cfg(feature = "metal")]
fn bench_matmul(
c: &mut Criterion,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
let mut group = c.benchmark_group("metal/matmul");
for size in MATMUL_SIZES {
let size_name = size.name;
let (m, k, n) = (size.m, size.k, size.n);
// Build graph and run search once per size; the benchmark loop only measures execution.
let mut cx = Graph::default();
let a = cx.tensor((m, k));
let b_tensor = cx.tensor((k, n));
let _ = a.matmul(b_tensor).output();
let input_sizes: Vec<(NodeIndex, usize)> = cx
.graph
.node_indices()
.filter_map(|node| {
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
Some((node, m * k.max(k * n)))
} else {
None
}
})
.collect();
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
println!("error: Skipping matmul/{} - search failed", size_name);
continue;
};
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add("matmul", size_name, metrics.clone());
prepared.metrics = Some(metrics);
}
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = prepared.metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add("matmul", size_name, m * k * n, avg_time_us, metrics);
}
total_time
});
});
}
group.finish();
}
// ============================================================================
// Softmax Benchmark
// ============================================================================
#[cfg(feature = "metal")]
fn bench_softmax(
c: &mut Criterion,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
let mut group = c.benchmark_group("metal/softmax");
for size in TRANSFORMER_SIZES {
let size_name = size.name;
let size_value = size.value;
let dim = (size_value as f64).sqrt() as usize;
let rows = size_value / dim;
let cols = dim;
let mut cx = Graph::default();
let x = cx.tensor((rows, cols));
let _ = x.softmax(1).output();
let input_sizes: Vec<(NodeIndex, usize)> = cx
.graph
.node_indices()
.filter_map(|node| {
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
Some((node, size_value))
} else {
None
}
})
.collect();
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
println!("error: Skipping softmax/{} - search failed", size_name);
continue;
};
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add("softmax", size_name, metrics.clone());
prepared.metrics = Some(metrics);
}
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = prepared.metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add("softmax", size_name, size_value, avg_time_us, metrics);
}
total_time
});
});
}
group.finish();
}
// ============================================================================
// LayerNorm Benchmark
// ============================================================================
#[cfg(feature = "metal")]
fn bench_layer_norm(
c: &mut Criterion,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
let mut group = c.benchmark_group("metal/layer_norm");
for size in TRANSFORMER_SIZES {
let size_name = size.name;
let size_value = size.value;
// Typical shape: (batch * seq_len, hidden_dim)
let hidden_dim = 128;
let batch_seq = (size_value / hidden_dim).max(1);
let mut cx = Graph::default();
let x = cx.tensor((batch_seq, hidden_dim));
// LayerNorm along last axis with epsilon
let _ = x.layer_norm(1, 1e-5).output();
let input_sizes: Vec<(NodeIndex, usize)> = cx
.graph
.node_indices()
.filter_map(|node| {
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
Some((node, batch_seq * hidden_dim))
} else {
None
}
})
.collect();
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
println!("error: Skipping layer_norm/{} - search failed", size_name);
continue;
};
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add("layer_norm", size_name, metrics.clone());
prepared.metrics = Some(metrics);
}
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = prepared.metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add(
"layer_norm",
size_name,
batch_seq * hidden_dim,
avg_time_us,
metrics,
);
}
total_time
});
});
}
group.finish();
}
// ============================================================================
// GeLU Benchmark
// ============================================================================
#[cfg(feature = "metal")]
fn bench_gelu(
c: &mut Criterion,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
let mut group = c.benchmark_group("metal/gelu");
for size in TRANSFORMER_SIZES {
let size_name = size.name;
let size_value = size.value;
let mut cx = Graph::default();
let x = cx.tensor(size_value);
let _ = x.gelu().output();
let input_sizes: Vec<(NodeIndex, usize)> = cx
.graph
.node_indices()
.filter_map(|node| {
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
Some((node, size_value))
} else {
None
}
})
.collect();
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
println!("error: Skipping gelu/{} - search failed", size_name);
continue;
};
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add("gelu", size_name, metrics.clone());
prepared.metrics = Some(metrics);
}
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = prepared.metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add("gelu", size_name, size_value, avg_time_us, metrics);
}
total_time
});
});
}
group.finish();
}
// ============================================================================
// Attention Benchmark
// ============================================================================
#[cfg(feature = "metal")]
fn bench_attention(
c: &mut Criterion,
metrics_map: &mut BenchMetricsMap,
collector: &BenchResultCollector,
) {
let mut group = c.benchmark_group("metal/attention");
for (seq_len, head_dim) in ATTENTION_SIZES {
let size_name = format!("{}x{}", seq_len, head_dim);
let seq_len = *seq_len;
let head_dim = *head_dim;
let mut cx = Graph::default();
let q = cx.tensor((seq_len, head_dim));
let k = cx.tensor((seq_len, head_dim));
let v = cx.tensor((seq_len, head_dim));
let scores = q.matmul(k.permute((1, 0)));
let scale = 1.0 / (head_dim as f32).sqrt();
let scaled_scores = scores * scale;
let attn_weights = scaled_scores.softmax(1);
let _ = attn_weights.matmul(v).output();
let input_sizes: Vec<(NodeIndex, usize)> = cx
.graph
.node_indices()
.filter_map(|node| {
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
Some((node, seq_len * head_dim))
} else {
None
}
})
.collect();
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
println!("error: Skipping attention/{} - search failed", size_name);
continue;
};
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
metrics_map.add("attention", &size_name, metrics.clone());
prepared.metrics = Some(metrics);
}
let size_name_clone = size_name.clone();
group.bench_with_input(
BenchmarkId::from_parameter(&size_name),
&(seq_len, head_dim),
|b, _| {
b.iter_custom(|iters| {
let mut total_time = Duration::ZERO;
for _ in 0..iters {
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
total_time +=
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
}
}
if let Some(ref metrics) = prepared.metrics {
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
collector.add(
"attention",
&size_name_clone,
seq_len * head_dim,
avg_time_us,
metrics,
);
}
total_time
});
},
);
}
group.finish();
}
// ============================================================================
// Main Benchmark Entry
// ============================================================================
#[cfg(feature = "metal")]
fn metal_pattern_benchmarks(c: &mut Criterion) {
let hw = MetalBenchmark::hardware_info();
println!("\n=== Metal Pattern Benchmarks ===");
println!("Device: {}", hw.device_name);
println!("Memory: {:.1} GB", hw.memory_gb);
if let Some(bw) = hw.peak_bandwidth_gbps {
println!("Peak Bandwidth: {:.0} GB/s", bw);
}
if let Some(tf) = hw.peak_tflops {
println!("Peak Compute: {:.1} TFLOPS", tf);
}
println!();
let hardware_spec = HardwareSpec {
device_name: hw.device_name.clone(),
memory_gb: hw.memory_gb,
peak_bandwidth_gbps: hw.peak_bandwidth_gbps.unwrap_or(100.0),
peak_tflops: hw.peak_tflops.unwrap_or(1.0),
};
let mut metrics_map = BenchMetricsMap::new(hardware_spec.clone());
let collector = BenchResultCollector::new(hardware_spec);
bench_matmul(c, &mut metrics_map, &collector);
bench_softmax(c, &mut metrics_map, &collector);
bench_layer_norm(c, &mut metrics_map, &collector);
bench_gelu(c, &mut metrics_map, &collector);
bench_attention(c, &mut metrics_map, &collector);
let metrics_path = std::path::Path::new("target/criterion/pattern_metrics.json");
if let Some(parent) = metrics_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Err(e) = metrics_map.save(metrics_path) {
eprintln!("Warning: Failed to save metrics mapping: {}", e);
}
let report = collector.into_report();
report.print_summary();
let report_path = std::path::Path::new("target/criterion/pattern_report.json");
if let Err(e) = report.save(report_path) {
eprintln!("Warning: Failed to save full report: {}", e);
} else {
println!("\nReports saved to:");
println!(" - {}", metrics_path.display());
println!(" - {}", report_path.display());
}
}
#[cfg(not(feature = "metal"))]
fn metal_pattern_benchmarks(_c: &mut Criterion) {
println!("Metal benchmarks disabled. Run with --features metal");
}
criterion_group! {
name = benches;
config = Criterion::default()
.sample_size(30)
.warm_up_time(std::time::Duration::from_millis(500))
.measurement_time(std::time::Duration::from_secs(3));
targets = metal_pattern_benchmarks
}
criterion_main!(benches);

View File

@@ -0,0 +1,586 @@
#![allow(unused)]
//! Debug script to locate which HLIR op(s) fail to lower to a backend dialect,
//! leading to `No valid graphs present in the e-graph!`.
//!
//! This tool is backend-agnostic. The specific backend is selected via feature flags.
//! All core analysis logic lives in `luminal_bench::egglog_debug` module.
//!
//! Usage examples: see `crates/luminal_bench/README.md`.
use luminal::op::IntoEgglogOp;
use luminal::prelude::*;
use luminal::{egglog_utils::hlir_to_egglog, hlir::HLIROps};
use luminal_bench::egglog_debug::{
DebugReport, FactQuery, analyze_hlir_dtype_chain, analyze_hlir_function_chain,
analyze_lowering, analyze_with_ops, inspect_var_hlir, print_dtype_chain, print_function_chain,
print_lowering_analysis, print_var_inspection, summarize_egglog_ops, summarize_hlir_ops,
};
// ============================================================================
// Backend Configuration
// ============================================================================
/// Backend-specific configuration trait.
trait BackendConfig {
type Runtime: luminal::op::Runtime;
const NAME: &'static str;
fn build_search_space(cx: &mut Graph);
}
#[cfg(feature = "metal")]
mod metal_backend {
use super::*;
use luminal_metal::runtime::MetalRuntime;
pub struct MetalConfig;
impl BackendConfig for MetalConfig {
type Runtime = MetalRuntime;
const NAME: &'static str = "Metal";
fn build_search_space(cx: &mut Graph) {
cx.build_search_space::<MetalRuntime>();
}
}
}
#[cfg(feature = "metal")]
use metal_backend::MetalConfig as ActiveBackend;
// Future: Add CUDA backend
// #[cfg(feature = "cuda")]
// mod cuda_backend { ... }
// ============================================================================
// Test Cases
// ============================================================================
#[derive(Clone, Copy, Debug)]
enum Case {
Mul,
Sigmoid,
Tanh,
GeluInner,
Gelu,
LayerNorm,
}
impl Case {
fn all() -> &'static [Case] {
&[
Case::Mul,
Case::Sigmoid,
Case::Tanh,
Case::GeluInner,
Case::Gelu,
Case::LayerNorm,
]
}
fn from_str(s: &str) -> Option<Case> {
match s {
"mul" => Some(Case::Mul),
"sigmoid" => Some(Case::Sigmoid),
"tanh" => Some(Case::Tanh),
"gelu-inner" => Some(Case::GeluInner),
"gelu" => Some(Case::Gelu),
"layer-norm" | "layer_norm" => Some(Case::LayerNorm),
_ => None,
}
}
fn name(&self) -> &'static str {
match self {
Case::Mul => "Mul",
Case::Sigmoid => "Sigmoid",
Case::Tanh => "Tanh",
Case::GeluInner => "GeluInner",
Case::Gelu => "Gelu",
Case::LayerNorm => "LayerNorm",
}
}
fn build(&self, cx: &mut Graph, size: usize) {
let out = match self {
Case::Mul => {
let x = cx.tensor(size);
x.clone() * x
}
Case::Sigmoid => cx.tensor(size).sigmoid(),
Case::Tanh => cx.tensor(size).tanh(),
Case::GeluInner => {
let x = cx.tensor(size);
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
}
Case::Gelu => cx.tensor(size).gelu(),
Case::LayerNorm => {
// Mirror `crates/luminal_bench/src/patterns.rs`: normalize along last axis.
let hidden_dim = 128usize;
let batch_seq = (size / hidden_dim).max(1);
cx.tensor((batch_seq, hidden_dim)).layer_norm(1, 1e-5)
}
};
let _ = out.output();
}
}
// ============================================================================
// CLI Argument Parsing
// ============================================================================
struct Args {
case: Case,
size: usize,
dump_egglog: Option<std::path::PathBuf>,
print_egglog: bool,
analyze: bool,
inspect_vars: Vec<String>,
inspect_ops: Vec<(String, String)>,
trace_facts: Vec<(String, String)>,
trace_first_missing_facts: Vec<TraceFirstMissingFact>,
checks: Vec<Check>,
json_out: Option<std::path::PathBuf>,
all: bool,
}
#[derive(Clone, Debug)]
struct TraceFirstMissingFact {
fn_name: String,
within_op: String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Check {
MissingBackend,
DType,
Function,
All,
}
fn parse_args() -> Args {
let mut args = Args {
case: Case::Gelu,
size: 262_144,
dump_egglog: None,
print_egglog: false,
analyze: false,
inspect_vars: Vec::new(),
inspect_ops: Vec::new(),
trace_facts: Vec::new(),
trace_first_missing_facts: Vec::new(),
checks: Vec::new(),
json_out: None,
all: false,
};
// If the user writes: --trace-first-missing-fact dtype --within-op Add
// we attach the next --within-op to the last pending request.
let mut pending_within_op_for: Option<usize> = None;
let mut iter = std::env::args().skip(1);
while let Some(arg) = iter.next() {
match arg.as_str() {
"--case" => {
let val = iter.next().expect("Missing value for --case");
args.case = Case::from_str(&val).unwrap_or_else(|| {
panic!(
"Unknown case: {}. Use: mul|sigmoid|tanh|gelu-inner|gelu",
val
)
});
}
"--size" => {
let val = iter.next().expect("Missing value for --size");
args.size = val.parse().expect("Invalid --size value");
}
"--dump-egglog" => {
let val = iter.next().expect("Missing value for --dump-egglog");
args.dump_egglog = Some(val.into());
}
"--print-egglog" => args.print_egglog = true,
"--analyze" => args.analyze = true,
"--trace-fact" => {
let fn_name = iter.next().expect("Missing function name for --trace-fact");
let var = iter.next().expect("Missing variable for --trace-fact");
args.trace_facts.push((fn_name, var));
}
"--trace-first-missing-fact" => {
let fn_name = iter
.next()
.expect("Missing function name for --trace-first-missing-fact");
args.trace_first_missing_facts.push(TraceFirstMissingFact {
fn_name,
within_op: String::new(),
});
pending_within_op_for = Some(args.trace_first_missing_facts.len() - 1);
}
"--within-op" => {
let op = iter.next().expect("Missing op head for --within-op");
let Some(idx) = pending_within_op_for.take() else {
eprintln!("--within-op must follow a --trace-first-missing-fact");
std::process::exit(2);
};
args.trace_first_missing_facts[idx].within_op = op;
}
"--inspect-var" => {
let val = iter.next().expect("Missing value for --inspect-var");
args.inspect_vars.push(val);
}
"--inspect-op" => {
let val = iter.next().expect("Missing value for --inspect-op");
let mut parts = val.split(':');
let hlir = parts.next().unwrap_or("").to_string();
let backend = parts.next().unwrap_or("").to_string();
if hlir.is_empty() || backend.is_empty() || parts.next().is_some() {
eprintln!("Invalid --inspect-op format. Expected HLIR:Backend, got {val}");
std::process::exit(2);
}
args.inspect_ops.push((hlir, backend));
}
"--check" => {
let val = iter.next().expect("Missing value for --check");
let check = match val.as_str() {
"missing-backend" => Check::MissingBackend,
"dtype" => Check::DType,
"fn" | "function" => Check::Function,
"all" => Check::All,
_ => {
eprintln!("Unknown --check {val}. Use: missing-backend|dtype|fn|all");
std::process::exit(2);
}
};
args.checks.push(check);
}
"--json" => {
let val = iter.next().expect("Missing value for --json");
args.json_out = Some(val.into());
}
"--all" => args.all = true,
"--help" | "-h" => {
println!(
"Usage: debug_ops [OPTIONS]\n\n\
Options:\n \
--case <CASE> Test case: mul|sigmoid|tanh|gelu-inner|gelu (default: gelu)\n \
(also: layer-norm)\n \
--size <N> Tensor size (default: 262144)\n \
--all Run all test cases\n \
--analyze Run lowering analysis\n \
--trace-fact FN VAR Trace fact FN for VAR (HLIR-only), e.g. dtype t24\n \
--trace-first-missing-fact FN Find first missing FN within an op-head, then trace it (HLIR-only)\n \
--within-op OPHEAD Used with --trace-first-missing-fact (e.g. Add)\n \
--inspect-var VAR Print detailed eclass + dtype info for VAR (HLIR-only)\n \
--inspect-op HLIR:Backend Check backend coverage for an op mapping\n \
--check KIND Run checks: missing-backend|dtype|fn|all\n \
--json PATH Write JSON report (use '-' for stdout)\n \
--dump-egglog PATH Write egglog program to file\n \
--print-egglog Print egglog program to stdout\n \
--help Show this help"
);
std::process::exit(0);
}
other => {
eprintln!("Unknown argument: {}. Use --help for usage.", other);
std::process::exit(2);
}
}
}
// Expand checks into concrete actions and validate requirements.
if args.checks.contains(&Check::All) {
args.checks = vec![Check::MissingBackend, Check::DType, Check::Function];
}
if args.checks.contains(&Check::DType) {
// Preserve the previous semantics: scan Add for missing dtype, then trace.
let already_has_add_dtype = args
.trace_first_missing_facts
.iter()
.any(|r| r.fn_name == "dtype" && r.within_op == "Add");
if !already_has_add_dtype {
args.trace_first_missing_facts.push(TraceFirstMissingFact {
fn_name: "dtype".to_string(),
within_op: "Add".to_string(),
});
}
}
if args.checks.contains(&Check::MissingBackend) && args.inspect_ops.is_empty() {
eprintln!("--check missing-backend requires at least one --inspect-op HLIR:Backend");
std::process::exit(2);
}
if args.checks.contains(&Check::Function) && args.trace_facts.is_empty() {
eprintln!("--check fn requires at least one --trace-fact FN VAR");
std::process::exit(2);
}
args
}
// ============================================================================
// Main Logic
// ============================================================================
fn run_case<B: BackendConfig>(case: Case, size: usize, args: &Args)
where
B::Runtime: luminal::op::Runtime,
<B::Runtime as luminal::op::Runtime>::Ops: luminal::op::IntoEgglogOp,
{
println!(
"\n=== Case: {} (size={}) [{}] ===",
case.name(),
size,
B::NAME
);
// Build graph
let mut cx = Graph::default();
case.build(&mut cx, size);
// Summarize HLIR
let hlir_counts = summarize_hlir_ops(&cx);
println!("-- HLIR node types --");
for (k, v) in &hlir_counts {
println!(" {}: {}", k, v);
}
// Get egglog program
let (program, root) = hlir_to_egglog(&cx);
// Summarize egglog ops
let egglog_counts = summarize_egglog_ops(&program);
println!("-- Egglog op heads --");
for (k, v) in &egglog_counts {
println!(" {}: {}", k, v);
}
println!("-- Egglog root: {} --", root);
// Dump egglog if requested
if let Some(ref base_path) = args.dump_egglog {
let path = if args.all {
let parent = base_path.parent().unwrap_or(std::path::Path::new("."));
let stem = base_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("debug");
parent.join(format!("{}-{}.egg", stem, case.name()))
} else {
base_path.clone()
};
let content = format!("; hlir_to_egglog dump\n; root: {root}\n{program}");
std::fs::write(&path, content).expect("Failed to write egglog file");
println!("Wrote egglog program to {}", path.display());
}
if args.print_egglog {
println!("-- Egglog program --\n{program}");
}
let find_vars_by_head = |head: &str| -> Vec<String> {
let mut vars = Vec::new();
for line in program.lines() {
let line = line.trim();
if !line.starts_with("(let ") {
continue;
}
let tokens: Vec<&str> = line.split_whitespace().collect();
if tokens.len() >= 3 && tokens[0] == "(let" {
let var = tokens[1].to_string();
let op = tokens[2].trim_start_matches('(');
if op == head {
vars.push(var);
}
}
}
vars
};
// Validate any pending --within-op pairing.
for req in &args.trace_first_missing_facts {
if req.within_op.is_empty() {
eprintln!(
"--trace-first-missing-fact {} requires --within-op OPHEAD",
req.fn_name
);
std::process::exit(2);
}
}
// Prepare fact queries needed for scan-first-missing-fact.
let mut hlir_analysis = None;
let mut backend_analysis = None;
let mut fact_queries: Vec<FactQuery> = Vec::new();
for req in &args.trace_first_missing_facts {
fact_queries.push(FactQuery {
fn_name: req.fn_name.clone(),
vars: find_vars_by_head(&req.within_op),
});
}
// Only compute backend analysis if requested; compute HLIR analysis if needed
// for either --analyze or scan-first-missing-fact.
let need_backend_analysis = args.analyze || !args.inspect_ops.is_empty();
let need_hlir_analysis = args.analyze || !fact_queries.is_empty();
if need_backend_analysis {
let (hlir, backend) =
analyze_lowering::<B::Runtime>(&program, &root, &fact_queries, &args.inspect_ops);
hlir_analysis = Some(hlir);
backend_analysis = Some(backend);
} else if need_hlir_analysis {
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
hlir_analysis = Some(analyze_with_ops(
&program,
&root,
hlir_ops,
"HLIR",
&fact_queries,
&[],
));
}
if args.analyze {
println!("-- Lowering analysis --");
if let Some(ref hlir) = hlir_analysis {
print_lowering_analysis(hlir);
}
if let Some(ref backend) = backend_analysis {
print_lowering_analysis(backend);
}
} else if !args.inspect_ops.is_empty() {
if let Some(ref backend) = backend_analysis {
print_lowering_analysis(backend);
}
}
// Trace facts for explicit variables.
let mut function_traces = Vec::new();
for (fn_name, var) in &args.trace_facts {
if fn_name == "dtype" {
println!("-- Trace dtype chain for {} (HLIR-only) --", var);
let chain = analyze_hlir_dtype_chain(&program, var);
print_dtype_chain(&chain);
// Also record a function-trace entry for JSON output.
function_traces.push(analyze_hlir_function_chain(&program, fn_name, var));
} else {
let trace = analyze_hlir_function_chain(&program, fn_name, var);
print_function_chain(&trace);
function_traces.push(trace);
}
}
// Scan for first missing fact within an op-head, then trace.
for req in &args.trace_first_missing_facts {
let Some(ref hlir) = hlir_analysis else {
println!(
"-- Trace first missing fact (fn={}) within op={} --",
req.fn_name, req.within_op
);
println!(" error Skipped: HLIR analysis did not run");
continue;
};
let vars = find_vars_by_head(&req.within_op);
if vars.is_empty() {
println!(
"-- Trace first missing fact (fn={}) within op={} --",
req.fn_name, req.within_op
);
println!(" √ No matching vars found (op head not present)");
continue;
}
let table = hlir.facts.get(&req.fn_name);
let first_missing = table.and_then(|t| {
vars.iter()
.find_map(|v| t.get(v).and_then(|s| s.is_missing().then(|| v.clone())))
});
if let Some(var) = first_missing {
println!(
"-- Trace first missing fact (fn={}) within op={} --",
req.fn_name, req.within_op
);
println!(" ❌ first missing at: {}", var);
if req.fn_name == "dtype" {
let chain = analyze_hlir_dtype_chain(&program, &var);
print_dtype_chain(&chain);
function_traces.push(analyze_hlir_function_chain(&program, "dtype", &var));
} else {
let trace = analyze_hlir_function_chain(&program, &req.fn_name, &var);
print_function_chain(&trace);
function_traces.push(trace);
}
} else {
println!(
"-- Trace first missing fact (fn={}) within op={} --",
req.fn_name, req.within_op
);
println!(" √ No missing values found");
}
}
let mut var_inspections = Vec::new();
if !args.inspect_vars.is_empty() {
for var in &args.inspect_vars {
let inspection = inspect_var_hlir(&program, var);
print_var_inspection(&inspection);
var_inspections.push(inspection);
}
}
// Try to build search space
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
B::build_search_space(&mut cx);
}));
std::panic::set_hook(prev_hook);
let build_succeeded = result.is_ok();
match result {
Ok(()) => println!("√ build_search_space succeeded"),
Err(_) => println!("❌ build_search_space failed"),
}
if let Some(ref path) = args.json_out {
let report = DebugReport {
case_name: case.name().to_string(),
size,
hlir_counts,
egglog_counts,
hlir_analysis,
backend_analysis,
var_inspections,
function_traces,
build_succeeded,
};
let json = serde_json::to_string_pretty(&report).expect("failed to serialize report");
if path.as_os_str() == "-" {
println!("{}", json);
} else {
std::fs::write(path, json).expect("failed to write json report");
println!("Wrote JSON report to {}", path.display());
}
}
}
#[cfg(feature = "metal")]
fn main() {
let args = parse_args();
println!("=== debug_ops ({}) ===", ActiveBackend::NAME);
println!("Backend: {}", ActiveBackend::NAME);
println!("Tip: Use --analyze for detailed lowering analysis.\n");
if args.all {
for case in Case::all() {
run_case::<ActiveBackend>(*case, args.size, &args);
}
} else {
run_case::<ActiveBackend>(args.case, args.size, &args);
}
}
#[cfg(not(feature = "metal"))]
fn main() {}

View File

@@ -0,0 +1,550 @@
//! Core analysis functions for egglog debugging.
use super::{
DTypeChainAnalysis, DTypeStatus, DependencyGraph, FactStatus, FunctionChainAnalysis,
FunctionTraceEntry,
};
use egraph_serialize::ClassId;
use luminal::egglog_utils;
use luminal::hlir::HLIROps;
use luminal::op::{EgglogOp, IntoEgglogOp, Runtime};
use luminal::prelude::egglog;
use luminal::prelude::egglog::prelude::exprs;
use luminal::prelude::egglog_ast::{RustSpan, Span};
use luminal::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;
/// Analysis result for lowering.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LoweringAnalysis {
pub label: String,
pub root_labels: Vec<String>,
pub output_input_labels: Vec<String>,
/// Optional op-coverage reports (only filled when explicitly requested).
pub op_reports: Vec<OpLoweringReport>,
/// Optional evaluated facts, keyed by function name then variable name.
pub facts: BTreeMap<String, BTreeMap<String, FactStatus>>,
}
/// Query for evaluating a function on a set of variables.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FactQuery {
pub fn_name: String,
pub vars: Vec<String>,
}
/// Missing backend equivalent for a specific HLIR op instance.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpMissing {
pub class_id: String,
pub op: String,
pub children: Vec<ChildInspection>,
}
/// Report for op lowering coverage.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OpLoweringReport {
pub label: String,
pub hlir_op: String,
pub backend_op: String,
pub total_classes: usize,
pub missing: Vec<OpMissing>,
}
/// Inspection of a specific variable's eclass and dtype facts.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VarInspection {
pub label: String,
pub var: String,
pub let_line: Option<String>,
pub eval_error: Option<String>,
pub class_id: Option<String>,
pub class_type: Option<String>,
pub class_labels: Vec<String>,
pub dtype: Option<String>,
pub enodes: Vec<EnodeInspection>,
}
/// Inspection of an enode within an eclass.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EnodeInspection {
pub label: String,
pub children: Vec<ChildInspection>,
}
/// Inspection of a child eclass.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChildInspection {
pub class_id: String,
pub class_type: String,
pub class_labels: Vec<String>,
pub dtype: Option<String>,
}
fn find_let_line(program: &str, var: &str) -> Option<String> {
for line in program.lines() {
let line = line.trim();
if line.starts_with("(let ") && line.split_whitespace().nth(1) == Some(var) {
return Some(line.to_string());
}
}
None
}
fn run_egraph(program: &str, ops: Vec<Arc<Box<dyn EgglogOp>>>) -> egglog::EGraph {
let code = egglog_utils::full_egglog(program, &ops, false);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code).unwrap();
let _outputs = egraph.run_program(commands).unwrap();
egraph
}
fn annotate_dtypes(graph: &mut DependencyGraph, egraph: &mut egglog::EGraph) {
for node in graph.nodes.values_mut() {
node.dtype = Some(eval_dtype(egraph, &node.var));
}
}
fn class_labels(serialized: &egglog_utils::SerializedEGraph, class_id: &ClassId) -> Vec<String> {
let Some((_, nodes)) = serialized.eclasses.get(class_id) else {
return vec!["<missing>".to_string()];
};
let mut labels: Vec<String> = nodes
.iter()
.filter_map(|node_id| {
serialized
.enodes
.get(node_id)
.map(|(label, _)| label.clone())
})
.collect();
labels.sort();
labels.dedup();
labels
}
fn class_type(serialized: &egglog_utils::SerializedEGraph, class_id: &ClassId) -> String {
serialized
.eclasses
.get(class_id)
.map(|(typ, _)| typ.clone())
.unwrap_or_else(|| "<missing>".to_string())
}
fn collect_dtype_facts(serialized: &egglog_utils::SerializedEGraph) -> FxHashMap<ClassId, String> {
let mut map: FxHashMap<ClassId, String> = FxHashMap::default();
for (node_id, (label, children)) in &serialized.enodes {
if !label.starts_with("dtype") {
continue;
}
if children.is_empty() {
continue;
}
let input_class = children[0].clone();
let dtype_class = serialized.node_to_class[node_id].clone();
let dtype_labels = class_labels(serialized, &dtype_class);
let dtype_label = dtype_labels
.first()
.cloned()
.unwrap_or_else(|| "<unknown>".to_string());
map.insert(input_class, dtype_label);
}
map
}
fn eval_function(egraph: &mut egglog::EGraph, fn_name: &str, var: &str) -> FactStatus {
let expr = exprs::call(fn_name, vec![exprs::var(var)]);
match egraph.eval_expr(&expr) {
Ok((sort, value)) => match egraph.extract_value_to_string(&sort, value) {
Ok((s, _)) => FactStatus::Resolved(s),
Err(_) => FactStatus::Missing("extract-error".to_string()),
},
Err(err) => FactStatus::Missing(format!("{err}")),
}
}
/// Inspect a specific var with a given set of ops.
pub fn inspect_var_with_ops(
program: &str,
ops: Vec<Arc<Box<dyn EgglogOp>>>,
var: &str,
label: &str,
) -> VarInspection {
let mut egraph = run_egraph(program, ops);
let let_line = find_let_line(program, var);
let mut inspection = VarInspection {
label: label.to_string(),
var: var.to_string(),
let_line,
..Default::default()
};
let var_expr = egglog::var!(var.to_string());
let (sort, value) = match egraph.eval_expr(&var_expr) {
Ok(res) => res,
Err(err) => {
inspection.eval_error = Some(format!("{err}"));
return inspection;
}
};
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
let dtype_facts = collect_dtype_facts(&serialized);
let class_id = serialized.roots.first().cloned();
if let Some(class_id) = class_id {
inspection.class_id = Some(format!("{:?}", class_id));
inspection.class_type = Some(class_type(&serialized, &class_id));
inspection.class_labels = class_labels(&serialized, &class_id);
inspection.dtype = dtype_facts.get(&class_id).cloned();
if let Some((_, nodes)) = serialized.eclasses.get(&class_id) {
for node_id in nodes {
let Some((label, children)) = serialized.enodes.get(node_id) else {
continue;
};
let mut enode = EnodeInspection {
label: label.clone(),
..Default::default()
};
for child in children {
let child_labels = class_labels(&serialized, child);
let child_type = class_type(&serialized, child);
let dtype = dtype_facts.get(child).cloned();
enode.children.push(ChildInspection {
class_id: format!("{:?}", child),
class_type: child_type,
class_labels: child_labels,
dtype,
});
}
inspection.enodes.push(enode);
}
}
}
inspection
}
/// Inspect a specific var using HLIR-only ops.
pub fn inspect_var_hlir(program: &str, var: &str) -> VarInspection {
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
inspect_var_with_ops(program, hlir_ops, var, "HLIR")
}
/// Evaluate dtype for a variable in an egraph.
pub fn eval_dtype(egraph: &mut egglog::EGraph, var: &str) -> DTypeStatus {
let expr = egglog::call!("dtype", vec![egglog::var!(var.to_string())]);
match egraph.eval_expr(&expr) {
Ok((sort, value)) => match egraph.extract_value_to_string(&sort, value) {
Ok((s, _)) => DTypeStatus::Resolved(s),
Err(_) => DTypeStatus::Missing("extract-error".to_string()),
},
Err(err) => DTypeStatus::Missing(format!("{err}")),
}
}
/// Analyze backend lowering for a specific HLIR op -> backend op mapping.
pub fn analyze_op_lowering_with_ops(
program: &str,
ops: Vec<Arc<Box<dyn EgglogOp>>>,
hlir_op: &str,
backend_op: &str,
label: &str,
) -> OpLoweringReport {
let mut egraph = run_egraph(program, ops);
let (sort, value) = egraph
.eval_expr(&egglog::var!("t0"))
.or_else(|_| egraph.eval_expr(&egglog::var!("t1")))
.unwrap_or_else(|_| {
panic!("failed to eval any root variable (t0/t1) for op inspection");
});
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
let dtype_facts = collect_dtype_facts(&serialized);
let mut eclass_has_backend: FxHashSet<ClassId> = FxHashSet::default();
for (node_id, (lbl, _)) in &serialized.enodes {
if lbl == backend_op {
eclass_has_backend.insert(serialized.node_to_class[node_id].clone());
}
}
let mut seen_classes: FxHashSet<ClassId> = FxHashSet::default();
let mut missing: Vec<OpMissing> = Vec::new();
for (node_id, (lbl, children)) in &serialized.enodes {
if lbl != hlir_op {
continue;
}
let class_id = &serialized.node_to_class[node_id];
if seen_classes.contains(class_id) {
continue;
}
seen_classes.insert(class_id.clone());
if eclass_has_backend.contains(class_id) {
continue;
}
let mut child_summaries = Vec::new();
for child in children {
let labels = class_labels(&serialized, child);
let typ = class_type(&serialized, child);
let dtype = dtype_facts.get(child).cloned();
child_summaries.push(ChildInspection {
class_id: format!("{:?}", child),
class_type: typ,
class_labels: labels,
dtype,
});
}
missing.push(OpMissing {
class_id: format!("{:?}", class_id),
op: hlir_op.to_string(),
children: child_summaries,
});
}
OpLoweringReport {
label: label.to_string(),
hlir_op: hlir_op.to_string(),
backend_op: backend_op.to_string(),
total_classes: seen_classes.len(),
missing,
}
}
/// Analyze backend lowering for a specific HLIR op using a runtime's ops.
pub fn analyze_backend_op_lowering<R: Runtime>(
program: &str,
hlir_op: &str,
backend_op: &str,
) -> OpLoweringReport
where
R::Ops: IntoEgglogOp,
{
let mut backend_ops = R::Ops::into_vec();
backend_ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
let label = format!(
"{}+HLIR",
std::any::type_name::<R>()
.split("::")
.last()
.unwrap_or("Backend")
);
analyze_op_lowering_with_ops(program, backend_ops, hlir_op, backend_op, &label)
}
/// Analyze lowering with a specific set of ops.
///
/// This is the core analysis function that works with any set of ops.
/// Use `analyze_lowering` for convenience when working with a specific backend.
pub fn analyze_with_ops(
program: &str,
root: &str,
ops: Vec<Arc<Box<dyn EgglogOp>>>,
label: &str,
fact_queries: &[FactQuery],
op_mappings: &[(String, String)],
) -> LoweringAnalysis {
let mut egraph = run_egraph(program, ops);
let (sort, value) = egraph.eval_expr(&egglog::var!(root)).unwrap();
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
let dtype_facts = collect_dtype_facts(&serialized);
let root_class_id = serialized.roots.first().unwrap();
let root_labels = class_labels(&serialized, root_class_id);
let mut analysis = LoweringAnalysis {
label: label.to_string(),
root_labels,
output_input_labels: Vec::new(),
op_reports: Vec::new(),
facts: BTreeMap::new(),
};
// Output input labels (if any Output exists under this root).
for (lbl, children) in serialized.enodes.values() {
if lbl != "Output" || children.is_empty() {
continue;
}
analysis.output_input_labels = class_labels(&serialized, &children[0]);
break;
}
// Op coverage reports (only when explicitly requested).
for (hlir_op, backend_op) in op_mappings {
// Determine which eclasses contain the backend op.
let mut eclass_has_backend: FxHashSet<ClassId> = FxHashSet::default();
for (node_id, (lbl, _)) in &serialized.enodes {
if lbl == backend_op {
eclass_has_backend.insert(serialized.node_to_class[node_id].clone());
}
}
let mut seen_classes: FxHashSet<ClassId> = FxHashSet::default();
let mut missing: Vec<OpMissing> = Vec::new();
for (node_id, (lbl, children)) in &serialized.enodes {
if lbl != hlir_op {
continue;
}
let class_id = &serialized.node_to_class[node_id];
if !seen_classes.insert(class_id.clone()) {
continue;
}
if eclass_has_backend.contains(class_id) {
continue;
}
let mut child_summaries = Vec::new();
for child in children {
child_summaries.push(ChildInspection {
class_id: format!("{:?}", child),
class_type: class_type(&serialized, child),
class_labels: class_labels(&serialized, child),
dtype: dtype_facts.get(child).cloned(),
});
}
missing.push(OpMissing {
class_id: format!("{:?}", class_id),
op: hlir_op.clone(),
children: child_summaries,
});
}
analysis.op_reports.push(OpLoweringReport {
label: label.to_string(),
hlir_op: hlir_op.clone(),
backend_op: backend_op.clone(),
total_classes: seen_classes.len(),
missing,
});
}
// Evaluate requested facts.
for q in fact_queries {
let mut table: BTreeMap<String, FactStatus> = BTreeMap::new();
for var in &q.vars {
table.insert(var.clone(), eval_function(&mut egraph, &q.fn_name, var));
}
analysis.facts.insert(q.fn_name.clone(), table);
}
analysis
}
/// Analyze dtype propagation chain for a specific variable with a given set of ops.
pub fn analyze_dtype_chain_with_ops(
program: &str,
ops: Vec<Arc<Box<dyn EgglogOp>>>,
target: &str,
) -> DTypeChainAnalysis {
let mut egraph = run_egraph(program, ops);
let mut graph = DependencyGraph::from_program(program);
annotate_dtypes(&mut graph, &mut egraph);
DTypeChainAnalysis::analyze(&graph, target)
}
/// Analyze dtype propagation chain for a specific variable using HLIR-only ops.
pub fn analyze_hlir_dtype_chain(program: &str, target: &str) -> DTypeChainAnalysis {
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
analyze_dtype_chain_with_ops(program, hlir_ops, target)
}
/// Analyze function propagation chain for a specific variable with a given set of ops.
pub fn analyze_function_chain_with_ops(
program: &str,
ops: Vec<Arc<Box<dyn EgglogOp>>>,
fn_name: &str,
target: &str,
) -> FunctionChainAnalysis {
let mut egraph = run_egraph(program, ops);
let graph = DependencyGraph::from_program(program);
let trace = graph.trace_back(target, 20);
let mut chain = Vec::new();
let mut first_missing = None;
for entry in trace {
let status = eval_function(&mut egraph, fn_name, &entry.var);
if first_missing.is_none() && status.is_missing() {
first_missing = Some(entry.var.clone());
}
chain.push(FunctionTraceEntry {
depth: entry.depth,
var: entry.var,
op_type: entry.op_type,
status,
});
}
let all_resolved = first_missing.is_none();
FunctionChainAnalysis {
target: target.to_string(),
fn_name: fn_name.to_string(),
chain,
first_missing,
all_resolved,
}
}
/// Analyze function propagation chain for a specific variable using HLIR-only ops.
pub fn analyze_hlir_function_chain(
program: &str,
fn_name: &str,
target: &str,
) -> FunctionChainAnalysis {
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
analyze_function_chain_with_ops(program, hlir_ops, fn_name, target)
}
/// Run full lowering analysis comparing HLIR-only vs Backend+HLIR.
///
/// This is a generic function that works with any backend implementing `Runtime`.
///
/// # Type Parameters
/// * `R` - The runtime type (e.g., `MetalRuntime`, `CudaRuntime`)
///
/// # Arguments
/// * `program` - The egglog program string
/// * `root` - The root variable name
/// * `backend_add_name` - The name of the backend's Add operation (e.g., "MetalAdd")
pub fn analyze_lowering<R: Runtime>(
program: &str,
root: &str,
fact_queries: &[FactQuery],
op_mappings: &[(String, String)],
) -> (LoweringAnalysis, LoweringAnalysis)
where
R::Ops: IntoEgglogOp,
{
// HLIR-only analysis
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
let hlir_analysis = analyze_with_ops(program, root, hlir_ops, "HLIR", fact_queries, &[]);
// Backend+HLIR analysis
let mut backend_ops = R::Ops::into_vec();
backend_ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
let backend_label = format!(
"{}+HLIR",
std::any::type_name::<R>()
.split("::")
.last()
.unwrap_or("Backend")
);
let backend_analysis = analyze_with_ops(
program,
root,
backend_ops,
&backend_label,
fact_queries,
op_mappings,
);
(hlir_analysis, backend_analysis)
}
/// Convenience function for HLIR-only analysis (no backend).
pub fn analyze_hlir_only(program: &str, root: &str) -> LoweringAnalysis {
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
analyze_with_ops(program, root, hlir_ops, "HLIR", &[], &[])
}

View File

@@ -0,0 +1,102 @@
//! Egglog debugging and analysis utilities.
//!
//! This module provides tools for diagnosing egglog lowering issues,
//! particularly when HLIR operations fail to convert to backend implementations.
mod analysis;
mod report;
mod trace;
pub use analysis::*;
pub use report::*;
pub use trace::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
/// Extract the operation head from an egglog expression.
///
/// Example: `(Add t1 t2 ...)` -> `Some("Add")`
pub fn egglog_op_head(code: &str) -> Option<&str> {
let code = code.trim();
code.strip_prefix('(')
.and_then(|s| s.split_whitespace().next())
}
/// Summarize HLIR node types in a graph.
pub fn summarize_hlir_ops(cx: &luminal::prelude::Graph) -> BTreeMap<String, usize> {
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
for node in cx.graph.node_indices() {
let name = cx.graph[node].type_name().to_string();
*counts.entry(name).or_insert(0) += 1;
}
counts
}
/// Summarize egglog operation heads from a program string.
pub fn summarize_egglog_ops(program: &str) -> BTreeMap<String, usize> {
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
for line in program.lines() {
// Parse lines like: (let t1 (Add ...))
let Some(code) = line.splitn(3, ' ').nth(2) else {
continue;
};
let Some(head) = egglog_op_head(code) else {
continue;
};
*counts.entry(head.to_string()).or_insert(0) += 1;
}
counts
}
/// Result of dtype analysis for a node.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DTypeStatus {
/// dtype was successfully resolved
Resolved(String),
/// dtype lookup failed
Missing(String),
}
impl DTypeStatus {
pub fn is_missing(&self) -> bool {
matches!(self, DTypeStatus::Missing(_))
}
pub fn is_resolved(&self) -> bool {
matches!(self, DTypeStatus::Resolved(_))
}
}
impl std::fmt::Display for DTypeStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DTypeStatus::Resolved(s) => write!(f, "{}", s),
DTypeStatus::Missing(err) => write!(f, "<missing:{}>", err),
}
}
}
/// Result of evaluating an arbitrary egglog function for a node.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FactStatus {
/// function value was successfully resolved
Resolved(String),
/// function lookup failed
Missing(String),
}
impl FactStatus {
pub fn is_missing(&self) -> bool {
matches!(self, FactStatus::Missing(_))
}
}
impl std::fmt::Display for FactStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FactStatus::Resolved(s) => write!(f, "{}", s),
FactStatus::Missing(err) => write!(f, "<missing:{}>", err),
}
}
}

View File

@@ -0,0 +1,247 @@
//! Formatted output and reporting for egglog debug analysis.
use super::{
DTypeChainAnalysis, DTypeStatus, EnodeInspection, FunctionChainAnalysis, LoweringAnalysis,
OpLoweringReport, TraceEntry, VarInspection,
};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
/// Print HLIR node type summary.
pub fn print_hlir_summary(counts: &BTreeMap<String, usize>) {
println!("-- HLIR node types --");
for (k, v) in counts {
println!(" {}: {}", k, v);
}
}
/// Print egglog op head summary.
pub fn print_egglog_summary(counts: &BTreeMap<String, usize>) {
println!("-- Egglog op heads --");
for (k, v) in counts {
println!(" {}: {}", k, v);
}
}
/// Print lowering analysis results.
pub fn print_lowering_analysis(analysis: &LoweringAnalysis) {
println!("-- {} Analysis --", analysis.label);
println!(" Root eclass labels: {}", analysis.root_labels.join("|"));
if !analysis.output_input_labels.is_empty() {
println!(
" Output input labels: {}",
analysis.output_input_labels.join("|")
);
}
if !analysis.facts.is_empty() {
println!(" Facts:");
for (fn_name, table) in &analysis.facts {
println!(" {}:", fn_name);
for (var, status) in table {
let prefix = if status.is_missing() { "" } else { "" };
println!(" {} {}: {}", prefix, var, status);
}
}
}
if !analysis.op_reports.is_empty() {
for report in &analysis.op_reports {
print_op_lowering_report(report);
}
}
}
/// Print dependency trace as a tree.
pub fn print_trace_tree(trace: &[TraceEntry]) {
println!("-- Dependency trace --");
for entry in trace {
println!("{}", entry.format_tree());
}
}
/// Print dtype chain analysis.
pub fn print_dtype_chain(analysis: &DTypeChainAnalysis) {
println!("-- DType chain analysis for {} --", analysis.target);
if analysis.all_resolved {
println!(" √ All nodes in chain have resolved dtype");
} else if let Some(ref first) = analysis.first_missing {
println!(" ❌ First missing dtype at: {}", first);
}
println!(" Chain:");
for entry in &analysis.chain {
let dtype_str = match &entry.dtype {
Some(DTypeStatus::Resolved(s)) => format!("{}", s),
Some(DTypeStatus::Missing(_)) => "❌ missing".to_string(),
None => "? unknown".to_string(),
};
let indent = " ".repeat(entry.depth + 1);
println!("{}{} ({}) {}", indent, entry.var, entry.op_type, dtype_str);
}
}
/// Print function chain analysis.
pub fn print_function_chain(analysis: &FunctionChainAnalysis) {
println!(
"-- Function chain analysis for {} (fn={}) --",
analysis.target, analysis.fn_name
);
if analysis.all_resolved {
println!(" √ All nodes in chain have resolved value");
} else if let Some(ref first) = analysis.first_missing {
println!(" ❌ First missing at: {}", first);
}
println!(" Chain:");
for entry in &analysis.chain {
println!("{}", entry.format_tree());
}
}
/// Print op lowering report.
pub fn print_op_lowering_report(report: &OpLoweringReport) {
println!(
"-- Op lowering [{}] {} -> {} --",
report.label, report.hlir_op, report.backend_op
);
println!(" total eclasses: {}", report.total_classes);
if report.missing.is_empty() {
println!(" √ All eclasses have backend equivalent");
} else {
println!(" ❌ Missing backend in {} eclasses:", report.missing.len());
for miss in &report.missing {
println!(" - class={} op={}", miss.class_id, miss.op);
for (idx, child) in miss.children.iter().enumerate() {
let labels = if child.class_labels.is_empty() {
"<none>".to_string()
} else {
child.class_labels.join("|")
};
let dtype = child
.dtype
.clone()
.unwrap_or_else(|| "<missing>".to_string());
println!(
" [{}] class={} type={} labels={} dtype={}",
idx, child.class_id, child.class_type, labels, dtype
);
}
}
}
}
fn print_enode(enode: &EnodeInspection) {
println!(" - {}", enode.label);
for (idx, child) in enode.children.iter().enumerate() {
let dtype = child
.dtype
.clone()
.unwrap_or_else(|| "<missing>".to_string());
let labels = if child.class_labels.is_empty() {
"<none>".to_string()
} else {
child.class_labels.join("|")
};
println!(
" [{}] class={} type={} labels={} dtype={}",
idx, child.class_id, child.class_type, labels, dtype
);
}
}
/// Print inspection results for a specific variable.
pub fn print_var_inspection(inspection: &VarInspection) {
println!(
"-- Var inspection [{}] {} --",
inspection.label, inspection.var
);
if let Some(ref line) = inspection.let_line {
println!(" let: {}", line);
}
if let Some(ref err) = inspection.eval_error {
println!(" eval error: {}", err);
return;
}
let class_id = inspection.class_id.as_deref().unwrap_or("<unknown>");
let class_type = inspection.class_type.as_deref().unwrap_or("<unknown>");
let dtype = inspection.dtype.as_deref().unwrap_or("<missing>");
let labels = if inspection.class_labels.is_empty() {
"<none>".to_string()
} else {
inspection.class_labels.join("|")
};
println!(
" class: {} type={} labels={}",
class_id, class_type, labels
);
println!(" dtype: {}", dtype);
println!(" enodes:");
for enode in &inspection.enodes {
print_enode(enode);
}
}
/// Summary report for a debug session.
#[derive(Debug, Serialize, Deserialize)]
pub struct DebugReport {
pub case_name: String,
pub size: usize,
pub hlir_counts: BTreeMap<String, usize>,
pub egglog_counts: BTreeMap<String, usize>,
pub hlir_analysis: Option<LoweringAnalysis>,
pub backend_analysis: Option<LoweringAnalysis>,
pub var_inspections: Vec<VarInspection>,
pub function_traces: Vec<FunctionChainAnalysis>,
pub build_succeeded: bool,
}
impl DebugReport {
/// Print full report to stdout.
pub fn print(&self) {
println!("\n{}", "=".repeat(60));
println!("Case: {} (size={})", self.case_name, self.size);
println!("{}", "=".repeat(60));
print_hlir_summary(&self.hlir_counts);
println!();
print_egglog_summary(&self.egglog_counts);
if let Some(ref analysis) = self.hlir_analysis {
println!();
print_lowering_analysis(analysis);
}
if let Some(ref analysis) = self.backend_analysis {
println!();
print_lowering_analysis(analysis);
}
if !self.function_traces.is_empty() {
for trace in &self.function_traces {
println!();
print_function_chain(trace);
}
}
if !self.var_inspections.is_empty() {
for inspection in &self.var_inspections {
println!();
print_var_inspection(inspection);
}
}
println!();
if self.build_succeeded {
println!("√ build_search_space succeeded");
} else {
println!("❌ build_search_space failed");
}
}
}

View File

@@ -0,0 +1,232 @@
//! Dependency chain tracing for dtype propagation analysis.
use super::{DTypeStatus, FactStatus};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
/// A node in the dependency graph.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DepNode {
pub var: String,
pub op_type: String,
pub inputs: Vec<String>,
pub dtype: Option<DTypeStatus>,
}
/// Dependency graph built from egglog program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct DependencyGraph {
pub nodes: HashMap<String, DepNode>,
pub roots: Vec<String>,
}
impl DependencyGraph {
/// Build dependency graph from egglog program.
pub fn from_program(program: &str) -> Self {
let mut graph = DependencyGraph::default();
for line in program.lines() {
let line = line.trim();
if !line.starts_with("(let ") {
continue;
}
// Parse: (let t1 (OpName args...))
let tokens: Vec<&str> = line.split_whitespace().collect();
if tokens.len() < 3 || tokens[0] != "(let" {
continue;
}
let var = tokens[1].to_string();
let op_type = tokens[2].trim_start_matches('(').to_string();
// Extract input variables (t followed by digits)
let mut inputs = Vec::new();
let bytes = line.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b't' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
let start = i;
i += 1;
while i < bytes.len() && bytes[i].is_ascii_digit() {
i += 1;
}
let found_var = String::from_utf8_lossy(&bytes[start..i]).to_string();
// Don't include self
if found_var != var {
inputs.push(found_var);
}
} else {
i += 1;
}
}
graph.nodes.insert(
var.clone(),
DepNode {
var,
op_type,
inputs,
dtype: None,
},
);
}
// Find roots (nodes that are not inputs to any other node)
let all_inputs: HashSet<String> = graph
.nodes
.values()
.flat_map(|n| n.inputs.iter().cloned())
.collect();
graph.roots = graph
.nodes
.keys()
.filter(|k| !all_inputs.contains(*k))
.cloned()
.collect();
graph
}
/// Trace the dependency chain from a target variable back to inputs.
pub fn trace_back(&self, target: &str, max_depth: usize) -> Vec<TraceEntry> {
let mut result = Vec::new();
self.trace_back_recursive(target, 0, max_depth, &mut result, &mut HashSet::new());
result
}
fn trace_back_recursive(
&self,
var: &str,
depth: usize,
max_depth: usize,
result: &mut Vec<TraceEntry>,
visited: &mut HashSet<String>,
) {
if depth > max_depth || visited.contains(var) {
return;
}
visited.insert(var.to_string());
let node = match self.nodes.get(var) {
Some(n) => n,
None => {
result.push(TraceEntry {
depth,
var: var.to_string(),
op_type: "<unknown>".to_string(),
dtype: None,
});
return;
}
};
result.push(TraceEntry {
depth,
var: node.var.clone(),
op_type: node.op_type.clone(),
dtype: node.dtype.clone(),
});
for input in &node.inputs {
self.trace_back_recursive(input, depth + 1, max_depth, result, visited);
}
}
/// Find the first node in a chain that has missing dtype.
pub fn find_dtype_break(&self, target: &str) -> Option<String> {
let trace = self.trace_back(target, 20);
for entry in trace {
if let Some(DTypeStatus::Missing(_)) = entry.dtype {
return Some(entry.var);
}
}
None
}
}
/// Entry in a trace result.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEntry {
pub depth: usize,
pub var: String,
pub op_type: String,
pub dtype: Option<DTypeStatus>,
}
impl TraceEntry {
/// Format as indented tree line.
pub fn format_tree(&self) -> String {
let indent = " ".repeat(self.depth);
let prefix = if self.depth == 0 { "" } else { "├── " };
let dtype_str = match &self.dtype {
Some(d) => format!(" dtype={}", d),
None => String::new(),
};
format!(
"{}{}{} ({}){}",
indent, prefix, self.var, self.op_type, dtype_str
)
}
}
/// Entry in a function trace result.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionTraceEntry {
pub depth: usize,
pub var: String,
pub op_type: String,
pub status: FactStatus,
}
impl FunctionTraceEntry {
pub fn format_tree(&self) -> String {
let indent = " ".repeat(self.depth);
let prefix = if self.depth == 0 { "" } else { "├── " };
format!(
"{}{}{} ({}) {}",
indent, prefix, self.var, self.op_type, self.status
)
}
}
/// Result of dtype chain analysis.
#[derive(Debug, Serialize, Deserialize)]
pub struct DTypeChainAnalysis {
pub target: String,
pub chain: Vec<TraceEntry>,
pub first_missing: Option<String>,
pub all_resolved: bool,
}
impl DTypeChainAnalysis {
/// Create from dependency graph and target variable.
pub fn analyze(graph: &DependencyGraph, target: &str) -> Self {
let chain = graph.trace_back(target, 20);
let first_missing = chain
.iter()
.find(|e| matches!(&e.dtype, Some(DTypeStatus::Missing(_))))
.map(|e| e.var.clone());
let all_resolved = chain
.iter()
.all(|e| matches!(&e.dtype, Some(DTypeStatus::Resolved(_)) | None));
DTypeChainAnalysis {
target: target.to_string(),
chain,
first_missing,
all_resolved,
}
}
}
/// Result of function chain analysis.
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionChainAnalysis {
pub target: String,
pub fn_name: String,
pub chain: Vec<FunctionTraceEntry>,
pub first_missing: Option<String>,
pub all_resolved: bool,
}

View File

@@ -0,0 +1,95 @@
//! # Luminal Benchmark Infrastructure
//!
//! Universal benchmark framework for Luminal backends.
//!
//! ## Architecture
//!
//! - **BenchmarkBackend**: Trait that backends implement to enable benchmarking
//! - **BenchmarkPattern**: Trait for defining benchmark workloads
//! - **Micro benchmarks (L1)**: Single-operator performance tests (HLIR primitives)
//! - **Pattern benchmarks (L2)**: Composite operator performance tests
//!
//! Usage 和调试方式见 crate 根目录的 `README.md`。
mod metrics;
mod micro;
mod patterns;
/// Egglog debugging and analysis utilities.
/// This module is backend-agnostic; specific backends are selected via feature flags
/// in the debug_ops example.
pub mod egglog_debug;
pub use metrics::*;
pub use micro::*;
pub use patterns::*;
use luminal::op::Runtime;
use luminal::prelude::*;
/// Hardware information for a backend device
#[derive(Debug, Clone)]
pub struct HardwareInfo {
pub device_name: String,
pub memory_gb: f64,
/// Peak memory bandwidth in GB/s (if known)
pub peak_bandwidth_gbps: Option<f64>,
/// Peak compute throughput in TFLOPS (if known)
pub peak_tflops: Option<f64>,
}
/// Trait that backends implement to enable benchmarking
pub trait BenchmarkBackend {
type Runtime: Runtime;
/// Initialize the runtime
fn initialize() -> Self::Runtime;
/// Get backend name (used in reports)
fn name() -> &'static str;
/// Get hardware information
fn hardware_info() -> HardwareInfo;
}
/// Size configuration for benchmarks
#[derive(Debug, Clone, Copy)]
pub struct BenchSize {
pub name: &'static str,
pub value: usize,
}
impl BenchSize {
pub const fn new(name: &'static str, value: usize) -> Self {
Self { name, value }
}
}
/// Standard benchmark sizes for micro benchmarks
pub const MICRO_SIZES: &[BenchSize] = &[
BenchSize::new("1k", 1_000),
BenchSize::new("100k", 100_000),
BenchSize::new("1m", 1_000_000),
BenchSize::new("10m", 10_000_000),
];
/// Trait for defining benchmark workloads (dyn-compatible version)
pub trait BenchmarkPattern {
/// Pattern name (used in reports)
fn name(&self) -> &'static str;
/// Available sizes for this pattern
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
/// Build the computation graph for this pattern
fn build_graph(&self, cx: &mut Graph, size: BenchSize);
}
// Re-export backend implementations when features are enabled
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "metal")]
pub use metal_backend::MetalBenchmark;

View File

@@ -0,0 +1,63 @@
//! Metal backend implementation for benchmarking
use crate::{BenchmarkBackend, HardwareInfo};
use luminal::op::Runtime;
use luminal_metal::runtime::MetalRuntime;
/// Metal benchmark backend
pub struct MetalBenchmark;
impl BenchmarkBackend for MetalBenchmark {
type Runtime = MetalRuntime;
fn initialize() -> Self::Runtime {
MetalRuntime::initialize(())
}
fn name() -> &'static str {
"metal"
}
fn hardware_info() -> HardwareInfo {
// Try to get device info from Metal
let device = metal::Device::system_default().expect("No Metal device found");
let device_name = device.name().to_string();
// Estimate based on common Apple Silicon specs
let (memory_gb, peak_bandwidth_gbps, peak_tflops) = estimate_device_specs(&device_name);
HardwareInfo {
device_name,
memory_gb,
peak_bandwidth_gbps: Some(peak_bandwidth_gbps),
peak_tflops: Some(peak_tflops),
}
}
}
/// Estimate device specs based on device name
fn estimate_device_specs(device_name: &str) -> (f64, f64, f64) {
// Memory (GB), Bandwidth (GB/s), FP32 TFLOPS
if device_name.contains("M3 Max") {
(128.0, 400.0, 14.0)
} else if device_name.contains("M3 Pro") {
(36.0, 200.0, 7.0)
} else if device_name.contains("M3") {
(24.0, 100.0, 3.5)
} else if device_name.contains("M2 Max") {
(96.0, 400.0, 13.6)
} else if device_name.contains("M2 Pro") {
(32.0, 200.0, 6.8)
} else if device_name.contains("M2") {
(24.0, 100.0, 3.6)
} else if device_name.contains("M1 Max") {
(64.0, 400.0, 10.4)
} else if device_name.contains("M1 Pro") {
(32.0, 200.0, 5.2)
} else if device_name.contains("M1") {
(16.0, 68.0, 2.6)
} else {
// Generic fallback
(8.0, 50.0, 1.0)
}
}

View File

@@ -0,0 +1,331 @@
//! Benchmark metrics and mapping
//!
//! Provides a mapping from benchmark names to their constant metrics (bytes, flops).
//! Combined with Criterion's time measurements, this allows computing derived metrics
//! like throughput, MBU, and MFU.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
/// Constant metrics for a single benchmark configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchMetrics {
/// Total bytes transferred (loaded + stored)
pub bytes: usize,
/// Bytes loaded from memory
pub bytes_loaded: usize,
/// Bytes stored to memory
pub bytes_stored: usize,
/// Floating-point operations
pub flops: usize,
}
impl BenchMetrics {
pub fn new(bytes_loaded: usize, bytes_stored: usize, flops: usize) -> Self {
Self {
bytes: bytes_loaded + bytes_stored,
bytes_loaded,
bytes_stored,
flops,
}
}
/// Calculate throughput in GB/s given execution time in microseconds
pub fn throughput_gbps(&self, time_us: f64) -> f64 {
if time_us <= 0.0 {
return 0.0;
}
self.bytes as f64 / time_us / 1000.0
}
/// Calculate TFLOPS given execution time in microseconds
pub fn tflops(&self, time_us: f64) -> f64 {
if time_us <= 0.0 {
return 0.0;
}
self.flops as f64 / time_us / 1_000_000.0
}
/// Calculate MBU given execution time and peak bandwidth
pub fn mbu(&self, time_us: f64, peak_bandwidth_gbps: f64) -> f64 {
self.throughput_gbps(time_us) / peak_bandwidth_gbps * 100.0
}
/// Calculate MFU given execution time and peak TFLOPS
pub fn mfu(&self, time_us: f64, peak_tflops: f64) -> f64 {
self.tflops(time_us) / peak_tflops * 100.0
}
}
/// Hardware specifications for a benchmark target
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareSpec {
pub device_name: String,
pub memory_gb: f64,
pub peak_bandwidth_gbps: f64,
pub peak_tflops: f64,
}
/// Complete benchmark metrics mapping
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchMetricsMap {
/// Hardware specifications
pub hardware: HardwareSpec,
/// Mapping from "pattern/size" to metrics
pub benchmarks: HashMap<String, BenchMetrics>,
}
impl BenchMetricsMap {
pub fn new(hardware: HardwareSpec) -> Self {
Self {
hardware,
benchmarks: HashMap::new(),
}
}
/// Add metrics for a benchmark
pub fn add(&mut self, pattern: &str, size: &str, metrics: BenchMetrics) {
let key = format!("{}/{}", pattern, size);
self.benchmarks.insert(key, metrics);
}
/// Get metrics for a benchmark
pub fn get(&self, pattern: &str, size: &str) -> Option<&BenchMetrics> {
let key = format!("{}/{}", pattern, size);
self.benchmarks.get(&key)
}
/// Export to JSON
pub fn to_json(&self) -> String {
serde_json::to_string_pretty(self).unwrap_or_default()
}
/// Save to file
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
let json = self.to_json();
std::fs::write(path, json)
}
/// Load from file
pub fn load(path: &std::path::Path) -> std::io::Result<Self> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
// ============================================================================
// Legacy types (kept for compatibility)
// ============================================================================
/// Result of a single benchmark run
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchResult {
/// Backend name (e.g., "metal", "cuda")
pub backend: String,
/// Benchmark pattern name (e.g., "add_vec")
pub benchmark: String,
/// Size label (e.g., "1m")
pub size_label: String,
/// Actual size value
pub size_value: usize,
/// Mean execution time in microseconds
pub mean_us: f64,
/// Standard deviation in microseconds
pub std_us: f64,
/// Throughput in GB/s (if applicable)
pub throughput_gbps: Option<f64>,
/// Memory Bandwidth Utilization (if peak bandwidth known)
pub mbu: Option<f64>,
}
impl BenchResult {
/// Calculate throughput given bytes transferred
pub fn with_throughput(mut self, bytes: usize) -> Self {
if self.mean_us > 0.0 {
// bytes / microseconds = MB/s, then convert to GB/s
self.throughput_gbps = Some((bytes as f64) / self.mean_us / 1000.0);
}
self
}
/// Calculate MBU given peak bandwidth
pub fn with_mbu(mut self, peak_bandwidth_gbps: f64) -> Self {
if let Some(throughput) = self.throughput_gbps {
self.mbu = Some(throughput / peak_bandwidth_gbps * 100.0);
}
self
}
}
/// Collection of benchmark results for reporting
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchReport {
pub backend: String,
pub hardware: String,
pub results: Vec<BenchResult>,
}
impl BenchReport {
pub fn new(backend: &str, hardware: &str) -> Self {
Self {
backend: backend.to_string(),
hardware: hardware.to_string(),
results: Vec::new(),
}
}
pub fn add_result(&mut self, result: BenchResult) {
self.results.push(result);
}
/// Export to JSON (for CI integration)
pub fn to_json(&self) -> String {
serde_json::to_string_pretty(self).unwrap_or_default()
}
}
// ============================================================================
// Full Report with Derived Metrics
// ============================================================================
/// Single benchmark result with all metrics (constant + derived)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullBenchResult {
pub pattern: String,
pub size: String,
pub size_value: usize,
/// Execution time in microseconds
pub time_us: f64,
/// Bytes transferred
pub bytes: usize,
/// Floating-point operations
pub flops: usize,
/// Throughput in GB/s
pub throughput_gbps: f64,
/// Memory Bandwidth Utilization (%)
pub mbu_percent: f64,
/// Compute in TFLOPS
pub tflops: f64,
/// Model FLOPs Utilization (%)
pub mfu_percent: f64,
}
/// Full benchmark report with derived metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullBenchReport {
pub hardware: HardwareSpec,
pub timestamp: String,
pub results: Vec<FullBenchResult>,
}
impl FullBenchReport {
pub fn new(hardware: HardwareSpec) -> Self {
let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string();
Self {
hardware,
timestamp,
results: Vec::new(),
}
}
pub fn add_result(&mut self, result: FullBenchResult) {
self.results.push(result);
}
/// Save to JSON file
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
let json = serde_json::to_string_pretty(self).unwrap_or_default();
std::fs::write(path, json)
}
/// Print summary table to terminal
pub fn print_summary(&self) {
println!("\n{}", "=".repeat(100));
println!("BENCHMARK RESULTS - {}", self.hardware.device_name);
println!(
"Peak Bandwidth: {:.0} GB/s | Peak Compute: {:.1} TFLOPS",
self.hardware.peak_bandwidth_gbps, self.hardware.peak_tflops
);
println!("{}", "=".repeat(100));
println!(
"{:<20} {:>8} {:>12} {:>10} {:>8} {:>10} {:>8}",
"Pattern", "Size", "Time(μs)", "GB/s", "MBU%", "TFLOPS", "MFU%"
);
println!("{}", "-".repeat(100));
for r in &self.results {
println!(
"{:<20} {:>8} {:>12.2} {:>10.2} {:>7.1}% {:>10.4} {:>7.1}%",
r.pattern,
r.size,
r.time_us,
r.throughput_gbps,
r.mbu_percent,
r.tflops,
r.mfu_percent
);
}
println!("{}", "=".repeat(100));
}
}
/// Thread-safe collector for benchmark results
#[derive(Clone)]
pub struct BenchResultCollector {
hardware: HardwareSpec,
results: Arc<Mutex<Vec<FullBenchResult>>>,
}
impl BenchResultCollector {
pub fn new(hardware: HardwareSpec) -> Self {
Self {
hardware,
results: Arc::new(Mutex::new(Vec::new())),
}
}
/// Add a benchmark result
pub fn add(
&self,
pattern: &str,
size: &str,
size_value: usize,
time_us: f64,
metrics: &BenchMetrics,
) {
let throughput_gbps = metrics.throughput_gbps(time_us);
let tflops = metrics.tflops(time_us);
let mbu_percent = metrics.mbu(time_us, self.hardware.peak_bandwidth_gbps);
let mfu_percent = metrics.mfu(time_us, self.hardware.peak_tflops);
let result = FullBenchResult {
pattern: pattern.to_string(),
size: size.to_string(),
size_value,
time_us,
bytes: metrics.bytes,
flops: metrics.flops,
throughput_gbps,
mbu_percent,
tflops,
mfu_percent,
};
self.results.lock().unwrap().push(result);
}
/// Generate full report
pub fn into_report(self) -> FullBenchReport {
let mut report = FullBenchReport::new(self.hardware);
report.results = self.results.lock().unwrap().clone();
// Sort by pattern name, then by size
report.results.sort_by(|a, b| {
a.pattern
.cmp(&b.pattern)
.then_with(|| a.size_value.cmp(&b.size_value))
});
report
}
}

View File

@@ -0,0 +1,338 @@
//! L1 micro benchmark patterns (single-op graphs), used by `benches/micro.rs`.
use crate::{BenchSize, BenchmarkPattern, MICRO_SIZES};
use luminal::prelude::*;
// ============================================================================
// Binary Operators
// ============================================================================
/// Vector addition benchmark: a + b
#[derive(Debug, Default)]
pub struct AddVec;
impl BenchmarkPattern for AddVec {
fn name(&self) -> &'static str {
"add_vec"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let b = cx.tensor(size.value);
let _ = (a + b).output();
}
}
/// Vector multiplication benchmark: a * b
#[derive(Debug, Default)]
pub struct MulVec;
impl BenchmarkPattern for MulVec {
fn name(&self) -> &'static str {
"mul_vec"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let b = cx.tensor(size.value);
let _ = (a * b).output();
}
}
/// Vector modulo benchmark: a % b
#[derive(Debug, Default)]
pub struct ModVec;
impl BenchmarkPattern for ModVec {
fn name(&self) -> &'static str {
"mod_vec"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let b = cx.tensor(size.value);
let _ = (a % b).output();
}
}
/// Vector less-than comparison benchmark: a < b
#[derive(Debug, Default)]
pub struct LessThanVec;
impl BenchmarkPattern for LessThanVec {
fn name(&self) -> &'static str {
"less_than_vec"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let b = cx.tensor(size.value);
let _ = a.lt(b).output();
}
}
// ============================================================================
// Reduction Operators
// ============================================================================
/// Sum reduction benchmark: sum(a)
#[derive(Debug, Default)]
pub struct SumReduce;
impl BenchmarkPattern for SumReduce {
fn name(&self) -> &'static str {
"sum_reduce"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.sum(0).output();
}
}
/// Max reduction benchmark: max(a)
#[derive(Debug, Default)]
pub struct MaxReduce;
impl BenchmarkPattern for MaxReduce {
fn name(&self) -> &'static str {
"max_reduce"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.max(0).output();
}
}
// ============================================================================
// Unary Operators
// ============================================================================
/// Exp2 benchmark: 2^x
#[derive(Debug, Default)]
pub struct Exp2Bench;
impl BenchmarkPattern for Exp2Bench {
fn name(&self) -> &'static str {
"exp2"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.exp2().output();
}
}
/// Log2 benchmark: log2(x)
#[derive(Debug, Default)]
pub struct Log2Bench;
impl BenchmarkPattern for Log2Bench {
fn name(&self) -> &'static str {
"log2"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.log2().output();
}
}
/// Sin benchmark: sin(x)
#[derive(Debug, Default)]
pub struct SinBench;
impl BenchmarkPattern for SinBench {
fn name(&self) -> &'static str {
"sin"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.sin().output();
}
}
/// Recip benchmark: 1/x
#[derive(Debug, Default)]
pub struct RecipBench;
impl BenchmarkPattern for RecipBench {
fn name(&self) -> &'static str {
"recip"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.reciprocal().output();
}
}
/// Sqrt benchmark: sqrt(x)
#[derive(Debug, Default)]
pub struct SqrtBench;
impl BenchmarkPattern for SqrtBench {
fn name(&self) -> &'static str {
"sqrt"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
let _ = a.sqrt().output();
}
}
// ============================================================================
// Indexing Operators
// ============================================================================
/// Gather benchmark: gather(data, indices)
#[derive(Debug, Default)]
pub struct GatherBench;
impl BenchmarkPattern for GatherBench {
fn name(&self) -> &'static str {
"gather"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
// Simple 1D gather: data[indices]
// data: 1D tensor of size.value elements
// indices: 1D tensor selecting num_indices elements
let num_indices = 1024.min(size.value);
let data = cx.tensor(size.value);
// Indices must be integer type for gather operation
let indices = cx.tensor(num_indices).as_dtype(luminal::dtype::DType::Int);
let _ = data.gather(indices).output();
}
}
/// Cast benchmark: type conversion (f32 -> f16 -> f32)
#[derive(Debug, Default)]
pub struct CastBench;
impl BenchmarkPattern for CastBench {
fn name(&self) -> &'static str {
"cast"
}
fn sizes(&self) -> &[BenchSize] {
MICRO_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let a = cx.tensor(size.value);
// Cast to f16 then back to f32 to measure round-trip cost
let _ = a
.cast(luminal::dtype::DType::F16)
.cast(luminal::dtype::DType::F32)
.output();
}
}
// ============================================================================
// Pattern Registry
// ============================================================================
/// Get all micro benchmark patterns (HLIR primitives supported by Metal)
pub fn all_micro_patterns() -> Vec<Box<dyn BenchmarkPattern>> {
vec![
// Binary operators
Box::new(AddVec),
Box::new(MulVec),
Box::new(ModVec),
Box::new(LessThanVec),
// Reduction operators
Box::new(SumReduce),
Box::new(MaxReduce),
// Unary operators
Box::new(Exp2Bench),
Box::new(Log2Bench),
Box::new(SinBench),
Box::new(RecipBench),
Box::new(SqrtBench),
// Indexing operators
Box::new(GatherBench),
// Note: CastBench removed - Metal backend does not implement Cast yet
]
}
/// Calculate bytes transferred for a benchmark pattern
pub fn bytes_for_pattern(pattern_name: &str, size: usize) -> usize {
let elem_size = std::mem::size_of::<f32>();
match pattern_name {
// Binary operators: read 2 inputs + write 1 output = 3 * size * 4 bytes
"add_vec" | "mul_vec" | "mod_vec" | "less_than_vec" => 3 * size * elem_size,
// Reduction operators: read 1 input + write 1 output (scalar)
"sum_reduce" | "max_reduce" => size * elem_size + elem_size,
// Unary operators: read 1 input + write 1 output = 2 * size * 4 bytes
"exp2" | "log2" | "sin" | "recip" | "sqrt" => 2 * size * elem_size,
// Cast: read 1 input (f32) + write intermediate (f16) + read (f16) + write output (f32)
// Simplified: 2 * size * 4 bytes (f32 in + f32 out)
"cast" => 2 * size * elem_size,
// Gather: read indices + read gathered data + write output
// Simple 1D gather: indices (num_indices * 4) + read data + write output
"gather" => {
let num_indices = 1024.min(size);
// Read indices (i32) + read data (random access, ~num_indices elements) + write output
num_indices * elem_size + num_indices * elem_size + num_indices * elem_size
}
_ => 0,
}
}

View File

@@ -0,0 +1,299 @@
//! L2 pattern benchmark patterns (composite graphs), used by `benches/patterns.rs`.
use crate::{BenchSize, BenchmarkPattern};
use luminal::prelude::*;
// ============================================================================
// Size Configurations
// ============================================================================
/// Matrix multiplication size configuration
#[derive(Debug, Clone, Copy)]
pub struct MatMulSize {
pub name: &'static str,
pub m: usize,
pub k: usize,
pub n: usize,
}
impl MatMulSize {
pub const fn new(name: &'static str, m: usize, k: usize, n: usize) -> Self {
Self { name, m, k, n }
}
}
/// Dummy size for patterns that handle sizes internally
pub const CUSTOM_SIZE: &[BenchSize] = &[BenchSize::new("custom", 0)];
/// Standard matrix multiplication sizes
pub const MATMUL_SIZES: &[MatMulSize] = &[
// Square matrices
MatMulSize::new("128x128", 128, 128, 128),
MatMulSize::new("512x512", 512, 512, 512),
MatMulSize::new("1024x1024", 1024, 1024, 1024),
// LLM-like shapes (batch=1, hidden_dim, ffn_dim)
MatMulSize::new("1x4096x4096", 1, 4096, 4096),
// MatMulSize::new("32x4096x4096", 32, 4096, 4096),
];
/// Transformer-like sizes for softmax, layernorm, etc.
pub const TRANSFORMER_SIZES: &[BenchSize] = &[
BenchSize::new("128x128", 128 * 128), // small attention
BenchSize::new("512x512", 512 * 512), // medium attention
BenchSize::new("2048x128", 2048 * 128), // typical seq_len x head_dim
// BenchSize::new("4096x128", 4096 * 128), // long context
];
/// Attention size configurations (seq_len, head_dim)
pub const ATTENTION_SIZES: &[(usize, usize)] = &[
(128, 64), // small: seq=128, head_dim=64
(512, 64), // medium: seq=512, head_dim=64
(1024, 64), // large: seq=1024, head_dim=64
// (2048, 64), // xlarge: seq=2048, head_dim=64
];
// ============================================================================
// MatMul Pattern
// ============================================================================
/// Matrix multiplication benchmark: C = A @ B
#[derive(Debug, Clone, Copy)]
pub struct MatMulBench {
pub size: MatMulSize,
}
impl MatMulBench {
pub fn new(size: MatMulSize) -> Self {
Self { size }
}
}
impl BenchmarkPattern for MatMulBench {
fn name(&self) -> &'static str {
"matmul"
}
fn sizes(&self) -> &[BenchSize] {
CUSTOM_SIZE
}
fn build_graph(&self, cx: &mut Graph, _size: BenchSize) {
let a = cx.tensor((self.size.m, self.size.k));
let b = cx.tensor((self.size.k, self.size.n));
let _ = a.matmul(b).output();
}
}
// ============================================================================
// Softmax Pattern
// ============================================================================
/// Softmax benchmark: softmax(x, axis=-1)
#[derive(Debug, Default)]
pub struct SoftmaxBench;
impl BenchmarkPattern for SoftmaxBench {
fn name(&self) -> &'static str {
"softmax"
}
fn sizes(&self) -> &[BenchSize] {
TRANSFORMER_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
// Reshape to 2D for softmax along last axis
// Assume size.value = rows * cols, use sqrt for balanced shape
let dim = (size.value as f64).sqrt() as usize;
let rows = size.value / dim;
let cols = dim;
let x = cx.tensor((rows, cols));
// Softmax along last axis (axis 1)
let _ = x.softmax(1).output();
}
}
// ============================================================================
// LayerNorm Pattern
// ============================================================================
/// Layer normalization benchmark
#[derive(Debug, Default)]
pub struct LayerNormBench;
impl BenchmarkPattern for LayerNormBench {
fn name(&self) -> &'static str {
"layer_norm"
}
fn sizes(&self) -> &[BenchSize] {
TRANSFORMER_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
// Typical shape: (batch * seq_len, hidden_dim)
let hidden_dim = 128;
let batch_seq = size.value / hidden_dim;
let x = cx.tensor((batch_seq.max(1), hidden_dim));
// LayerNorm along last axis with epsilon
let _ = x.layer_norm(1, 1e-5).output();
}
}
// ============================================================================
// GeLU Pattern
// ============================================================================
/// GeLU activation benchmark
#[derive(Debug, Default)]
pub struct GeLUBench;
impl BenchmarkPattern for GeLUBench {
fn name(&self) -> &'static str {
"gelu"
}
fn sizes(&self) -> &[BenchSize] {
TRANSFORMER_SIZES
}
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
let x = cx.tensor(size.value);
let _ = x.gelu().output();
}
}
// ============================================================================
// Attention Pattern
// ============================================================================
/// Self-attention benchmark: softmax(Q @ K^T / sqrt(d)) @ V
#[derive(Debug, Clone, Copy)]
pub struct AttentionBench {
pub seq_len: usize,
pub head_dim: usize,
}
impl AttentionBench {
pub fn new(seq_len: usize, head_dim: usize) -> Self {
Self { seq_len, head_dim }
}
}
impl Default for AttentionBench {
fn default() -> Self {
Self {
seq_len: 512,
head_dim: 64,
}
}
}
impl BenchmarkPattern for AttentionBench {
fn name(&self) -> &'static str {
"attention"
}
fn sizes(&self) -> &[BenchSize] {
CUSTOM_SIZE
}
fn build_graph(&self, cx: &mut Graph, _size: BenchSize) {
let seq_len = self.seq_len;
let head_dim = self.head_dim;
// Q, K, V tensors: (seq_len, head_dim)
let q = cx.tensor((seq_len, head_dim));
let k = cx.tensor((seq_len, head_dim));
let v = cx.tensor((seq_len, head_dim));
// Attention: softmax(Q @ K^T / sqrt(d)) @ V
// Q @ K^T -> (seq_len, seq_len)
let scores = q.matmul(k.permute((1, 0)));
// Scale by 1/sqrt(head_dim)
let scale = 1.0 / (head_dim as f32).sqrt();
let scaled_scores = scores * scale;
// Softmax along last axis
let attn_weights = scaled_scores.softmax(1);
// @ V -> (seq_len, head_dim)
let _ = attn_weights.matmul(v).output();
}
}
// ============================================================================
// Pattern Registry
// ============================================================================
/// Get all high-priority pattern benchmarks
pub fn all_pattern_benchmarks() -> Vec<Box<dyn BenchmarkPattern>> {
let mut patterns: Vec<Box<dyn BenchmarkPattern>> = vec![];
// MatMul patterns with different sizes
for size in MATMUL_SIZES {
patterns.push(Box::new(MatMulBench::new(*size)));
}
// Softmax
patterns.push(Box::new(SoftmaxBench));
// LayerNorm
patterns.push(Box::new(LayerNormBench));
// GeLU
patterns.push(Box::new(GeLUBench));
// Attention patterns with different sizes
for (seq_len, head_dim) in ATTENTION_SIZES {
patterns.push(Box::new(AttentionBench::new(*seq_len, *head_dim)));
}
patterns
}
/// Calculate bytes transferred for pattern benchmarks
pub fn bytes_for_pattern_bench(
pattern_name: &str,
size: usize,
extra: Option<(usize, usize, usize)>,
) -> usize {
let elem_size = std::mem::size_of::<f32>();
match pattern_name {
"matmul" => {
if let Some((m, k, n)) = extra {
// Read A (m*k) + Read B (k*n) + Write C (m*n)
(m * k + k * n + m * n) * elem_size
} else {
0
}
}
"softmax" => {
// Read input + Write output (same size)
2 * size * elem_size
}
"layer_norm" => {
// Read input + Write output
2 * size * elem_size
}
"gelu" => {
// Read input + Write output
2 * size * elem_size
}
"attention" => {
if let Some((seq_len, head_dim, _)) = extra {
// Q, K, V reads: 3 * seq_len * head_dim
// scores: seq_len * seq_len
// output: seq_len * head_dim
(3 * seq_len * head_dim + seq_len * seq_len + seq_len * head_dim) * elem_size
} else {
0
}
}
_ => 0,
}
}

View File

@@ -0,0 +1,2 @@
[env]
RUST_TEST_THREADS = "1"

View File

@@ -0,0 +1,33 @@
[package]
name = "luminal_cuda_lite"
version = "0.2.0"
edition = "2024"
description = "Cuda compiler for luminal"
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
luminal = { path = "../.." }
luminal_tracing = { path = "../luminal_tracing" }
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
as-any = "0.3.2"
itertools = "0.12.1"
fixedbitset = "0.5.7"
safetensors = "0.7.0"
tracing = "0.1.43"
half = { version = "2.7.1", features = ["num-traits"] }
pretty-duration = "0.1.1"
bytemuck = "1.24.0"
memmap2 = "0.9.9"
uuid = {version="1.19.0", features=["v4"]}
lru = "0.16.2"
libc = "0.2"
colorize = "*"
[dev-dependencies]
candle-core = { version = "0.9.2", features = ["cuda"] }
proptest = "1.9.0"
rand = "0.9.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
num-traits = "0.2"

View File

@@ -0,0 +1,29 @@
## luminal_cuda_lite
This crate contains the CUDA backend for Luminal.
The backend can be broken down into several main types of ops. Starting from the highest level and going lower:
#### Host Ops
Host ops are opaque operations executed from the host (can execute on device, simply launched in an opaque manner). cuBLAS is a good example of this type of op. Luminal can't assume much about these operations since they are so opaque. These ops implement the `HostOp` trait.
#### Kernel Ops
Kernel ops are operations encoded as a kernel and launch parameters. Luminal can put these into CUDA graphs. Cutlass kernels are good examples of these. These ops implement the `KernelOp` trait.
#### Block Ops
Block ops are operations encoded on the threadblock level, which implement an operation that runs for a duration within a single threadblock. These are required to use a fixed number of threads per threadblock (or gate unused threads out), and are given a fixed-size shared memory scratchpad. Luminal can fuse these operations into megakernels. These ops impelement the `BlockOp` trait.
#### Warp Ops
Warp ops are not yet merged. Stay tuned!
#### Thread Ops
Thread ops are not yet merged. Stay tuned!
### Architecture
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.

View File

@@ -0,0 +1,258 @@
use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND, STRING},
extract_expr,
},
op::{EgglogOp, LLIROp},
prelude::{
tracing::{Level, span, trace},
*,
},
};
use crate::{
cudarc::{
cublas::{
CudaBlas,
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
},
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::HostOp,
};
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
// Strip quotes if present (egglog strings are stored with quotes)
let stripped = s.trim_matches('"');
match stripped {
"T" => cublasOperation_t::CUBLAS_OP_T,
"N" => cublasOperation_t::CUBLAS_OP_N,
"C" => cublasOperation_t::CUBLAS_OP_C,
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct CuBlasSgemmV2 {
m: Expression,
n: Expression,
k: Expression,
a_layout: cublasOperation_t,
b_layout: cublasOperation_t,
lda: Expression,
ldb: Expression,
ldc: Expression,
/// Lazily initialized cuBLAS handle - created on first execute
cublas: OnceLock<Arc<CudaBlas>>,
}
// Useless default for IntoEgglogOp
impl Default for CuBlasSgemmV2 {
fn default() -> Self {
Self {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
cublas: OnceLock::new(),
}
}
}
impl EgglogOp for CuBlasSgemmV2 {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"cublasSgemmV2",
&[
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
("a_layout", STRING),
("b_layout", STRING),
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
]
}
#[allow(unused_variables)]
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
let extracted_state = Self {
m,
n,
k,
a_layout,
b_layout,
lda,
ldb,
ldc,
cublas: OnceLock::new(),
};
trace!(?extracted_state);
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, input_enodes)
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for CuBlasSgemmV2 {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
let m = self.m.exec(dyn_map).unwrap() as i32;
let n = self.n.exec(dyn_map).unwrap() as i32;
let k = self.k.exec(dyn_map).unwrap() as i32;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = self.lda.exec(dyn_map).unwrap() as i32;
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
let alpha = 1.0f32;
let beta = 0.0f32;
// Get buffers: output is self_node, inputs are from graph edges
let c_buf = buffers[&self_node];
let a_buf = buffers[&inputs[0]];
let b_buf = buffers[&inputs[1]];
// Get device pointers
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
// Debug: Check buffer sizes
trace!(
"buffer_validation {}=={},{}=={},{}=={}",
a_buf.len(),
m * k * 4,
b_buf.len(),
k * n * 4,
c_buf.len(),
m * n * 4
);
let _sgemm_span = span!(
Level::TRACE,
"cuBLAS_SGEMM_V2",
m,
n,
k,
alpha,
beta,
lda,
ldb,
ldc,
?a_layout,
?b_layout,
)
.entered();
// Use shared cuBLAS handle to avoid per-operation workspace allocation
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
// Set the stream for this operation (cuBLAS handle can work with any stream)
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
unsafe {
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
}
let status = unsafe {
cublasSgemm_v2(
*cublas.handle(),
a_layout,
b_layout,
m,
n,
k,
&alpha as *const f32,
a_ptr as *const f32,
lda,
b_ptr as *const f32,
ldb,
&beta as *const f32,
c_ptr as *mut f32,
ldc,
)
};
stream.synchronize().unwrap();
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(anyhow::anyhow!(
"cuBLAS SGEMM TN failed with status: {:?}",
status
));
}
Ok(())
}
fn output_size(&self) -> Expression {
self.m * self.n
}
fn output_bytes(&self) -> Expression {
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
self.output_size() * 4
}
}

View File

@@ -0,0 +1,72 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:name "cublas sgemm column-major × column-major"
)

View File

@@ -0,0 +1,72 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:name "cublas sgemm column-major × row-major"
)

View File

@@ -0,0 +1,72 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:name "cublas sgemm row-major × column-major"
)

View File

@@ -0,0 +1,72 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)
:name "cublas sgemm row-major"
)

View File

@@ -0,0 +1,133 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "T"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × column-major"
)

View File

@@ -0,0 +1,133 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "T"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × row-major"
)

View File

@@ -0,0 +1,133 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × column-major"
)

View File

@@ -0,0 +1,139 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; In broadcast [batch, m, n, k] space:
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape: [batch, m, n]
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
; A strides in [batch, m, n, k]
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; B strides in [batch, m, n, k]
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "N"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc (contiguous output per batch)
?batch ; batch_count
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × row-major"
)

View File

@@ -0,0 +1,476 @@
use std::sync::{Arc, OnceLock};
use luminal::{
dtype::DType,
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, EXPRESSION, OP_KIND, STRING},
extract_dtype, extract_expr,
},
op::{EgglogOp, LLIROp},
prelude::{
tracing::{Level, span, trace},
*,
},
};
use crate::{
cudarc::{
cublas::sys::cublasOperation_t,
cublaslt::{
CudaBlasLT, MatmulShared,
sys::{
cublasComputeType_t, cublasLtMatmul, cublasLtMatmulAlgoGetHeuristic,
cublasLtMatmulDesc_t, cublasLtMatmulDescCreate, cublasLtMatmulDescDestroy,
cublasLtMatmulDescSetAttribute, cublasLtMatmulHeuristicResult_t,
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t,
cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceDestroy,
cublasLtMatmulPreferenceSetAttribute, cublasLtMatrixLayout_t,
cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy, cudaDataType,
},
},
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::{HostOp, cublas::parse_cublas_op},
};
#[derive(Debug)]
#[allow(dead_code)]
pub struct CuBlasLt {
m: Expression,
n: Expression,
k: Expression,
a_layout: cublasOperation_t,
b_layout: cublasOperation_t,
lda: Expression,
ldb: Expression,
ldc: Expression,
batch_count: Expression,
stride_a: Expression,
stride_b: Expression,
stride_c: Expression,
dtype: DType,
cublaslt: OnceLock<Arc<CudaBlasLT>>,
}
// Useless default for IntoEgglogOp
impl Default for CuBlasLt {
fn default() -> Self {
Self {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N,
b_layout: cublasOperation_t::CUBLAS_OP_T,
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
batch_count: 1.into(),
stride_a: 0.into(),
stride_b: 0.into(),
stride_c: 0.into(),
dtype: DType::F32,
cublaslt: OnceLock::new(),
}
}
}
impl EgglogOp for CuBlasLt {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"cublaslt",
&[
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
("a_layout", STRING),
("b_layout", STRING),
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
("batch_count", EXPRESSION),
("stride_a", EXPRESSION),
("stride_b", EXPRESSION),
("stride_c", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
// (not the Mul eclass), so they survive the cascade.
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
]
}
#[allow(unused_variables)]
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
// Extract batch parameters
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
// Extract dtype from egglog
let dtype = extract_dtype(egraph, kind_children[12]);
let extracted_state = Self {
m,
n,
k,
a_layout,
b_layout,
lda,
ldb,
ldc,
batch_count,
stride_a,
stride_b,
stride_c,
dtype,
cublaslt: OnceLock::new(),
};
trace!(?extracted_state);
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, input_enodes)
}
fn cleanup(&self) -> bool {
false
}
}
/// Convert DType to CUDA types for cuBLAS LT
/// Returns (matrix_dtype, compute_type, scale_dtype)
fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cudaDataType) {
match dtype {
// F64: matrix=f64, compute=f64, scale=f64
DType::F64 => (
cudaDataType::CUDA_R_64F,
cublasComputeType_t::CUBLAS_COMPUTE_64F,
cudaDataType::CUDA_R_64F,
),
// F32: matrix=f32, compute=f32, scale=f32
DType::F32 => (
cudaDataType::CUDA_R_32F,
cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudaDataType::CUDA_R_32F,
),
// F16: matrix=f16, compute=f32 (FP32 accumulation for accuracy), scale=f32
DType::F16 => (
cudaDataType::CUDA_R_16F,
cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudaDataType::CUDA_R_32F,
),
// BF16: matrix=bf16, compute=f32 with tensor cores, scale=f32
DType::Bf16 => (
cudaDataType::CUDA_R_16BF,
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
cudaDataType::CUDA_R_32F,
),
// TF32: stored as f32, use fast TF32 tensor core path
DType::TF32 => (
cudaDataType::CUDA_R_32F,
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32,
cudaDataType::CUDA_R_32F,
),
// FP8 E4M3: matrix=fp8_e4m3, compute=f32, scale=f32
DType::F8E4M3 => (
cudaDataType::CUDA_R_8F_E4M3,
cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudaDataType::CUDA_R_32F,
),
// FP8 E5M2: matrix=fp8_e5m2, compute=f32, scale=f32
DType::F8E5M2 => (
cudaDataType::CUDA_R_8F_E5M2,
cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudaDataType::CUDA_R_32F,
),
DType::Int => panic!("cuBLAS LT does not support integer matmul"),
DType::Bool => panic!("cuBLAS LT does not support bool matmul"),
other => todo!("cuBLAS LT matmul not yet implemented for {other}"),
}
}
impl HostOp for CuBlasLt {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
use crate::cudarc::cublaslt::sys::{
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
};
// GEMM parameters — resolve z→1 for element stride before exec
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
// Get CUDA types based on dtype
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
let element_size = (self.dtype.bits() / 8) as u64;
assert!(
element_size > 0,
"cuBLAS LT does not support sub-byte dtype {}",
self.dtype
);
// Alpha/beta scale values (all dtypes use F32 scale type)
let alpha_f32: f32 = 1.0;
let beta_f32: f32 = 0.0;
// Get buffers: output is self_node, inputs are from graph edges
let c_buf = buffers[&self_node];
let a_buf = buffers[&inputs[0]];
let b_buf = buffers[&inputs[1]];
// Get device pointers
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
// Clamp leading dimensions to minimum valid values.
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
// dimension may be 0 in the egglog representation, but cuBLAS requires
// lda >= rows_of_A and ldb >= rows_of_B.
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N {
m
} else {
k
};
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N {
k
} else {
n
};
let lda = std::cmp::max(lda, a_ld_min as i64);
let ldb = std::cmp::max(ldb, b_ld_min as i64);
let ldc = std::cmp::max(ldc, m as i64);
let _span = span!(
Level::TRACE,
"cuBLASLT",
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
)
.entered();
let cublaslt = self
.cublaslt
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
// Allocate workspace (32 MiB)
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
unsafe {
// Create matmul descriptor (compute_type, scale_type for alpha/beta)
cublasLtMatmulDescCreate(&mut matmul_desc, compute_type, scale_dtype).result()?;
// Set transpose attributes
cublasLtMatmulDescSetAttribute(
matmul_desc,
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
&a_layout as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
cublasLtMatmulDescSetAttribute(
matmul_desc,
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
&b_layout as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
// Create matrix layout descriptors
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
cublasLtMatrixLayoutCreate(&mut a_desc, cuda_dtype, a_rows, a_cols, lda).result()?;
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
// Set batched GEMM attributes if batch_count > 1
if batch_count > 1 {
for (desc, stride) in [(a_desc, stride_a), (b_desc, stride_b), (c_desc, stride_c)] {
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i32>(),
)
.result()?;
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i64>(),
)
.result()?;
}
}
// Create preference and set workspace size
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
// Get heuristic (best algorithm)
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
matmul_desc,
a_desc,
b_desc,
c_desc,
c_desc, // D layout same as C
preference,
1, // Request 1 result
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
cublasLtMatmul(
*cublaslt.handle(),
matmul_desc,
alpha_ptr,
a_ptr as *const std::ffi::c_void,
a_desc,
b_ptr as *const std::ffi::c_void,
b_desc,
beta_ptr,
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
c_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
stream.cu_stream() as *mut _,
)
.result()?;
// Cleanup
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
}
stream.synchronize()?;
Ok(())
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
}

View File

@@ -0,0 +1,63 @@
use std::{fmt::Debug, sync::Arc};
use crate::cudarc::driver::{CudaSlice, CudaStream};
use luminal::{op::EgglogOp, prelude::*};
mod cublas;
mod cublaslt;
pub mod moe;
pub type Ops = (
// cublas::CuBlasSgemmV2,
cublaslt::CuBlasLt,
moe::GLUMoE,
);
/// Host operations that execute on the CPU but orchestrate GPU work.
///
/// This includes operations like cuBLAS calls and CUDA graph executions.
pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
/// Execute the operation with access to buffers via a map.
///
/// # Arguments
/// * `stream` - The CUDA stream to execute on
/// * `self_node` - The NodeIndex of this op in the llir_graph (used as output buffer)
/// * `inputs` - NodeIndices of input nodes (in edge order from the graph)
/// * `buffers` - Map from NodeIndex to device buffer for all allocated nodes
/// * `dyn_map` - Dynamic dimension values
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()>;
/// Returns the output buffer size in elements.
/// Return 0 if this op doesn't have a single output buffer (e.g., CudaGraphOp).
fn output_size(&self) -> Expression;
/// Returns the output buffer size in bytes (accounts for dtype).
fn output_bytes(&self) -> Expression;
/// Returns additional nodes (beyond graph edges) that this op needs buffers for.
///
/// For most ops, this returns empty (buffers determined by graph edges).
/// For CudaGraphOp, this returns all internal kernel nodes.
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
vec![]
}
/// Returns buffer size requirements for extra nodes (node -> size in elements).
///
/// Called during buffer allocation to ensure all required buffers exist.
/// For CudaGraphOp, this returns sizes for all internal kernel output buffers.
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
FxHashMap::default()
}
/// Returns the name of this host op for stats reporting, or None if not reportable.
fn stats_name(&self) -> Option<&'static str> {
None
}
}

View File

@@ -0,0 +1,128 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
; t53: Cast to F32
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
; t54: Iota for within-expert index
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
; t55: Cast within to F32
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
; t57: Cast to Int
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
; t58: Gather gate_up weights
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
; t61: SumReduce over K dimension
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Op (Constant -1.000000) (INil)))
; t65: gate * -1
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
; t66: Constant(log2e)
(= ?log2e (Op (Constant 1.442695) (INil)))
; t67: neg_gate * log2e
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
; t68: exp2
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
; t69: Constant(1)
(= ?one (Op (Constant 1.000000) (INil)))
; t70: exp2 + 1
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
; t71: recip
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
; t73: silu(gate) * up
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
; t76: Cast to F32
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
; t77: Iota for within-expert index
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
; t78: Cast within to F32
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
; t79: Add base + within
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
; t80: Cast to Int
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
; t81: Gather down weights
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
; t84: SumReduce
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
)
(
(let ?glumoe (Op (GLUMoE
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range)
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
)

View File

@@ -0,0 +1,662 @@
use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND},
extract_expr,
},
op::{EgglogOp, LLIROp},
prelude::*,
shape::Expression,
};
use crate::{
compile_module_image_for_current_device,
cudarc::{
cublas::sys::cublasOperation_t,
cublaslt::{
CudaBlasLT, MatmulShared,
sys::{
cublasComputeType_t, cublasLtMatmul, cublasLtMatmulAlgoGetHeuristic,
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, cublasLtMatmulDescCreate,
cublasLtMatmulDescDestroy, cublasLtMatmulDescSetAttribute,
cublasLtMatmulHeuristicResult_t, cublasLtMatmulPreference_t,
cublasLtMatmulPreferenceAttributes_t, cublasLtMatmulPreferenceCreate,
cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceSetAttribute,
cublasLtMatrixLayout_t, cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy,
cudaDataType,
},
},
driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
},
host::HostOp,
};
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
/// Fused GLU-MoE HostOp matched via egglog pattern.
///
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
/// + weighted sum) with an efficient cuBLASLt implementation.
///
/// Inputs (graph edges, in order):
/// 0: x [seq, hidden] F32
/// 1: topk_indices [seq, k] Int
/// 2: topk_values [seq, k] F32
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
/// 4: down_w [E, hidden, intermediate] BF16
///
/// Output: [seq, hidden] F32
pub struct GLUMoE {
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
gu_io: Expression,
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
dn_io: Expression,
/// K dimension of gate_up matmul (= hidden)
gu_matmul_k: Expression,
/// K dimension of down matmul (= intermediate)
dn_matmul_k: Expression,
/// K experts to sum over (= top_k)
output_k: Expression,
/// Total elements in a single gate_up expert weight matrix
gu_within_range: Expression,
/// Total elements in a single down expert weight matrix
dn_within_range: Expression,
cublaslt: OnceLock<Arc<CudaBlasLT>>,
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
}
impl Default for GLUMoE {
fn default() -> Self {
Self {
gu_io: Expression::default(),
dn_io: Expression::default(),
gu_matmul_k: Expression::default(),
dn_matmul_k: Expression::default(),
output_k: Expression::default(),
gu_within_range: Expression::default(),
dn_within_range: Expression::default(),
cublaslt: OnceLock::new(),
module: OnceLock::new(),
}
}
}
impl std::fmt::Debug for GLUMoE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GLUMoE")
.field("gu_io", &self.gu_io)
.field("dn_io", &self.dn_io)
.field("gu_matmul_k", &self.gu_matmul_k)
.field("dn_matmul_k", &self.dn_matmul_k)
.field("output_k", &self.output_k)
.finish()
}
}
impl Clone for GLUMoE {
fn clone(&self) -> Self {
Self {
gu_io: self.gu_io,
dn_io: self.dn_io,
gu_matmul_k: self.gu_matmul_k,
dn_matmul_k: self.dn_matmul_k,
output_k: self.output_k,
gu_within_range: self.gu_within_range,
dn_within_range: self.dn_within_range,
cublaslt: OnceLock::new(),
module: OnceLock::new(),
}
}
}
impl GLUMoE {
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
self.cublaslt
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
}
fn get_kernels(
&self,
stream: &Arc<CudaStream>,
) -> &(Arc<CudaModule>, CudaFunction, CudaFunction) {
self.module.get_or_init(|| {
let src = r#"
#include <cuda_bf16.h>
extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long long out_ptr, int n) {
const float* in_ = (const float*)in_ptr;
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = __float2bfloat16(in_[i]);
}
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < intermediate) {
float gate = __bfloat162float(gate_up[i]);
float up = __bfloat162float(gate_up[i + intermediate]);
float silu = gate / (1.0f + expf(-gate));
out[i] = __float2bfloat16(silu * up);
}
}
"#;
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
let swiglu = module.load_function("swiglu_bf16").unwrap();
(module, f32_to_bf16, swiglu)
})
}
}
impl EgglogOp for GLUMoE {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"GLUMoE",
&[
("gu_io", EXPRESSION),
("dn_io", EXPRESSION),
("gu_matmul_k", EXPRESSION),
("dn_matmul_k", EXPRESSION),
("output_k", EXPRESSION),
("gu_within_range", EXPRESSION),
("dn_within_range", EXPRESSION),
],
)
}
fn n_inputs(&self) -> usize {
5
}
fn early_rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
}
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let extracted = GLUMoE {
gu_io,
dn_io,
gu_matmul_k,
dn_matmul_k,
output_k,
gu_within_range,
dn_within_range,
cublaslt: OnceLock::new(),
module: OnceLock::new(),
};
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
(op, input_enodes)
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for GLUMoE {
fn execute(
&self,
stream: &Arc<CudaStream>,
self_node: NodeIndex,
inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// Resolve dimensions
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
let top_k = self.output_k.exec(dyn_map).unwrap();
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
let x_buf = buffers[&inputs[0]];
let seq = x_buf.len() / (hidden * 4);
// Get input/output buffers
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
let output_buf = buffers[&self_node]; // [seq, hidden] F32
// Get raw device pointer addresses
let x_ptr = buf_ptr(x_buf, stream);
let gate_up_ptr = buf_ptr(gate_up_buf, stream);
let down_ptr = buf_ptr(down_buf, stream);
let output_ptr = buf_ptr(output_buf, stream);
let cublaslt = self.get_cublaslt(stream);
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
// Read topk indices and values from GPU
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
// Allocate temp buffers
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
let gate_up_out_buf = unsafe { stream.alloc::<u8>(gate_up_dim * 2)? }; // BF16 per-token
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
let hid_ptr = buf_ptr(&hidden_tmp, stream);
let ws_ptr = buf_ptr(&workspace, stream);
// Cast x F32 → BF16
let n_cast = (seq * hidden) as i32;
let blocks = (n_cast as u32).div_ceil(256);
unsafe {
stream
.launch_builder(f32_to_bf16_fn)
.arg(&x_ptr)
.arg(&xbf16_ptr)
.arg(&n_cast)
.launch(LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})?;
}
// Per-token expert computation
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
// Normalize top-k values per token (norm_topk_prob=true)
let mut normalized_vals = topk_vals_f32.to_vec();
for t in 0..seq {
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
let sum: f32 = row.iter().sum();
if sum > 0.0 {
for v in row.iter_mut() {
*v /= sum;
}
}
}
for t in 0..seq {
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
{
let expert_idx = expert_idx as usize;
// a. Gate+Up matmul (BF16 in, BF16 out)
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
cublas_matmul(
stream,
cublaslt,
ws_ptr,
gate_up_dim as u64,
1,
hidden as u64,
expert_gu_ptr,
cublasOperation_t::CUBLAS_OP_T,
hidden as i64,
x_t_ptr,
cublasOperation_t::CUBLAS_OP_N,
hidden as i64,
gu_out_ptr,
gate_up_dim as i64,
cudaDataType::CUDA_R_16BF,
cublasComputeType_t::CUBLAS_COMPUTE_32F,
1.0f32,
0.0f32,
)?;
// b. SwiGLU kernel (BF16 → BF16)
let moe_int = intermediate as i32;
let swiglu_blocks = (moe_int as u32).div_ceil(256);
unsafe {
stream
.launch_builder(swiglu_fn)
.arg(&gu_out_ptr)
.arg(&hid_ptr)
.arg(&moe_int)
.launch(LaunchConfig {
grid_dim: (swiglu_blocks, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})?;
}
// c. Down matmul (BF16 in → F32 out) with fused accumulate
let expert_down_ptr = down_ptr + expert_idx as u64 * down_stride;
let out_t_ptr = output_ptr + (t * hidden * 4) as u64; // F32
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
cublas_matmul_mixed(
stream,
cublaslt,
ws_ptr,
hidden as u64,
1,
intermediate as u64,
expert_down_ptr,
cublasOperation_t::CUBLAS_OP_T,
intermediate as i64,
hid_ptr,
cublasOperation_t::CUBLAS_OP_N,
intermediate as i64,
out_t_ptr,
hidden as i64,
weight,
beta,
)?;
}
}
stream.synchronize()?;
Ok(())
}
fn output_size(&self) -> Expression {
// Output is [seq, hidden] F32 → seq * hidden elements
// But seq is dynamic. We derive from first input size / hidden.
// Actually, output_bytes is what matters for allocation:
Expression::from('s') * self.gu_matmul_k
}
fn output_bytes(&self) -> Expression {
Expression::from('s') * self.gu_matmul_k * 4 // F32
}
fn stats_name(&self) -> Option<&'static str> {
Some("GLUMoE")
}
}
// ============================================================
// Helpers
// ============================================================
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
let (ptr, _guard) = buf.device_ptr(stream);
ptr
}
#[allow(clippy::too_many_arguments)]
fn cublas_matmul(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
workspace_ptr: u64,
m: u64,
n: u64,
k: u64,
a_ptr: u64,
a_op: cublasOperation_t,
lda: i64,
b_ptr: u64,
b_op: cublasOperation_t,
ldb: i64,
c_ptr: u64,
ldc: i64,
dtype: cudaDataType,
compute: cublasComputeType_t,
alpha: f32,
beta: f32,
) -> anyhow::Result<()> {
let scale_type = cudaDataType::CUDA_R_32F;
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
unsafe {
cublasLtMatmulDescCreate(&mut matmul_desc, compute, scale_type).result()?;
cublasLtMatmulDescSetAttribute(
matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
&a_op as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
cublasLtMatmulDescSetAttribute(
matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
&b_op as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
let (a_rows, a_cols) = if a_op == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_op == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
cublasLtMatrixLayoutCreate(&mut a_desc, dtype, a_rows, a_cols, lda).result()?;
cublasLtMatrixLayoutCreate(&mut b_desc, dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, dtype, m, n, ldc).result()?;
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
matmul_desc,
a_desc,
b_desc,
c_desc,
c_desc,
preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
cublasLtMatmul(
*cublaslt.handle(),
matmul_desc,
&alpha as *const _ as *const std::ffi::c_void,
a_ptr as *const std::ffi::c_void,
a_desc,
b_ptr as *const std::ffi::c_void,
b_desc,
&beta as *const _ as *const std::ffi::c_void,
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
c_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
stream.cu_stream() as *mut _,
)
.result()?;
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn cublas_matmul_mixed(
stream: &Arc<CudaStream>,
cublaslt: &Arc<CudaBlasLT>,
workspace_ptr: u64,
m: u64,
n: u64,
k: u64,
a_ptr: u64,
a_op: cublasOperation_t,
lda: i64,
b_ptr: u64,
b_op: cublasOperation_t,
ldb: i64,
c_ptr: u64,
ldc: i64,
alpha: f32,
beta: f32,
) -> anyhow::Result<()> {
let ab_dtype = cudaDataType::CUDA_R_16BF;
let cd_dtype = cudaDataType::CUDA_R_32F;
let compute = cublasComputeType_t::CUBLAS_COMPUTE_32F;
let scale_type = cudaDataType::CUDA_R_32F;
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut d_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
let mut algo_count: i32 = 0;
unsafe {
cublasLtMatmulDescCreate(&mut matmul_desc, compute, scale_type).result()?;
cublasLtMatmulDescSetAttribute(
matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
&a_op as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
cublasLtMatmulDescSetAttribute(
matmul_desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
&b_op as *const _ as *const std::ffi::c_void,
std::mem::size_of::<cublasOperation_t>(),
)
.result()?;
let (a_rows, a_cols) = if a_op == cublasOperation_t::CUBLAS_OP_N {
(m, k)
} else {
(k, m)
};
let (b_rows, b_cols) = if b_op == cublasOperation_t::CUBLAS_OP_N {
(k, n)
} else {
(n, k)
};
cublasLtMatrixLayoutCreate(&mut a_desc, ab_dtype, a_rows, a_cols, lda).result()?;
cublasLtMatrixLayoutCreate(&mut b_desc, ab_dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, cd_dtype, m, n, ldc).result()?;
cublasLtMatrixLayoutCreate(&mut d_desc, cd_dtype, m, n, ldc).result()?;
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
preference,
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
std::mem::size_of::<usize>(),
)
.result()?;
cublasLtMatmulAlgoGetHeuristic(
*cublaslt.handle(),
matmul_desc,
a_desc,
b_desc,
c_desc,
d_desc,
preference,
1,
&mut heuristic,
&mut algo_count,
)
.result()?;
if algo_count == 0 {
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(d_desc);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
return Err(anyhow::anyhow!(
"No suitable cuBLASLT algorithm found for mixed matmul"
));
}
cublasLtMatmul(
*cublaslt.handle(),
matmul_desc,
&alpha as *const _ as *const std::ffi::c_void,
a_ptr as *const std::ffi::c_void,
a_desc,
b_ptr as *const std::ffi::c_void,
b_desc,
&beta as *const _ as *const std::ffi::c_void,
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
d_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
stream.cu_stream() as *mut _,
)
.result()?;
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(d_desc);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
cublasLtMatrixLayoutDestroy(a_desc);
cublasLtMatmulDescDestroy(matmul_desc);
}
Ok(())
}

View File

@@ -0,0 +1,656 @@
#![allow(clippy::missing_safety_doc, clippy::not_unsafe_ptr_arg_deref)]
//! CUDA Graph API wrappers for explicit graph construction and surgical updates.
use std::ffi::c_void;
use std::mem::MaybeUninit;
use std::sync::Arc;
use cudarc::driver::{
CudaContext, CudaFunction, CudaStream, DriverError,
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
};
/// A CUDA graph that can be modified and instantiated.
pub struct CudaGraphHandle {
pub(crate) cu_graph: CUgraph,
pub(crate) ctx: Arc<CudaContext>,
}
impl CudaGraphHandle {
/// Creates a new empty CUDA graph.
pub fn new(ctx: Arc<CudaContext>) -> Result<Self, DriverError> {
ctx.bind_to_thread()?;
let mut graph = MaybeUninit::uninit();
unsafe {
sys::cuGraphCreate(graph.as_mut_ptr(), 0).result()?;
Ok(Self {
cu_graph: graph.assume_init(),
ctx,
})
}
}
/// Adds a kernel node to the graph. kernel_params must remain valid for graph lifetime.
pub unsafe fn add_kernel_node(
&mut self,
dependencies: &[CUgraphNode],
func: CUfunction,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
shared_mem_bytes: u32,
kernel_params: *mut *mut c_void,
) -> Result<CUgraphNode, DriverError> {
let params = sys::CUDA_KERNEL_NODE_PARAMS {
func,
gridDimX: grid_dim.0,
gridDimY: grid_dim.1,
gridDimZ: grid_dim.2,
blockDimX: block_dim.0,
blockDimY: block_dim.1,
blockDimZ: block_dim.2,
sharedMemBytes: shared_mem_bytes,
kernelParams: kernel_params,
extra: std::ptr::null_mut(),
kern: std::ptr::null_mut(), // Not using CUkernel-based launch
ctx: std::ptr::null_mut(), // Use default context
};
let mut node = MaybeUninit::uninit();
unsafe {
sys::cuGraphAddKernelNode_v2(
node.as_mut_ptr(),
self.cu_graph,
dependencies.as_ptr(),
dependencies.len(),
&params,
)
.result()?;
Ok(node.assume_init())
}
}
/// Adds an event record node to the graph for timing.
pub fn add_event_record_node(
&mut self,
dependencies: &[CUgraphNode],
event: CUevent,
) -> Result<CUgraphNode, DriverError> {
let mut node = MaybeUninit::uninit();
unsafe {
sys::cuGraphAddEventRecordNode(
node.as_mut_ptr(),
self.cu_graph,
dependencies.as_ptr(),
dependencies.len(),
event,
)
.result()?;
Ok(node.assume_init())
}
}
/// Instantiates the graph, creating an executable graph.
pub fn instantiate(&self) -> Result<CudaGraphExecHandle, DriverError> {
self.ctx.bind_to_thread()?;
let mut graph_exec = MaybeUninit::uninit();
unsafe {
sys::cuGraphInstantiateWithFlags(graph_exec.as_mut_ptr(), self.cu_graph, 0).result()?;
Ok(CudaGraphExecHandle {
cu_graph_exec: graph_exec.assume_init(),
ctx: self.ctx.clone(),
})
}
}
}
impl Drop for CudaGraphHandle {
fn drop(&mut self) {
let _ = self.ctx.bind_to_thread();
if !self.cu_graph.is_null() {
unsafe {
let _ = sys::cuGraphDestroy(self.cu_graph);
}
}
}
}
/// An instantiated CUDA graph that can be launched and updated.
pub struct CudaGraphExecHandle {
pub(crate) cu_graph_exec: CUgraphExec,
pub(crate) ctx: Arc<CudaContext>,
}
impl CudaGraphExecHandle {
/// Launches the graph on the given stream.
pub fn launch(&self, stream: &CudaStream) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { sys::cuGraphLaunch(self.cu_graph_exec, stream.cu_stream()).result() }
}
/// Surgically updates a kernel node's parameters without rebuilding the graph.
pub unsafe fn update_kernel_node(
&mut self,
node: CUgraphNode,
func: CUfunction,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
shared_mem_bytes: u32,
kernel_params: *mut *mut c_void,
) -> Result<(), DriverError> {
let params = sys::CUDA_KERNEL_NODE_PARAMS {
func,
gridDimX: grid_dim.0,
gridDimY: grid_dim.1,
gridDimZ: grid_dim.2,
blockDimX: block_dim.0,
blockDimY: block_dim.1,
blockDimZ: block_dim.2,
sharedMemBytes: shared_mem_bytes,
kernelParams: kernel_params,
extra: std::ptr::null_mut(),
kern: std::ptr::null_mut(),
ctx: std::ptr::null_mut(),
};
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, &params) }
.result()
}
}
impl Drop for CudaGraphExecHandle {
fn drop(&mut self) {
let _ = self.ctx.bind_to_thread();
if !self.cu_graph_exec.is_null() {
unsafe {
let _ = sys::cuGraphExecDestroy(self.cu_graph_exec);
}
}
}
}
/// Extension trait to get the raw CUfunction handle from CudaFunction.
pub trait CudaFunctionExt {
unsafe fn raw_function(&self) -> CUfunction;
}
impl CudaFunctionExt for CudaFunction {
unsafe fn raw_function(&self) -> CUfunction {
// CudaFunction fields are reordered by Rust - cu_function is at offset 8
debug_assert_eq!(
std::mem::size_of::<CudaFunction>(),
std::mem::size_of::<CUfunction>() + std::mem::size_of::<usize>()
);
unsafe {
let ptr = (self as *const CudaFunction as *const u8).add(8) as *const CUfunction;
std::ptr::read(ptr)
}
}
}
/// Stored kernel parameters that persist for the lifetime of a CUDA graph.
#[derive(Debug)]
pub struct KernelParams {
values: Box<[u64]>,
ptrs: Box<[*mut c_void]>,
/// Index of the dyn_dims pointer in values array (if present)
dyn_dims_idx: Option<usize>,
}
impl KernelParams {
pub fn new(output_ptr: u64, input_ptrs: &[u64]) -> Self {
let mut values: Vec<u64> = Vec::with_capacity(1 + input_ptrs.len());
values.push(output_ptr);
values.extend_from_slice(input_ptrs);
let values = values.into_boxed_slice();
let ptrs: Vec<*mut c_void> = values
.iter()
.map(|v| v as *const u64 as *mut c_void)
.collect();
Self {
values,
ptrs: ptrs.into_boxed_slice(),
dyn_dims_idx: None,
}
}
/// Create kernel params with a dyn_dims pointer as the last parameter.
pub fn with_dyn_dims(output_ptr: u64, input_ptrs: &[u64], dyn_dims_ptr: u64) -> Self {
let mut values: Vec<u64> = Vec::with_capacity(2 + input_ptrs.len());
values.push(output_ptr);
values.extend_from_slice(input_ptrs);
let dyn_dims_idx = values.len();
values.push(dyn_dims_ptr);
let values = values.into_boxed_slice();
let ptrs: Vec<*mut c_void> = values
.iter()
.map(|v| v as *const u64 as *mut c_void)
.collect();
Self {
values,
ptrs: ptrs.into_boxed_slice(),
dyn_dims_idx: Some(dyn_dims_idx),
}
}
pub fn as_cuda_params(&mut self) -> *mut *mut c_void {
self.ptrs.as_mut_ptr()
}
pub fn update_output(&mut self, ptr: u64) {
self.values[0] = ptr;
}
pub fn update_input(&mut self, index: usize, ptr: u64) {
self.values[1 + index] = ptr;
}
/// Update the dyn_dims pointer if this kernel uses one.
pub fn update_dyn_dims(&mut self, ptr: u64) {
if let Some(idx) = self.dyn_dims_idx {
self.values[idx] = ptr;
}
}
}
/// Stored kernel parameters for megakernels that persist for the lifetime of a CUDA graph.
/// Params: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
#[derive(Debug)]
pub struct MegakernelParams {
/// Parameter values: [tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims]
values: Box<[u64]>,
/// Pointer array for CUDA kernel launch
ptrs: Box<[*mut c_void]>,
}
impl MegakernelParams {
/// Create megakernel params with all internal buffer pointers and dyn_dims.
/// Order: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
#[allow(clippy::too_many_arguments)]
pub fn new(
tasks_ptr: u64,
head_ptr: u64,
ready_ptr: u64,
queue_lock_ptr: u64,
timings_ptr: u64,
start_times_ptr: u64,
buffers_ptr: u64,
dyn_dims_ptr: u64,
) -> Self {
let values: Box<[u64]> = vec![
tasks_ptr,
head_ptr,
ready_ptr,
queue_lock_ptr,
timings_ptr,
start_times_ptr,
buffers_ptr,
dyn_dims_ptr,
]
.into_boxed_slice();
let ptrs: Box<[*mut c_void]> = values
.iter()
.map(|v| v as *const u64 as *mut c_void)
.collect();
Self { values, ptrs }
}
pub fn as_cuda_params(&mut self) -> *mut *mut c_void {
// Rebuild pointers (in case struct was moved)
for (i, v) in self.values.iter().enumerate() {
self.ptrs[i] = v as *const u64 as *mut c_void;
}
self.ptrs.as_mut_ptr()
}
/// Update the buffers pointer (index 6).
pub fn update_buffers(&mut self, ptr: u64) {
self.values[6] = ptr;
}
/// Update the dyn_dims pointer (index 7).
pub fn update_dyn_dims(&mut self, ptr: u64) {
self.values[7] = ptr;
}
/// Get the current buffers pointer value.
pub fn buffers_ptr(&self) -> u64 {
self.values[6]
}
}
/// Timing data for a single kernel in a CUDA graph.
#[derive(Clone, Debug)]
pub struct CudaGraphKernelTiming {
pub kernel_name: &'static str,
pub start_ns: u64,
pub end_ns: u64,
}
/// Timing data for a CUDA graph execution.
#[derive(Clone, Debug)]
pub struct CudaGraphTiming {
pub kernel_timings: Vec<CudaGraphKernelTiming>,
/// Time from launch call until first kernel started on GPU
pub launch_latency_ns: u64,
/// Elapsed time (in nanoseconds) from span entry to just before graph launch.
/// This captures the setup overhead (constants, buffers, graph building) that
/// occurs before the GPU actually starts executing.
pub setup_duration_ns: u64,
}
pub fn create_cuda_event(ctx: &Arc<CudaContext>) -> Result<CUevent, DriverError> {
ctx.bind_to_thread()?;
let mut event = MaybeUninit::uninit();
unsafe {
sys::cuEventCreate(
event.as_mut_ptr(),
sys::CUevent_flags::CU_EVENT_DEFAULT as u32,
)
.result()?;
Ok(event.assume_init())
}
}
pub fn destroy_cuda_event(ctx: &Arc<CudaContext>, event: CUevent) {
if !event.is_null() {
let _ = ctx.bind_to_thread();
unsafe {
let _ = sys::cuEventDestroy_v2(event);
}
}
}
pub fn event_elapsed_ms(
ctx: &Arc<CudaContext>,
start: CUevent,
end: CUevent,
) -> Result<f32, DriverError> {
ctx.bind_to_thread()?;
let mut ms: f32 = 0.0;
unsafe {
sys::cuEventElapsedTime_v2(&mut ms, start, end).result()?;
}
Ok(ms)
}
pub fn record_event_on_stream(
ctx: &Arc<CudaContext>,
event: CUevent,
stream: &CudaStream,
) -> Result<(), DriverError> {
ctx.bind_to_thread()?;
unsafe {
sys::cuEventRecord(event, stream.cu_stream()).result()?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, Tensor};
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use proptest::prelude::*;
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::sync::Arc;
use crate::cuda_bandwidth_gbps;
use crate::runtime::CudaRuntime;
use crate::tests::utilities::*;
#[test]
fn test_create_empty_graph() {
let Ok(ctx) = CudaContext::new(0) else { return };
assert!(CudaGraphHandle::new(ctx).is_ok());
}
#[test]
fn test_kernel_params() {
let mut params = KernelParams::new(0x1000, &[0x2000, 0x3000]);
assert!(!params.as_cuda_params().is_null());
params.update_output(0x4000);
params.update_input(0, 0x5000);
}
#[test]
fn test_cuda_function_size() {
assert_eq!(
std::mem::size_of::<CudaFunction>(),
std::mem::size_of::<CUfunction>() + std::mem::size_of::<usize>()
);
}
#[test]
fn test_raw_function_extraction() {
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
let func = module.load_function("test_kernel").unwrap();
let cu_func = unsafe { func.raw_function() };
assert!(!cu_func.is_null());
let mut max_threads: i32 = 0;
let result = unsafe {
sys::cuFuncGetAttribute(
&mut max_threads,
sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
cu_func,
)
};
assert!(result == sys::cudaError_enum::CUDA_SUCCESS);
}
#[test]
fn test_graph_with_kernel() {
use cudarc::driver::{CudaSlice, DevicePtr};
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
let func = module.load_function("test_kernel").unwrap();
let stream = ctx.default_stream();
let output: CudaSlice<f32> = unsafe { stream.alloc(1) }.unwrap();
let mut input: CudaSlice<f32> = unsafe { stream.alloc(1) }.unwrap();
stream.memcpy_htod(&[5.0f32], &mut input).unwrap();
let cu_func = unsafe { func.raw_function() };
let mut graph = CudaGraphHandle::new(ctx.clone()).unwrap();
let mut params =
KernelParams::new(output.device_ptr(&stream).0, &[input.device_ptr(&stream).0]);
let _node = unsafe {
graph.add_kernel_node(
&[],
cu_func,
(1, 1, 1),
(1, 1, 1),
0,
params.as_cuda_params(),
)
}
.unwrap();
let exec = graph.instantiate().unwrap();
exec.launch(&stream).unwrap();
stream.synchronize().unwrap();
let mut result = [0.0f32];
stream.memcpy_dtoh(&output, &mut result).unwrap();
assert_eq!(result[0], 6.0f32);
}
// CUDA Graph Tests
#[test]
fn test_cuda_graph_basic_execution() {
let Some(stream) = get_cuda_stream() else {
return;
};
let size = 1024;
let mut cx = Graph::default();
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = ((a + b) * a + b).output();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
rt.execute(&cx.dyn_map);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&result1, &rt.get_f32(c), tol, tol);
let expected: Vec<f32> = data_a
.iter()
.zip(&data_b)
.map(|(a, b)| (a + b) * a + b)
.collect();
assert_close(&result1, &expected, tol, tol);
}
#[test]
fn test_cuda_graph_multiple_executions() {
let Some(stream) = get_cuda_stream() else {
return;
};
let size = 2048;
let mut cx = Graph::default();
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = (a + b + a + b).output();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
let mut results = Vec::new();
for _ in 0..5 {
rt.execute(&cx.dyn_map);
results.push(rt.get_f32(c));
}
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
for result in &results {
assert_close(result, &results[0], tol, tol);
}
let expected: Vec<f32> = data_a
.iter()
.zip(&data_b)
.map(|(a, b)| a + b + a + b)
.collect();
assert_close(&results[0], &expected, tol, tol);
}
#[test]
fn test_cuda_graph_dyn_dims_surgical_update() {
let Some(stream) = get_cuda_stream() else {
return;
};
let size = 512;
let mut cx = Graph::default();
let a = cx.tensor('s');
let b = cx.tensor('s');
let c = (a + b).output();
let d = (c * a).output();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
.iter()
.zip(&data_b)
.map(|(a, b)| (a + b) * a)
.collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&rt.get_f32(d), &expected, tol, tol);
let size = 1024;
let data_a2 = random_f32_vec(size, 44, -0.5, 0.5);
let data_b2 = random_f32_vec(size, 45, -0.5, 0.5);
rt.set_data(a, data_a2.clone());
rt.set_data(b, data_b2.clone());
cx.set_dim('s', size);
rt.execute(&cx.dyn_map);
let expected2: Vec<f32> = data_a2
.iter()
.zip(&data_b2)
.map(|(a, b)| (a + b) * a)
.collect();
assert_close(&rt.get_f32(d), &expected2, tol, tol);
}
#[test]
fn test_single_kernel_in_graph() {
let Some(stream) = get_cuda_stream() else {
return;
};
let size = 1024;
let mut cx = Graph::default();
let a = cx.tensor(size);
let b = cx.tensor(size);
let c = (a + b).output();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&rt.get_f32(c), &expected, tol, tol);
assert!(rt.last_kernel_stats.iter().any(|s| s.name == "CudaGraph"));
}
#[test]
fn test_cuda_graph_chain_performance() {
let Some(stream) = get_cuda_stream() else {
return;
};
let size = 4096;
let mut cx = Graph::default();
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let mut result = a + b;
for _ in 0..5 {
result += a;
result *= b;
}
let output = result.output();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
for _ in 0..10 {
rt.execute(&cx.dyn_map);
}
let mut expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
for _ in 0..5 {
expected = expected.iter().zip(&data_a).map(|(r, a)| r + a).collect();
expected = expected.iter().zip(&data_b).map(|(r, b)| r * b).collect();
}
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,289 @@
#![allow(unused)]
use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream};
use luminal::prelude::*;
use luminal_tracing::schema::{
self as schema, TrackEvent, debug_annotation::NameField, trace_packet, track_event,
};
use uuid::Uuid;
pub mod cuda_graph;
pub mod hlir;
pub mod other_ops;
pub use cuda_graph::*;
pub type Ops = (hlir::Ops, other_ops::Ops);
/// Build a mapping from interned string IDs to their string values for a given sequence.
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
use luminal_tracing::schema::trace_packet;
let mut interned: std::collections::HashMap<(u32, u64), String> =
std::collections::HashMap::new();
for packet in &trace.packet {
let seq_id = match &packet.optional_trusted_packet_sequence_id {
Some(trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(seq)) => {
*seq
}
_ => 0,
};
// interned_data is a field on TracePacket, not a Data variant
if let Some(data) = &packet.interned_data {
for entry in &data.debug_annotation_names {
if let Some(name) = &entry.name {
interned.insert((seq_id, entry.iid()), name.clone());
}
}
}
}
interned
}
/// Check if a debug annotation has key "id" and the given UUID value.
fn annotation_matches_id(
a: &schema::DebugAnnotation,
id: &Uuid,
interned: &std::collections::HashMap<(u32, u64), String>,
seq_id: u32,
) -> bool {
let key_matches = match &a.name_field {
Some(NameField::Name(k)) => k == "id",
Some(NameField::NameIid(iid)) => interned
.get(&(seq_id, *iid))
.map(|s| s == "id")
.unwrap_or(false),
None => false,
};
if !key_matches {
return false;
}
match &a.value {
Some(luminal_tracing::schema::debug_annotation::Value::StringValue(v)) => {
*v == format!("{id}")
}
_ => false,
}
}
/// Record CUDA graph kernel timings as nested slices in perfetto trace
pub fn record_cuda_graph_timings(
trace: &schema::Trace,
cuda_graph_timings: &[(CudaGraphTiming, Uuid)],
) -> Vec<schema::TracePacket> {
use luminal_tracing::schema::{trace_packet, track_descriptor};
// Build interned string lookup table
let interned = build_interned_strings(trace);
let mut packets = Vec::new();
for (graph_timing, id) in cuda_graph_timings {
let parent_info = trace.packet.iter().find_map(|p| {
let seq_id = match &p.optional_trusted_packet_sequence_id {
Some(trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
seq,
)) => *seq,
_ => 0,
};
match &p.data {
Some(trace_packet::Data::TrackEvent(TrackEvent {
r#type: ty,
track_uuid,
debug_annotations,
..
})) if *ty == Some(track_event::Type::SliceBegin as i32)
&& debug_annotations
.iter()
.any(|a| annotation_matches_id(a, id, &interned, seq_id)) =>
{
Some((p.timestamp?, p.timestamp_clock_id?, (*track_uuid)?, seq_id))
}
_ => None,
}
});
let Some((span_start_time, clock_id, track_uuid, sequence_id)) = parent_info else {
continue;
};
// Use span_start_time + setup_duration + launch_latency as the base for kernel timings.
// - setup_duration_ns: time spent on host between span entry and launch call
// - launch_latency_ns: GPU-side time from launch to first kernel execution
// This ensures kernel spans are accurately positioned within the cuda_graph span.
let base_time =
span_start_time + graph_timing.setup_duration_ns + graph_timing.launch_latency_ns;
for kernel_timing in &graph_timing.kernel_timings {
packets.push(schema::TracePacket {
timestamp: Some(base_time + kernel_timing.start_ns),
timestamp_clock_id: Some(clock_id),
optional_trusted_packet_sequence_id: Some(
trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
sequence_id,
),
),
data: Some(trace_packet::Data::TrackEvent(schema::TrackEvent {
track_uuid: Some(track_uuid),
r#type: Some(track_event::Type::SliceBegin as i32),
name_field: Some(track_event::NameField::Name(
kernel_timing.kernel_name.to_owned(),
)),
..Default::default()
})),
..Default::default()
});
packets.push(schema::TracePacket {
timestamp: Some(base_time + kernel_timing.end_ns),
timestamp_clock_id: Some(clock_id),
optional_trusted_packet_sequence_id: Some(
trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
sequence_id,
),
),
data: Some(trace_packet::Data::TrackEvent(schema::TrackEvent {
track_uuid: Some(track_uuid),
r#type: Some(track_event::Type::SliceEnd as i32),
name_field: Some(track_event::NameField::Name(
kernel_timing.kernel_name.to_owned(),
)),
..Default::default()
})),
..Default::default()
});
}
}
packets
}
pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
#[allow(clippy::type_complexity)]
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
);
/// Returns the output buffer size in elements.
fn output_size(&self) -> Expression;
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.output_size().dyn_vars().into_iter().collect()
}
/// Returns the output buffer size in bytes (accounts for dtype).
fn output_bytes(&self) -> Expression;
/// Returns the DType of this kernel's output buffer.
/// Used by has_nan_outputs to interpret buffer bytes correctly.
/// Default: F32 (most kernels output float).
fn output_dtype(&self) -> DType {
DType::F32
}
/// Returns the number of bytes this kernel will load from global memory.
fn bytes_loaded(&self) -> Expression {
0.into()
}
/// Returns the number of bytes this kernel will store to global memory.
fn bytes_stored(&self) -> Expression {
0.into()
}
/// Returns the number of floating point operations this kernel performs.
fn flops(&self) -> Expression {
0.into()
}
/// Returns the name of this kernel for profiling display.
fn kernel_name(&self) -> &'static str {
"Unknown"
}
/// Allocate internal buffers this kernel needs. Called once during graph building.
/// Default: no internal buffers.
fn allocate_internal_buffers(
&self,
_stream: &Arc<CudaStream>,
_dyn_map: &FxHashMap<char, usize>,
) -> Vec<CudaSlice<u8>> {
vec![]
}
/// Returns the set of dynamic dimensions that affect internal buffer sizes.
/// When any of these dimensions change, internal buffers should be reallocated.
/// Default: empty set (no dimensions affect internal buffers).
fn internal_buffer_dyn_dims(&self) -> FxHashSet<char> {
FxHashSet::default()
}
/// Build kernel parameters. Returns the u64 values to pass to the kernel.
/// Default: [output_ptr, input_ptrs..., dyn_dims_ptr (if non-zero)]
fn build_params(
&self,
_stream: &Arc<CudaStream>,
output_ptr: u64,
input_ptrs: &[u64],
_internal_bufs: &[CudaSlice<u8>],
dyn_dims_ptr: u64,
) -> Vec<u64> {
let mut params = vec![output_ptr];
params.extend_from_slice(input_ptrs);
if dyn_dims_ptr != 0 {
params.push(dyn_dims_ptr);
}
params
}
/// Called before each kernel execution. Update internal state if needed.
/// `all_buffer_ptrs` contains pointers for all buffers this kernel might use.
/// `constants` are device constants returned by compile() that may need updating.
fn pre_execute(
&self,
_stream: &Arc<CudaStream>,
_internal_bufs: &mut [CudaSlice<u8>],
_constants: &mut FxHashMap<char, CudaSlice<u8>>,
_all_buffer_ptrs: &FxHashMap<NodeIndex, u64>,
_dyn_map: &FxHashMap<char, usize>,
) {
}
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
fn output_aliases_input(&self) -> Option<usize> {
None
}
/// If this kernel's output is derived from one of its inputs (copy-then-modify
/// or in-place write), return that input index. Used by `resolve_data_node` to
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
/// roundtrip pattern.
///
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
/// (like Scatter which copies dest→output then scatters into it).
fn output_data_input(&self) -> Option<usize> {
self.output_aliases_input()
}
/// Returns indices of internal buffers containing timing data, if any.
/// Returns (timings_idx, start_times_idx, sm_count).
fn timing_buffer_indices(&self) -> Option<(usize, usize, usize)> {
None
}
}
luminal::impl_into_ops!(KernelOp);
// Kernel to host op compilation
mod to_host;
pub use to_host::{CudaGraphOp, kernel_to_host};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,839 @@
//! Compiles KernelOp subgraphs into HostOp (CudaGraphOp).
//!
//! CudaGraphOp wraps a subgraph of KernelOps into a single executable unit
//! that can be executed like any other HostOp.
use std::cell::RefCell;
use std::sync::Arc;
use cudarc::driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, sys::CUgraphNode,
};
use itertools::Itertools;
use luminal::{
egglog_utils::{api::Rule, base::OP_KIND},
graph::LLIRGraph,
op::{EgglogOp, LLIROp},
prelude::{
petgraph::{Direction, algo::toposort, visit::EdgeRef},
*,
},
};
use tracing::{Level, enabled, span};
use crate::{
host::HostOp,
kernel::{
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
destroy_cuda_event,
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
},
runtime::partition_marked_convex,
};
/// A compiled kernel within a CudaGraphOp.
#[derive(Debug)]
struct CompiledKernel {
/// The node index in the original llir_graph
node: NodeIndex,
/// The compiled CUDA function
function: CudaFunction,
/// Launch grid dimensions (blocks)
grid: (Expression, Expression, Expression),
/// Launch block dimensions (threads)
block: (Expression, Expression, Expression),
/// Shared memory size
shared_mem: Expression,
/// Input node indices (for buffer lookup)
inputs: Vec<NodeIndex>,
/// Reference to the KernelOp for trait methods
kernel_op: Arc<Box<dyn KernelOp>>,
/// Internal buffers allocated for this kernel
internal_bufs: Vec<CudaSlice<u8>>,
/// Device constants from compile()
constants: FxHashMap<char, CudaSlice<u8>>,
/// Graph node handle (set after graph is built)
graph_node: Option<CUgraphNode>,
/// Kernel name for profiling
kernel_name: &'static str,
}
impl CompiledKernel {
#[allow(clippy::too_many_arguments)]
fn new(
node: NodeIndex,
function: CudaFunction,
grid: (Expression, Expression, Expression),
block: (Expression, Expression, Expression),
shared_mem: Expression,
inputs: Vec<NodeIndex>,
kernel_op: Arc<Box<dyn KernelOp>>,
constants: FxHashMap<char, CudaSlice<u8>>,
kernel_name: &'static str,
) -> Self {
Self {
node,
function,
grid,
block,
shared_mem,
inputs,
kernel_op,
internal_bufs: Vec::new(),
constants,
graph_node: None,
kernel_name,
}
}
}
/// Unified kernel params that can hold any number of u64 values.
struct UnifiedKernelParams {
values: Vec<u64>,
ptrs: Vec<*mut std::ffi::c_void>,
}
impl UnifiedKernelParams {
fn new(values: Vec<u64>) -> Self {
let ptrs = values
.iter()
.map(|v| v as *const u64 as *mut std::ffi::c_void)
.collect();
Self { values, ptrs }
}
fn as_cuda_params(&mut self) -> *mut *mut std::ffi::c_void {
// Rebuild pointers (in case struct was moved)
for (i, v) in self.values.iter().enumerate() {
self.ptrs[i] = v as *const u64 as *mut std::ffi::c_void;
}
self.ptrs.as_mut_ptr()
}
}
/// Mutable state for CudaGraphOp that needs interior mutability.
struct CudaGraphOpState {
/// Compiled kernels in topological order
kernels: Vec<CompiledKernel>,
/// Shared device buffer for dynamic dimensions
dyn_dims_buffer: Option<CudaSlice<i32>>,
/// CUDA graph handle
cuda_graph: Option<CudaGraphHandle>,
/// CUDA graph exec handle
cuda_graph_exec: Option<CudaGraphExecHandle>,
/// Mapping from kernel node to graph node
node_to_graph_node: FxHashMap<NodeIndex, CUgraphNode>,
/// Kernel params for each kernel
kernel_params: Vec<UnifiedKernelParams>,
/// Last dynamic dimension values (for change detection)
last_dyn_values: FxHashMap<char, usize>,
/// Last buffer pointers (for change detection)
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
/// Timing events for profiling
timing_events: Vec<cudarc::driver::sys::CUevent>,
}
impl CudaGraphOpState {
fn new(kernels: Vec<CompiledKernel>) -> Self {
Self {
kernels,
dyn_dims_buffer: None,
cuda_graph: None,
cuda_graph_exec: None,
node_to_graph_node: FxHashMap::default(),
kernel_params: Vec::new(),
last_dyn_values: FxHashMap::default(),
last_buffer_ptrs: FxHashMap::default(),
timing_events: Vec::new(),
}
}
}
/// A CUDA graph operation that implements HostOp.
///
/// This wraps a subgraph of KernelOps into a single executable CUDA graph.
/// It manages graph building, execution, and dynamic updates.
pub struct CudaGraphOp {
/// All nodes that this graph needs buffers for (kernels + their inputs)
buffer_nodes: Vec<NodeIndex>,
/// Buffer size requirements for extra nodes (node -> size in elements)
buffer_sizes: FxHashMap<NodeIndex, Expression>,
/// Dynamic dimensions used by this graph (sorted alphabetically)
dyn_dims_order: Vec<char>,
/// The CUDA stream (needed for operations)
stream: Arc<CudaStream>,
/// Mutable state wrapped in RefCell for interior mutability
state: RefCell<CudaGraphOpState>,
}
impl CudaGraphOp {
fn new(
buffer_nodes: Vec<NodeIndex>,
buffer_sizes: FxHashMap<NodeIndex, Expression>,
dyn_dims_order: Vec<char>,
stream: Arc<CudaStream>,
state: CudaGraphOpState,
) -> Self {
Self {
buffer_nodes,
buffer_sizes,
dyn_dims_order,
stream,
state: RefCell::new(state),
}
}
}
impl std::fmt::Debug for CudaGraphOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = self.state.borrow();
f.debug_struct("CudaGraphOp")
.field("n_kernels", &state.kernels.len())
.field("n_buffer_nodes", &self.buffer_nodes.len())
.finish()
}
}
impl EgglogOp for CudaGraphOp {
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
}
fn rewrites(&self) -> Vec<Rule> {
vec![]
}
fn extract<'a>(
&'a self,
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
_kind_children: &[&'a luminal::prelude::ENodeId],
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
panic!("CudaGraphOp should not be extracted from egglog")
}
fn cleanup(&self) -> bool {
false
}
}
impl HostOp for CudaGraphOp {
fn execute(
&self,
stream: &Arc<CudaStream>,
_self_node: NodeIndex,
_inputs: &[NodeIndex],
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
self.execute_internal(stream, buffers, dyn_map)
}
fn output_size(&self) -> Expression {
// CudaGraphOp doesn't have a single output - individual kernels have outputs
0.into()
}
fn output_bytes(&self) -> Expression {
// CudaGraphOp doesn't have a single output - individual kernels have outputs
0.into()
}
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
// Only return nodes that actually have buffers
// Filter out nodes in buffer_sizes with size 0 (like MegakernelOps)
// Keep nodes not in buffer_sizes (external inputs that have their own buffers)
self.buffer_nodes
.iter()
.filter(|n| {
match self.buffer_sizes.get(n) {
Some(size) => size.exec(&FxHashMap::default()).unwrap_or(1) != 0,
None => true, // Not a kernel output, might be an external input
}
})
.copied()
.collect()
}
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
self.buffer_sizes.clone()
}
fn stats_name(&self) -> Option<&'static str> {
Some("CudaGraph")
}
}
impl CudaGraphOp {
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
fn execute_internal(
&self,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let mut state = self.state.borrow_mut();
let _span = span!(Level::TRACE, "cuda_graph", kernels = state.kernels.len()).entered();
// Check if dyn_map changed
let dyn_map_changed = dyn_map.len() != state.last_dyn_values.len()
|| dyn_map
.iter()
.any(|(k, v)| state.last_dyn_values.get(k) != Some(v));
// Check if any kernel's internal buffer dimensions changed
let mut needs_internal_realloc = false;
for kernel in state.kernels.iter() {
let internal_dims = kernel.kernel_op.internal_buffer_dyn_dims();
if internal_dims
.iter()
.any(|d| dyn_map.get(d) != state.last_dyn_values.get(d))
{
needs_internal_realloc = true;
break;
}
}
// Reallocate internal buffers if needed
if needs_internal_realloc {
for kernel in state.kernels.iter_mut() {
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
}
}
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
if dyn_map_changed || needs_internal_realloc {
state.cuda_graph = None;
state.cuda_graph_exec = None;
state.node_to_graph_node.clear();
state.kernel_params.clear();
}
// Allocate dyn_dims_buffer if needed
if !self.dyn_dims_order.is_empty() && state.dyn_dims_buffer.is_none() {
state.dyn_dims_buffer = Some(
stream
.alloc_zeros::<i32>(self.dyn_dims_order.len())
.expect("Failed to allocate dyn_dims buffer"),
);
}
// Update shared dyn_dims buffer if dyn_map changed
if dyn_map_changed && !self.dyn_dims_order.is_empty() {
let values: Vec<i32> = self
.dyn_dims_order
.iter()
.map(|d| dyn_map.get(d).copied().unwrap_or(0) as i32)
.collect();
if let Some(buf) = state.dyn_dims_buffer.as_mut() {
stream.memcpy_htod(&values, buf)?;
}
}
// Build CUDA graph if needed
if state.cuda_graph.is_none() {
self.build_graph(&mut state, stream, buffers, dyn_map)?;
}
// Collect current buffer pointers
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
}
}
// Apply output-aliases-input
for kernel in state.kernels.iter() {
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
{
current_buffer_ptrs.insert(kernel.node, input_ptr);
}
}
// Always call pre_execute for each kernel to reset internal state
// (e.g., MegakernelOps need work queue, head, barriers, lock reset every execution)
for idx in 0..state.kernels.len() {
let kernel = &mut state.kernels[idx];
kernel.kernel_op.pre_execute(
stream,
&mut kernel.internal_bufs,
&mut kernel.constants,
&current_buffer_ptrs,
dyn_map,
);
}
// Check if we need to update the graph
let buffer_ptrs_changed = current_buffer_ptrs != state.last_buffer_ptrs;
let needs_update = dyn_map_changed || buffer_ptrs_changed;
if needs_update {
// Update kernel params
let dyn_dims_ptr = state
.dyn_dims_buffer
.as_ref()
.map(|buf| buf.device_ptr(stream).0)
.unwrap_or(0);
// Build params for each kernel first
let num_kernels = state.kernels.len();
for idx in 0..num_kernels {
let kernel = &state.kernels[idx];
let output_ptr = current_buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
let input_ptrs: Vec<u64> = kernel
.inputs
.iter()
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
);
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
}
// Now update CUDA graph nodes
state
.cuda_graph_exec
.as_ref()
.unwrap()
.ctx
.bind_to_thread()?;
for idx in 0..num_kernels {
let kernel = &state.kernels[idx];
let graph_node = state.node_to_graph_node[&kernel.node];
let grid_dim = (
kernel.grid.0.exec(dyn_map).unwrap() as u32,
kernel.grid.1.exec(dyn_map).unwrap() as u32,
kernel.grid.2.exec(dyn_map).unwrap() as u32,
);
let block_dim = (
kernel.block.0.exec(dyn_map).unwrap() as u32,
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let cu_func = unsafe { kernel.function.raw_function() };
// Get params pointer first to avoid borrowing state twice
let params_ptr = state.kernel_params[idx].as_cuda_params();
let exec = state.cuda_graph_exec.as_mut().unwrap();
unsafe {
exec.update_kernel_node(
graph_node, cu_func, grid_dim, block_dim, shared_mem, params_ptr,
)?;
}
}
state.last_dyn_values = dyn_map.clone();
state.last_buffer_ptrs = current_buffer_ptrs;
}
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
Ok(())
}
/// Build the CUDA graph from compiled kernels.
fn build_graph(
&self,
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
stream: &Arc<CudaStream>,
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
let ctx = stream.context().clone();
let mut graph = CudaGraphHandle::new(ctx.clone())?;
let num_kernels = state.kernels.len();
state.kernel_params.clear();
state.kernel_params.reserve(num_kernels);
let tracing_enabled = enabled!(Level::TRACE);
if tracing_enabled {
let needed_events = num_kernels + 1;
while state.timing_events.len() < needed_events {
state.timing_events.push(create_cuda_event(&ctx)?);
}
}
// Collect buffer pointers
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
for &node in &self.buffer_nodes {
if let Some(buf) = buffers.get(&node) {
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
}
}
let dyn_dims_ptr = state
.dyn_dims_buffer
.as_ref()
.map(|buf| buf.device_ptr(stream).0)
.unwrap_or(0);
graph.ctx.bind_to_thread()?;
let mut prev_graph_node: Option<CUgraphNode> = None;
for idx in 0..num_kernels {
// Allocate internal buffers if not already done
{
let kernel = &mut state.kernels[idx];
if kernel.internal_bufs.is_empty() {
kernel.internal_bufs =
kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
}
}
// Call pre_execute to initialize internal state (e.g., populate buffer array for MegakernelOps)
{
let kernel = &mut state.kernels[idx];
kernel.kernel_op.pre_execute(
stream,
&mut kernel.internal_bufs,
&mut kernel.constants,
&buffer_ptrs,
dyn_map,
);
}
let kernel = &state.kernels[idx];
let grid_dim = (
kernel.grid.0.exec(dyn_map).unwrap() as u32,
kernel.grid.1.exec(dyn_map).unwrap() as u32,
kernel.grid.2.exec(dyn_map).unwrap() as u32,
);
let block_dim = (
kernel.block.0.exec(dyn_map).unwrap() as u32,
kernel.block.1.exec(dyn_map).unwrap() as u32,
kernel.block.2.exec(dyn_map).unwrap() as u32,
);
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
let input_ptrs: Vec<u64> = kernel
.inputs
.iter()
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
let param_values = kernel.kernel_op.build_params(
stream,
output_ptr,
&input_ptrs,
&kernel.internal_bufs,
dyn_dims_ptr,
);
let mut params = UnifiedKernelParams::new(param_values);
let cu_func = unsafe { kernel.function.raw_function() };
let kernel_node = kernel.node;
// Get timing event for this index (separate access from kernels)
let timing_event = if tracing_enabled {
Some(state.timing_events[idx])
} else {
None
};
let deps: &[CUgraphNode] = match (&prev_graph_node, timing_event) {
(Some(prev), Some(event)) => {
let event_node = graph.add_event_record_node(&[*prev], event)?;
prev_graph_node = Some(event_node);
std::slice::from_ref(prev_graph_node.as_ref().unwrap())
}
(None, Some(event)) => {
let event_node = graph.add_event_record_node(&[], event)?;
prev_graph_node = Some(event_node);
std::slice::from_ref(prev_graph_node.as_ref().unwrap())
}
(Some(prev), None) => std::slice::from_ref(prev),
(None, None) => &[],
};
let graph_node = unsafe {
graph.add_kernel_node(
deps,
cu_func,
grid_dim,
block_dim,
shared_mem,
params.as_cuda_params(),
)?
};
state.node_to_graph_node.insert(kernel_node, graph_node);
state.kernels[idx].graph_node = Some(graph_node);
state.kernel_params.push(params);
prev_graph_node = Some(graph_node);
}
if tracing_enabled && let Some(prev) = prev_graph_node {
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}
let exec = graph.instantiate()?;
state.cuda_graph = Some(graph);
state.cuda_graph_exec = Some(exec);
state.last_dyn_values = dyn_map.clone();
state.last_buffer_ptrs = buffer_ptrs;
Ok(())
}
}
impl Drop for CudaGraphOp {
fn drop(&mut self) {
let mut state = self.state.borrow_mut();
// Destroy timing events first
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
if let Some(ctx) = ctx {
for event in state.timing_events.drain(..) {
destroy_cuda_event(&ctx, event);
}
}
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
// so it must be destroyed first to avoid dangling pointer issues.
drop(state.cuda_graph_exec.take());
drop(state.cuda_graph.take());
// Now safe to free dynamically allocated GPU buffers
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
// Constants point to __constant__ memory in the CUDA module,
// not dynamically allocated — must not be freed.
for kernel in state.kernels.iter_mut() {
let constants = std::mem::take(&mut kernel.constants);
for (_k, v) in constants {
std::mem::forget(v);
}
}
}
}
/// Compile KernelOp subgraphs in the LLIR graph into CudaGraphOps.
///
/// This function:
/// 1. Finds all KernelOp nodes in the graph
/// 2. Partitions them into convex subgraphs
/// 3. For each subgraph, creates a CudaGraphOp (which implements HostOp)
/// 4. Adds the CudaGraphOp node to the llir_graph with appropriate edges
///
/// Note: KernelOp nodes remain in the graph for buffer allocation and edge tracking.
/// Their execution is handled by the CudaGraphOp via the CUDA graph API.
#[allow(clippy::type_complexity)]
pub fn kernel_to_host(
llir_graph: &mut LLIRGraph,
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) {
let _span = span!(Level::TRACE, "kernel_to_host").entered();
let kernel_ops_in_graph = llir_graph
.node_indices()
.filter(|n| llir_graph[*n].to_dialect::<dyn KernelOp>().is_some())
.collect::<FxHashSet<_>>();
if kernel_ops_in_graph.is_empty() {
return;
}
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
// Track all CudaGraphOp nodes and their subgraphs for edge creation
let mut cuda_graph_subgraphs: Vec<(NodeIndex, FxHashSet<NodeIndex>)> = Vec::new();
for subgraph in kernel_subgraphs {
// Compile kernels in topological order
let topo_order: Vec<_> = toposort(&*llir_graph, None)
.unwrap()
.into_iter()
.filter(|n| subgraph.contains(n))
.collect();
let mut all_dyn_dims = FxHashSet::default();
let mut all_buffer_nodes = FxHashSet::default();
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
.unwrap();
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
}
// Set global dyn dims ordering so compiles use consistent indices
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
global_dyn_dims.sort();
if !global_dyn_dims.is_empty() {
set_global_dyn_dims(global_dyn_dims.clone());
}
// Compile all kernels with global ordering for correct dyn_dims indices
let mut kernels = Vec::with_capacity(topo_order.len());
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
.unwrap();
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
kernel_op_ref.compile(cuda_stream, kernel_cache);
// Collect inputs from graph edges
let mut inputs: Vec<NodeIndex> = llir_graph
.edges_directed(*kernel_node_idx, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect_vec();
// Collect buffer nodes and sizes
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
let output_size = kernel_op_ref.output_size();
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
all_buffer_nodes.insert(*kernel_node_idx);
all_buffer_sizes.insert(*kernel_node_idx, output_size);
}
all_buffer_nodes.extend(inputs.iter().copied());
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
kernels.push(CompiledKernel::new(
*kernel_node_idx,
kernel_function,
grid,
block,
shared_mem,
inputs,
kernel_op.clone(),
constants,
kernel_op.kernel_name(),
));
}
// Get the possibly-extended global ordering (kernels may have discovered new dims)
let final_global = get_global_dyn_dims();
// Clear global ordering now that all kernels are compiled
clear_global_dyn_dims();
// Use the final global ordering if it was extended during compilation
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
final_order
} else {
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
dims.sort();
dims
};
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
// Create CudaGraphOp with RefCell for interior mutability
let state = CudaGraphOpState::new(kernels);
let cuda_graph_op = CudaGraphOp::new(
buffer_nodes,
all_buffer_sizes,
dyn_dims_order,
cuda_stream.clone(),
state,
);
// Add CudaGraphOp to llir_graph as a HostOp
let cuda_graph_node =
llir_graph.add_node(LLIROp::new(Box::new(cuda_graph_op) as Box<dyn HostOp>));
// Track which kernel nodes belong to this CudaGraphOp
for kernel_node in &subgraph {
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
}
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
// Find external inputs: nodes outside subgraph that have edges into subgraph
let external_inputs: FxHashSet<NodeIndex> = subgraph
.iter()
.flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.filter(|src| !subgraph.contains(src))
})
.collect();
// Add edges from external inputs to CudaGraphOp
for input in &external_inputs {
llir_graph.add_edge(*input, cuda_graph_node, ());
}
// Note: We intentionally keep the kernel nodes in the graph.
// They are needed for:
// 1. Buffer allocation (their output_size determines buffer sizes)
// 2. Edge tracking (other ops like cuBLAS reference specific kernel outputs)
// The CudaGraphOp handles their execution via the CUDA graph API.
}
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
// This ensures proper execution ordering when a kernel in one CudaGraphOp
// produces output consumed by a kernel in another CudaGraphOp.
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
// Find external consumers that are kernels belonging to other CudaGraphOps
for producer_node in subgraph {
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
let consumer = edge.target();
if subgraph.contains(&consumer) {
continue; // Same subgraph
}
// Check if consumer is a kernel in another CudaGraphOp
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer)
&& consumer_cuda_graph != *cuda_graph_node
{
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
}
// Also add edges to HostOps (like cuBLAS ops) that consume our outputs
if llir_graph[consumer]
.to_dialect::<dyn super::super::host::HostOp>()
.is_some()
{
edges_to_add.push((*cuda_graph_node, consumer));
}
}
}
}
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
let topo = toposort(&*llir_graph, None).unwrap();
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
for (i, n) in topo.iter().enumerate() {
topo_pos.insert(*n, i);
}
for (src, dst) in edges_to_add {
// Only add forward edges (src before dst in topo order) to avoid creating cycles
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
if src_pos >= dst_pos {
continue; // Skip back-edges
}
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
llir_graph.add_edge(src, dst, ());
}
}
}

View File

@@ -0,0 +1,309 @@
pub mod host;
pub mod kernel;
pub mod runtime;
use std::{
ffi::{CStr, CString},
path::Path,
sync::Arc,
};
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::{
driver::{CudaContext, DriverError, sys as driver_sys},
nvrtc::{
Ptx,
result::{self as nvrtc_result, NvrtcError},
sys as nvrtc_sys,
},
};
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
match dtype {
DType::F64 => "double",
DType::F32 => "float",
DType::F16 => "half",
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",
DType::U8 => "unsigned char",
DType::Bool => "unsigned char",
DType::F8E4M3 => "__nv_fp8_e4m3",
DType::F8E5M2 => "__nv_fp8_e5m2",
DType::F8UE8M0 => "__nv_fp8_e8m0",
DType::F6E2M3 => "__nv_fp6_e2m3",
DType::F6E3M2 => "__nv_fp6_e3m2",
DType::F4E2M1 => "__nv_fp4_e2m1",
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
}
}
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
#[derive(Debug)]
pub(crate) enum CudaModuleImageCompileFailure {
ComputeCapability(DriverError),
Nvrtc {
stage: &'static str,
error: NvrtcError,
},
NoModuleImageProduced,
}
#[derive(Debug)]
pub(crate) struct CudaModuleImageCompileError {
pub target_arch: Option<String>,
pub driver_version: Option<i32>,
pub runtime_version: Option<i32>,
pub nvrtc_options: Vec<String>,
pub nvrtc_log: Option<String>,
pub failure: CudaModuleImageCompileFailure,
}
impl std::fmt::Display for CudaModuleImageCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to compile CUDA module image")?;
if let Some(target_arch) = &self.target_arch {
write!(f, " for {target_arch}")?;
}
match &self.failure {
CudaModuleImageCompileFailure::ComputeCapability(error) => {
write!(f, ": failed to query compute capability: {error}")?;
}
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
write!(f, ": NVRTC {stage} failed: {error}")?;
}
CudaModuleImageCompileFailure::NoModuleImageProduced => {
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
}
}
if let Some(version) = self.driver_version {
write!(f, " | driver {}", format_cuda_version(version))?;
}
if let Some(version) = self.runtime_version {
write!(f, " | runtime {}", format_cuda_version(version))?;
}
if !self.nvrtc_options.is_empty() {
write!(f, " | options {:?}", self.nvrtc_options)?;
}
if let Some(log) = &self.nvrtc_log {
write!(f, " | log: {log}")?;
}
Ok(())
}
}
impl std::error::Error for CudaModuleImageCompileError {}
fn format_cuda_version(version: i32) -> String {
format!("{}.{}", version / 1000, (version % 1000) / 10)
}
fn cuda_nvrtc_include_paths() -> Vec<String> {
let mut include_paths = Vec::new();
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
if let Ok(root) = std::env::var(env_var) {
let path = format!("{root}/include");
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
}
for path in CUDA_NVRTC_INCLUDE_PATHS {
let path = path.to_string();
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
include_paths
}
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
let mut driver_version = 0;
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
.result()
.ok()
.map(|_| driver_version);
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
// resolves newer libcudart symbols that may not exist in the installed runtime.
(driver_version, None)
}
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
let mut options = cuda_nvrtc_include_paths()
.into_iter()
.map(|path| format!("--include-path={path}"))
.collect::<Vec<_>>();
options.push(format!("--gpu-architecture={target_arch}"));
options
}
fn build_module_image_compile_error(
target_arch: Option<String>,
driver_version: Option<i32>,
runtime_version: Option<i32>,
nvrtc_options: &[String],
nvrtc_log: Option<String>,
failure: CudaModuleImageCompileFailure,
) -> CudaModuleImageCompileError {
CudaModuleImageCompileError {
target_arch,
driver_version,
runtime_version,
nvrtc_options: nvrtc_options.to_vec(),
nvrtc_log,
failure,
}
}
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
if raw.is_empty() {
return None;
}
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
.to_string_lossy()
.trim_end_matches('\0')
.trim()
.to_string();
if log.is_empty() { None } else { Some(log) }
}
#[allow(clippy::slow_vector_initialization)]
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
let mut cubin_size = 0usize;
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
if cubin_size == 0 {
return Ok(Vec::new());
}
let mut cubin = Vec::with_capacity(cubin_size);
cubin.resize(cubin_size, 0);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
}
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
ctx: &Arc<CudaContext>,
src: S,
) -> Result<Ptx, CudaModuleImageCompileError> {
let (driver_version, runtime_version) = cuda_driver_diagnostics();
let (major, minor) = ctx.compute_capability().map_err(|error| {
build_module_image_compile_error(
None,
driver_version,
runtime_version,
&[],
None,
CudaModuleImageCompileFailure::ComputeCapability(error),
)
})?;
let target_arch = format!("sm_{major}{minor}");
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
let source = CString::new(src.as_ref().as_bytes())
.expect("CUDA source code cannot contain null terminators");
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
build_module_image_compile_error(
Some(target_arch.clone()),
driver_version,
runtime_version,
&nvrtc_options,
None,
CudaModuleImageCompileFailure::Nvrtc {
stage: "create_program",
error,
},
)
})?;
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
let nvrtc_log = read_nvrtc_log(program);
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "compile_program",
error,
},
));
}
let nvrtc_log = read_nvrtc_log(program);
let cubin = match get_cubin(program) {
Ok(cubin) => cubin,
Err(error) => {
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "get_cubin",
error,
},
));
}
};
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "destroy_program",
error,
},
));
}
if cubin.is_empty() {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::NoModuleImageProduced,
));
}
Ok(Ptx::from_binary(cubin))
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 273,
"NVIDIA H100 PCIe" => 2_000,
"NVIDIA H100 SXM" => 3_350,
_ => return None,
})
}
/// Returns the bandwidth of the device in TFLOPs
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 125, // forced to use tf32 flops
"NVIDIA H100 PCIe" => 756,
"NVIDIA H100 SXM" => 989,
_ => return None,
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,344 @@
use crate::runtime::CudaRuntime;
use crate::tests::utilities::*;
use luminal::prelude::*;
use rand::{SeedableRng, rngs::SmallRng};
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
/// Returns (cx, input_node, output_node).
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', 4));
let b = (a + a).output();
(cx, a.id, b.id)
}
/// Helper: build a matmul graph with dynamic dim 's'.
/// Computes (s, K) @ (K, N) -> (s, N)
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let c = a.matmul(b).output();
(cx, a.id, b.id, c.id)
}
#[test]
fn test_bucket_dispatch_simple() {
// Tests that bucketed compilation produces correct results for different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Set dummy input for search
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Test bucket 1: s=1
cx.set_dim('s', 1);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..4], &expected, 1e-5, 1e-5);
// Test bucket 2: s=3
cx.set_dim('s', 3);
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..12], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_matmul_dynamic() {
// Tests matmul with bucketed dynamic dim
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 8;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Execute at s=1
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s1 = rt.get_f32(c);
// Compute reference for s=1 (1xK @ KxN -> 1xN)
let mut expected_s1 = vec![0.0f32; n];
for j in 0..n {
for i in 0..k {
expected_s1[j] += a_data[i] * b_data[i * n + j];
}
}
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
// Execute at s=4
cx.set_dim('s', 4);
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
rt.set_data(a, a_data_4.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s4 = rt.get_f32(c);
// Compute reference for s=4 (4xK @ KxN -> 4xN)
let mut expected_s4 = vec![0.0f32; 4 * n];
for row in 0..4 {
for j in 0..n {
for i in 0..k {
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
}
#[test]
fn test_bucket_results_match_unbucketed() {
// Tests that bucketed results match non-bucketed results for the same graph
let Some(stream) = get_cuda_stream() else {
return;
};
let seed = 42u64;
// Non-bucketed run
let (mut cx1, a1, b1) = build_dynamic_add_graph();
cx1.set_dim('s', 3);
cx1.build_search_space::<CudaRuntime>();
let mut rt1 = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
// Bucketed run with bucket that covers s=3
let (mut cx2, a2, b2) = build_dynamic_add_graph();
cx2.set_dim('s', 3);
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
cx2.build_search_space::<CudaRuntime>();
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
// Results should match — same graph, same search seed, same dyn_map
assert_eq!(result_unbucketed.len(), result_bucketed.len());
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
}
#[test]
#[should_panic(expected = "No bucket matches")]
fn test_bucket_out_of_range_panics() {
let Some(stream) = get_cuda_stream() else {
// Can't trigger panic without GPU, skip gracefully
panic!("No bucket matches dyn_map");
};
let (mut cx, a, _b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
rt.set_data(a, vec![1.0f32; 40]);
rt.execute(&cx.dyn_map);
}
#[test]
fn test_bucket_no_buckets_backward_compat() {
// No buckets set → should behave identically to old path
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim('s', 2);
// No set_dim_buckets call
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..8], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_representative_override() {
// Tests that custom representative works
let bucket = DimBucket::new(2, 32).representative(16);
assert_eq!(bucket.representative_value(), 16);
let bucket_default = DimBucket::new(2, 32);
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
let exact = DimBucket::new(1, 1);
assert_eq!(exact.representative_value(), 1);
}
#[test]
fn test_bucket_switch_preserves_weights() {
// Tests that switching between buckets still sees the correct weight data
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 4;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 5, &mut rng);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1a = rt.get_f32(c);
// Switch to bucket 2 (s=3)
cx.set_dim('s', 3);
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
rt.set_data(a, a_data_3.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_3 = rt.get_f32(c);
// Switch back to bucket 1 (s=1) — weights should still work
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1b = rt.get_f32(c);
// First and last s=1 results should match exactly
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
// Verify s=3 result correctness
let mut expected_3 = vec![0.0f32; 3 * n];
for row in 0..3 {
for j in 0..n {
for i in 0..k {
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
}
#[test]
fn test_bucket_multiple_executions_same_bucket() {
// Tests multiple executions within the same bucket with different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_rng(rt, 3, &mut rng);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {
cx.set_dim('s', s);
let n = s * 4;
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
rt.set_data(a, input.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
assert_close(&result[..n], &expected, 1e-5, 1e-5);
}
}
#[test]
#[should_panic(expected = "Overlapping buckets")]
fn test_bucket_overlapping_ranges_panics() {
let mut cx = Graph::default();
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
}
#[test]
fn test_dim_bucket_contains() {
let b = DimBucket::new(2, 10);
assert!(!b.contains(1));
assert!(b.contains(2));
assert!(b.contains(5));
assert!(b.contains(10));
assert!(!b.contains(11));
// Exact bucket
let exact = DimBucket::new(3, 3);
assert!(!exact.contains(2));
assert!(exact.contains(3));
assert!(!exact.contains(4));
}

View File

@@ -0,0 +1,416 @@
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use rand::SeedableRng;
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice, validate_choice_set};
use crate::kernel::KernelOp;
use crate::runtime::CudaRuntime;
/// Helper: build search space and extract all possible kernel names across many random choices.
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
let custom_ops = &cx.custom_ops;
let mut all_names = Vec::new();
// Try many random extractions to cover both alternatives
for _ in 0..20 {
let choices = random_initial_choice(egraph, &mut rand::rng());
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
let name = k.kernel_name().to_string();
if !all_names.contains(&name) {
all_names.push(name);
}
}
}
}
all_names
}
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
/// the ConsumedBuffer (not in any other ICons).
#[test]
fn test_scatter_nocopy_selected_when_dest_unshared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let _result = src.scatter(indexes, dest).output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should be available (dest is not shared)
assert!(
names.iter().any(|n| n == "ScatterNoCopy"),
"Expected ScatterNoCopy to be available but got: {:?}",
names
);
}
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
/// cleanup rule should fire, deleting the ConsumedBuffer. This makes KernelScatterNoCopy
/// invalid, so it should NOT appear in any extraction.
#[test]
fn test_scatter_nocopy_not_selected_when_dest_shared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let scatter_result = src.scatter(indexes, dest);
// Also use dest directly in another op (add with itself) — this makes dest shared
let _dest_also_used = (dest + dest).output();
let _result = scatter_result.output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should NOT be available (dest is shared with the add op)
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest is shared, got: {:?}",
names
);
// Regular KernelScatter should be present
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected Scatter but got: {:?}",
names
);
}
/// Actually execute the scatter and verify correctness.
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
#[test]
fn test_scatter_execution_correctness() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// dest: [0.0, 1.0, 2.0, 3.0, 4.0]
let dest = cx.tensor(5).persist();
// src: [10.0, 20.0, 30.0]
let src = cx.tensor(3).persist();
// indexes: [1, 3, 4]
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
// Try many random extractions to cover both Scatter and ScatterNoCopy
let mut rng = rand::rng();
let mut tested_scatter = false;
let mut tested_nocopy = false;
for _ in 0..50 {
let choices = random_initial_choice(egraph, &mut rng);
if validate_choice_set(egraph, &choices, ops).is_err() {
continue;
}
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
// Check which scatter variant was selected
let mut has_nocopy = false;
let mut has_scatter = false;
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
match k.kernel_name() {
"ScatterNoCopy" => has_nocopy = true,
"Scatter" => has_scatter = true,
_ => {}
}
}
}
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir);
rt.set_data(dest, vec![0.0f32, 1.0, 2.0, 3.0, 4.0]);
rt.set_data(src, vec![10.0f32, 20.0, 30.0]);
rt.set_data(indexes, vec![1i32, 3, 4]);
rt.execute(&cx.dyn_map);
let actual = rt.get_f32(result);
let variant = if has_nocopy {
tested_nocopy = true;
"ScatterNoCopy"
} else if has_scatter {
tested_scatter = true;
"Scatter"
} else {
"Unknown"
};
assert_eq!(
actual, expected,
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
actual, expected
);
}
println!(
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
tested_scatter, tested_nocopy
);
assert!(
tested_nocopy,
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
);
}
/// Test the KV-cache round-trip pattern: scatter → remove_buffer → set_buffer → scatter again.
/// This mimics how the llama model uses scatter for KV cache updates.
#[test]
fn test_scatter_kv_cache_roundtrip() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// KV cache: [5] elements (simulating a small cache)
let cache_in = cx.named_tensor("cache", 5).persist();
// New value to scatter: [1] element
let src = cx.tensor(1).persist();
// Index: [1] element (position to write)
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// scatter src into cache at index position
let cache_out = src.scatter(indexes, cache_in);
// Also read the scatter output (simulates attention reading from cache)
let read_out = (cache_out + 0.0).output();
// Return cache for round-trip
let cache_output = cache_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
// Must set input data BEFORE search (profiler needs valid buffers)
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, 5);
// Print which scatter variant was selected
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Selected: {}", k.kernel_name());
}
}
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let read1 = rt.get_f32(read_out);
println!("After step 1 (scatter 10.0 at pos 0): {:?}", read1);
assert_eq!(
read1,
vec![10.0, 0.0, 0.0, 0.0, 0.0],
"Step 1 read_out mismatch"
);
// Round-trip: remove cache output buffer, set as new cache input
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 2: Scatter 20.0 at position 1
rt.set_data(src, vec![20.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let read2 = rt.get_f32(read_out);
println!("After step 2 (scatter 20.0 at pos 1): {:?}", read2);
assert_eq!(
read2,
vec![10.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 read_out mismatch"
);
// Round-trip again
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 3: Scatter 30.0 at position 2
rt.set_data(src, vec![30.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let read3 = rt.get_f32(read_out);
println!("After step 3 (scatter 30.0 at pos 2): {:?}", read3);
assert_eq!(
read3,
vec![10.0, 20.0, 30.0, 0.0, 0.0],
"Step 3 read_out mismatch"
);
}
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
/// Also verifies graph_break interaction.
#[test]
fn test_scatter_dual_cache_with_graph_break() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// Two caches (like K and V)
let k_cache = cx.named_tensor("k_cache", 5).persist();
let v_cache = cx.named_tensor("v_cache", 5).persist();
// Input values
let k_new = cx.tensor(1).persist();
let v_new = cx.tensor(1).persist();
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// Scatter into both caches
let k_out = k_new.scatter(indexes, k_cache);
let v_out = v_new.scatter(indexes, v_cache);
// Read both (simulates attention using the scattered caches)
let k_read = k_out + 0.0;
let v_read = v_out + 0.0;
// Compute something from the scattered values (simulates attention output)
let attn = k_read * v_read;
// Output everything
let attn_out = attn.output();
let k_cache_out = k_out.output();
let v_cache_out = v_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
// Use seeded search for deterministic scatter variant selection.
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_rng(rt, 5, &mut rng);
// Print selected variants
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Dual test selected: {}", k.kernel_name());
}
}
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let attn1 = rt.get_f32(attn_out);
println!("Attn step 1: {:?}", attn1);
// k=[2,0,0,0,0], v=[3,0,0,0,0], attn = k*v = [6,0,0,0,0]
assert_eq!(attn1, vec![6.0, 0.0, 0.0, 0.0, 0.0], "Step 1 attn mismatch");
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 2: scatter k=4.0, v=5.0 at position 1
rt.set_data(k_new, vec![4.0f32]);
rt.set_data(v_new, vec![5.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let attn2 = rt.get_f32(attn_out);
println!("Attn step 2: {:?}", attn2);
// k=[2,4,0,0,0], v=[3,5,0,0,0], attn = k*v = [6,20,0,0,0]
assert_eq!(
attn2,
vec![6.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 attn mismatch"
);
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 3: scatter k=6.0, v=7.0 at position 2
rt.set_data(k_new, vec![6.0f32]);
rt.set_data(v_new, vec![7.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let attn3 = rt.get_f32(attn_out);
println!("Attn step 3: {:?}", attn3);
// k=[2,4,6,0,0], v=[3,5,7,0,0], attn = k*v = [6,20,42,0,0]
assert_eq!(
attn3,
vec![6.0, 20.0, 42.0, 0.0, 0.0],
"Step 3 attn mismatch"
);
}

View File

@@ -0,0 +1,14 @@
pub mod utilities;
#[cfg(test)]
mod bucket_tests;
#[cfg(test)]
mod consumed_buffer_tests;
#[cfg(test)]
mod model_fuzz;
#[cfg(test)]
mod op_functional_tests;
#[cfg(test)]
mod performance_tests;
#[cfg(test)]
mod transformer;

View File

@@ -0,0 +1,685 @@
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
//!
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
//! reference to catch incorrect HLIR kernel fallback rewrites.
use luminal::prelude::*;
use super::utilities::{assert_close, fuzz_genomes, get_cuda_stream, random_f32_vec};
use crate::runtime::CudaRuntime;
/// Number of genomes to fuzz per test (higher than default GENOME_FUZZ_COUNT=20).
const FUZZ_COUNT: usize = 100;
// ============================================================================
// RMSNorm helper (used by all three models)
// ============================================================================
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
let normed = x.std_norm(x.shape.last_axis(), eps);
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
}
fn rms_norm_ref(
x: &candle_core::Tensor,
weight: &candle_core::Tensor,
eps: f64,
) -> candle_core::Tensor {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
let normed = x.broadcast_mul(&rsqrt).unwrap();
normed
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
.unwrap()
}
// ============================================================================
// SwiGLU MLP helper (used by all three models)
// ============================================================================
fn swiglu_mlp(
x: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
) -> GraphTensor {
let gate = x.matmul(w_gate.t()).swish();
let up = x.matmul(w_up.t());
(gate * up).matmul(w_down.t())
}
fn swiglu_mlp_ref(
x: &candle_core::Tensor,
w_gate: &candle_core::Tensor,
w_up: &candle_core::Tensor,
w_down: &candle_core::Tensor,
) -> candle_core::Tensor {
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
let up = x.matmul(&w_up.t().unwrap()).unwrap();
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
// ============================================================================
// Generic test functions
// ============================================================================
/// Test a SwiGLU MLP block at given dimensions with genome fuzzing.
fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
/// Test RMSNorm + matmul projection at given dimensions with genome fuzzing.
fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let norm_w = cx.tensor(hidden);
let proj_w = cx.tensor((proj_dim, hidden));
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), hidden, &device).unwrap();
let ref_proj =
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_norm, eps as f64);
let expected = normed.matmul(&ref_proj.t().unwrap()).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
/// Test a full transformer layer (norm -> proj -> norm -> MLP) without attention.
fn fuzz_layer_no_attn(
seq: usize,
hidden: usize,
intermediate: usize,
proj_dim: usize,
eps: f32,
seed: u64,
) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let attn_norm_w = cx.tensor(hidden);
let proj_w = cx.tensor((proj_dim, hidden));
let o_proj_w = cx.tensor((hidden, proj_dim));
let mlp_norm_w = cx.tensor(hidden);
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let normed = rms_norm(input, attn_norm_w, eps);
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
let x = input + proj_out;
let mlp_normed = rms_norm(x, mlp_norm_w, eps);
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
let out = (x + mlp_out).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let attn_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
let o_proj_data = random_f32_vec(hidden * proj_dim, seed + 3, -0.3, 0.3);
let mlp_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 4, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let gate_data = random_f32_vec(intermediate * hidden, seed + 5, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 6, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 7, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Candle reference
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_attn_norm =
candle_core::Tensor::from_vec(attn_norm_data.clone(), hidden, &device).unwrap();
let ref_proj =
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
let ref_o_proj =
candle_core::Tensor::from_vec(o_proj_data.clone(), (hidden, proj_dim), &device).unwrap();
let ref_mlp_norm =
candle_core::Tensor::from_vec(mlp_norm_data.clone(), hidden, &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, eps as f64);
let proj_out = normed
.matmul(&ref_proj.t().unwrap())
.unwrap()
.matmul(&ref_o_proj.t().unwrap())
.unwrap();
let x_ref = (&ref_input + proj_out).unwrap();
let mlp_normed = rms_norm_ref(&x_ref, &ref_mlp_norm, eps as f64);
let mlp_out = swiglu_mlp_ref(&mlp_normed, &ref_gate, &ref_up, &ref_down);
let expected_t = (x_ref + mlp_out).unwrap();
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2, 2e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
2e-2,
2e-2,
FUZZ_COUNT,
seed,
);
}
/// Test a SwiGLU MLP with HLIR-only to specifically verify
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((seq, hidden));
let w_gate = cx.tensor((intermediate, hidden));
let w_up = cx.tensor((intermediate, hidden));
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_up =
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
// ============================================================================
// Llama-specific tests
// Llama 3 8B: HIDDEN=4096, INTERMEDIATE=14336, HEAD_DIM=128
// Using scaled-down dims that preserve architectural ratios
// ============================================================================
mod llama {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 256;
const INTERMEDIATE: usize = 896; // ~3.5x hidden, matching 14336/4096
const PROJ_DIM: usize = 256; // Q_DIM == HIDDEN for llama
const EPS: f32 = 1e-5;
#[test]
fn fuzz_llama_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
}
#[test]
fn fuzz_llama_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
}
#[test]
fn fuzz_llama_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
}
#[test]
fn fuzz_llama_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
}
#[test]
fn fuzz_llama_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
}
/// Force HLIR-only (no block ops) to specifically test the fallback path.
#[test]
fn fuzz_llama_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
}
}
// ============================================================================
// Gemma-specific tests
// Gemma 3 4B: HIDDEN=2560, INTERMEDIATE=10240, HEAD_DIM=256, Q_DIM=2048
// Key difference: Q_DIM != HIDDEN, and 4 extra RMSNorm layers per block
// ============================================================================
mod gemma {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 320; // divisible by 8 (N_HEADS)
const INTERMEDIATE: usize = 1280; // 4x hidden, matching 10240/2560
const Q_DIM: usize = 256; // scaled from 2048 (N_HEADS * HEAD_DIM)
const EPS: f32 = 1e-6;
#[test]
fn fuzz_gemma_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
}
#[test]
fn fuzz_gemma_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
}
#[test]
fn fuzz_gemma_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
}
/// Gemma has extra post-attention and post-feedforward norms.
#[test]
fn fuzz_gemma_layer_full_norms() {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let attn_norm_w = cx.tensor(HIDDEN);
let post_attn_norm_w = cx.tensor(HIDDEN);
let pre_ff_norm_w = cx.tensor(HIDDEN);
let post_ff_norm_w = cx.tensor(HIDDEN);
let proj_w = cx.tensor((Q_DIM, HIDDEN));
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
let normed = rms_norm(input, attn_norm_w, EPS);
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
let x = input + attn_normed;
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
let mlp_out = swiglu_mlp(ff_normed, w_gate, w_up, w_down);
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
let out = (x + mlp_normed).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 800u64;
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
let attn_norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let post_attn_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 2, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let pre_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 3, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let post_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 4, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let proj_data = random_f32_vec(Q_DIM * HIDDEN, seed + 5, -0.3, 0.3);
let o_proj_data = random_f32_vec(HIDDEN * Q_DIM, seed + 6, -0.3, 0.3);
let gate_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 7, -0.3, 0.3);
let up_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 8, -0.3, 0.3);
let down_data = random_f32_vec(HIDDEN * INTERMEDIATE, seed + 9, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(post_attn_norm_w, post_attn_data.clone());
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
rt.set_data(post_ff_norm_w, post_ff_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Candle reference
let device = candle_core::Device::Cpu;
let t = |data: &[f32], shape: &[usize]| {
candle_core::Tensor::from_vec(data.to_vec(), shape, &device).unwrap()
};
let ref_input = t(&input_data, &[SEQ, HIDDEN]);
let ref_attn_norm = t(&attn_norm_data, &[HIDDEN]);
let ref_post_attn = t(&post_attn_data, &[HIDDEN]);
let ref_pre_ff = t(&pre_ff_data, &[HIDDEN]);
let ref_post_ff = t(&post_ff_data, &[HIDDEN]);
let ref_proj = t(&proj_data, &[Q_DIM, HIDDEN]);
let ref_o_proj = t(&o_proj_data, &[HIDDEN, Q_DIM]);
let ref_gate = t(&gate_data, &[INTERMEDIATE, HIDDEN]);
let ref_up = t(&up_data, &[INTERMEDIATE, HIDDEN]);
let ref_down = t(&down_data, &[HIDDEN, INTERMEDIATE]);
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, EPS as f64);
let proj_out = normed
.matmul(&ref_proj.t().unwrap())
.unwrap()
.matmul(&ref_o_proj.t().unwrap())
.unwrap();
let attn_normed = rms_norm_ref(&proj_out, &ref_post_attn, EPS as f64);
let x_ref = (&ref_input + attn_normed).unwrap();
let ff_normed = rms_norm_ref(&x_ref, &ref_pre_ff, EPS as f64);
let mlp_out = swiglu_mlp_ref(&ff_normed, &ref_gate, &ref_up, &ref_down);
let mlp_normed = rms_norm_ref(&mlp_out, &ref_post_ff, EPS as f64);
let expected_t = (x_ref + mlp_normed).unwrap();
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2, 2e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(attn_norm_w, attn_norm_data.clone());
rt.set_data(post_attn_norm_w, post_attn_data.clone());
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
rt.set_data(post_ff_norm_w, post_ff_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt.set_data(o_proj_w, o_proj_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
},
out.id,
&expected,
2e-2,
2e-2,
FUZZ_COUNT,
seed,
);
}
#[test]
fn fuzz_gemma_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
}
/// Force HLIR-only to test fallback path with Gemma dimensions.
#[test]
fn fuzz_gemma_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
}
}
// ============================================================================
// Qwen-specific tests
// Qwen3-4B: HIDDEN=2560, INTERMEDIATE=9728, HEAD_DIM=128, Q_DIM=4096
// Key difference: Q_DIM > HIDDEN, tied embeddings (lm_head = embedding.t())
// ============================================================================
mod qwen {
use super::*;
const SEQ: usize = 4;
const HIDDEN: usize = 256;
const INTERMEDIATE: usize = 768; // ~3x hidden, matching 9728/2560
const Q_DIM: usize = 512; // scaled from 4096 (Q_DIM > HIDDEN)
const EPS: f32 = 1e-6;
#[test]
fn fuzz_qwen_mlp() {
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
}
#[test]
fn fuzz_qwen_norm_proj() {
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
}
#[test]
fn fuzz_qwen_layer() {
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
}
/// Qwen uses tied embeddings: lm_head = embedding^T
#[test]
fn fuzz_qwen_lm_head() {
let Some(stream) = get_cuda_stream() else {
return;
};
const VOCAB: usize = 512;
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let norm_w = cx.tensor(HIDDEN);
let embedding = cx.tensor((VOCAB, HIDDEN));
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let seed = 1300u64;
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
let norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
let emb_data = random_f32_vec(VOCAB * HIDDEN, seed + 2, -0.3, 0.3);
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.clone(), (SEQ, HIDDEN), &device).unwrap();
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), HIDDEN, &device).unwrap();
let ref_emb =
candle_core::Tensor::from_vec(emb_data.clone(), (VOCAB, HIDDEN), &device).unwrap();
let normed = rms_norm_ref(&ref_input, &ref_norm, EPS as f64);
let expected = normed.matmul(&ref_emb.t().unwrap()).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
},
out.id,
&expected,
1e-2,
1e-2,
FUZZ_COUNT,
seed,
);
}
#[test]
fn fuzz_qwen_mlp_seq1() {
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
}
#[test]
fn fuzz_qwen_mlp_seq7() {
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
}
/// Force HLIR-only to test fallback path with Qwen dimensions.
#[test]
fn fuzz_qwen_mlp_hlir_only() {
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
}
}

View File

@@ -0,0 +1,606 @@
use candle_core::{Device, Tensor};
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use proptest::prelude::*;
use luminal::egglog_utils::{
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
};
use crate::runtime::CudaRuntime;
#[allow(unused_imports)]
use super::utilities::{
GENOME_FUZZ_COUNT, TOLERANCE_SAFETY_FACTOR, assert_close, dtype_epsilon, fuzz_genomes,
gen_slice_range, get_cuda_stream, gpu_supports_dtype, random_f32_vec, random_i32_vec,
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
};
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
#[test]
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
}
#[test]
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
}
#[test]
fn test_matmul(
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
(1usize..128, 1usize..128, 1usize..128, any::<bool>(), any::<bool>(),
any::<(bool, bool)>(), any::<(bool, bool)>(), any::<(bool, bool)>(),
prop::sample::select(&[luminal::dtype::DType::F32, luminal::dtype::DType::F16, luminal::dtype::DType::Bf16]))
.prop_perturb(|(m, n, k, a_cm, b_cm, m_sl, k_sl, n_sl, dt), mut rng| {
(m, n, k, a_cm, b_cm,
gen_slice_range(m, m_sl.0, m_sl.1, &mut rng),
gen_slice_range(k, k_sl.0, k_sl.1, &mut rng),
gen_slice_range(n, n_sl.0, n_sl.1, &mut rng),
dt)
}),
seed in any::<u64>()
) {
prop_assume!(gpu_supports_dtype(dtype), "GPU does not support {:?}", dtype);
let (m_start, m_end) = m_slice;
let (k_start, k_end) = k_slice;
let (n_start, n_end) = n_slice;
let effective_m = m_end - m_start;
let effective_k = k_end - k_start;
let effective_n = n_end - n_start;
// Column-major achieved by storing transposed then calling .t()
let (a_shape, b_shape): ((usize, usize), (usize, usize)) = match (a_col_major, b_col_major) {
(false, false) => ((m, k), (k, n)), // Rm x Rm
(false, true) => ((m, k), (n, k)), // Rm x Cm
(true, false) => ((k, m), (k, n)), // Cm x Rm
(true, true) => ((k, m), (n, k)), // Cm x Cm
};
let candle_dtype = to_candle_dtype(dtype);
let luminal_op = move |a: GraphTensor, b: GraphTensor| {
let a = a.cast(dtype);
let b = b.cast(dtype);
let a = if a_col_major { a.t() } else { a };
let b = if b_col_major { b.t() } else { b };
// After transpose: A is (m, k), B is (k, n)
let a = a.slice((m_start..m_end, k_start..k_end));
let b = b.slice((k_start..k_end, n_start..n_end));
a.matmul(b).cast(luminal::dtype::DType::F32)
};
let candle_op = move |a: Tensor, b: Tensor| {
let a = a.to_dtype(candle_dtype).unwrap();
let b = b.to_dtype(candle_dtype).unwrap();
let a = if a_col_major { a.t().unwrap() } else { a };
let b = if b_col_major { b.t().unwrap() } else { b };
// After transpose: A is (m, k), B is (k, n)
let a = a.narrow(0, m_start, effective_m).unwrap()
.narrow(1, k_start, effective_k).unwrap()
.contiguous().unwrap();
let b = b.narrow(0, k_start, effective_k).unwrap()
.narrow(1, n_start, effective_n).unwrap()
.contiguous().unwrap();
a.matmul(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap()
};
// Matmul tolerance: rtol scales with sqrt(k) for accumulated rounding error
let eps = dtype_epsilon(dtype);
let sqrt_k = (effective_k as f32).sqrt();
let rtol = eps * sqrt_k;
let atol = 5.0 * eps;
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
}
// Unary ops tests
#[test]
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
}
#[test]
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// log2(x) = ln(x) / ln(2)
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
}
#[test]
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
}
#[test]
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
}
#[test]
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
}
// Binary ops tests
#[test]
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
test_mod(x, x, |a, b| a % b, seed);
test_mod((y, x), (y, x), |a, b| a % b, seed);
}
#[test]
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
}
}
#[allow(dead_code)]
fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
let total = rows * cols;
let mut cx = Graph::default();
let input = cx.tensor((rows, cols));
let sorted_dim0 = input.stable_argsort(0, true).output(); // descend
let sorted_dim1 = input.stable_argsort(1, false).output(); // ascend
// random and unique data using seed
let data: Vec<f32> = random_f32_vec(total, seed, 0.0, 1.0);
let sorted_cols: Vec<Vec<i32>> = (0..cols)
.map(|col| {
let mut indices: Vec<i32> = (0..rows as i32).collect();
indices.sort_by(|&a, &b| {
let va = data[(a as usize) * cols + col];
let vb = data[(b as usize) * cols + col];
vb.partial_cmp(&va).unwrap()
});
indices
})
.collect();
let expected_dim0: Vec<i32> = (0..rows)
.flat_map(|row| {
(0..cols)
.map(|col| sorted_cols[col][row])
.collect::<Vec<_>>()
})
.collect();
let expected_dim1: Vec<i32> = (0..rows)
.flat_map(|row| {
let mut indices: Vec<i32> = (0..cols as i32).collect();
indices.sort_by(|&a, &b| {
let va = data[row * cols + (a as usize)];
let vb = data[row * cols + (b as usize)];
va.partial_cmp(&vb).unwrap()
});
indices
})
.collect();
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data);
rt = cx.search(rt, 10);
rt.execute(&cx.dyn_map);
let out_dim0 = rt.get_i32(sorted_dim0.id);
let out_dim1 = rt.get_i32(sorted_dim1.id);
assert_eq!(out_dim0.len(), expected_dim0.len(), "dim0 length mismatch");
assert_eq!(out_dim1.len(), expected_dim1.len(), "dim1 length mismatch");
// Debug: check for out-of-range values (indices should be 0..rows for dim0, 0..cols for dim1)
let max_valid_dim0 = rows as i32 - 1;
let max_valid_dim1 = cols as i32 - 1;
let bad_dim0: Vec<_> = out_dim0
.iter()
.enumerate()
.filter(|&(_, &v)| v < 0 || v > max_valid_dim0)
.take(10)
.collect();
let bad_dim1: Vec<_> = out_dim1
.iter()
.enumerate()
.filter(|&(_, &v)| v < 0 || v > max_valid_dim1)
.take(10)
.collect();
if !bad_dim0.is_empty() {
panic!(
"dim0 has out-of-range values (valid: 0-{max_valid_dim0}): {:?}\nFirst 20 values: {:?}",
bad_dim0,
&out_dim0[..20.min(out_dim0.len())]
);
}
if !bad_dim1.is_empty() {
panic!(
"dim1 has out-of-range values (valid: 0-{max_valid_dim1}): {:?}",
bad_dim1
);
}
for i in 0..out_dim0.len() {
assert_eq!(
out_dim0[i], expected_dim0[i],
"dim0 mismatch at {i}: got {}, expected {}",
out_dim0[i], expected_dim0[i]
);
}
for i in 0..out_dim1.len() {
assert_eq!(
out_dim1[i], expected_dim1[i],
"dim1 mismatch at {i}: got {}, expected {}",
out_dim1[i], expected_dim1[i]
);
}
}
// NOTE: Argsort proptest disabled due to pre-existing bug where argsort output shape
// through e-graph compilation returns only `rows` elements instead of `rows * cols`.
// proptest! {
// #![proptest_config(ProptestConfig::with_cases(10))]
// #[test]
// fn test_argsort(seed in any::<u64>()) {
// run_argsort_test(5, 500, seed);
// }
// }
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
#[test]
#[allow(clippy::approx_constant, clippy::excessive_precision)]
pub fn test_cast_f16_edge_cases() {
use luminal::dtype::DType;
// Fixed edge-case values that exercise F16 behavior
let edge_cases: Vec<f32> = vec![
0.0,
1.0,
-1.0,
0.5,
0.333333333, // Will truncate: F16 can't represent 1/3 exactly
0.1, // Will truncate: 0.1 isn't exact in binary
1.0009765625, // Exactly representable in F16 (1 + 1/1024)
1.00048828125, // Rounds to 1.0 in F16 (1 + 1/2048, below F16 precision)
1.0007324219, // Between two F16 values, will round
-3.140625, // Exactly representable
3.14159265, // Pi - will truncate
65504.0, // Max normal F16
-65504.0, // Min normal F16
0.000060976, // Near F16 min positive normal
1e-7, // Subnormal in F16
100.0,
-100.0,
12.345678, // Arbitrary value requiring truncation
];
// Generator that ignores seed and returns edge cases
let gen_edge_cases = |_n: usize, _seed: u64| edge_cases.clone();
test_unary_cuda(
edge_cases.len(),
|a| a.cast(DType::F16).cast(DType::F32),
|a| {
a.to_dtype(candle_core::DType::F16)
.unwrap()
.to_dtype(candle_core::DType::F32)
.unwrap()
},
gen_edge_cases,
0,
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
#[test]
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
use luminal::dtype::DType;
// Use range beyond F16 limits so some values overflow to infinity
let f16_max = half::f16::MAX.to_f32();
let gen_lambda = |n, s| random_f32_vec(n, s, -2.0 * f16_max, 2.0 * f16_max);
test_unary_cuda(
size,
|a| a.cast(DType::F16).cast(DType::F32),
|a| {
a.to_dtype(candle_core::DType::F16)
.unwrap()
.to_dtype(candle_core::DType::F32)
.unwrap()
},
gen_lambda,
seed,
);
}
}
/// Fuzz test that generates many random genomes and verifies they all produce correct results.
/// This tests the genetic algorithm search by validating each genome individually.
/// Uses proptest seed for reproducibility - if this test fails, proptest will print the seed
/// which can be used to reproduce the failure.
fn fuzz_test_cuda_genomes_impl(seed: u64) {
use rand::SeedableRng;
use rand::rngs::StdRng;
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping test");
return;
};
println!("Running fuzz_test_cuda_genomes with seed: {}", seed);
// Build a graph with operations that have rewrite alternatives
let mut cx = Graph::default();
let a = cx.tensor((4, 8));
let b = cx.tensor((8, 4));
let c = cx.tensor((4, 4));
// Matmul + add + relu creates opportunities for rewrites
let d = a.matmul(b);
let e = (d + c).relu();
let out = e.output();
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().unwrap();
let ops = cx.egglog_ops().unwrap();
// Count mutable eclasses
let mutable_eclasses: usize = egraph
.eclasses
.iter()
.filter(|(_, (label, enodes))| {
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
})
.count();
println!(
"CUDA search space: {} total eclasses, {} mutable",
egraph.eclasses.len(),
mutable_eclasses
);
// Use seeded RNG for full reproducibility
let mut rng = StdRng::seed_from_u64(seed);
// Generate test data with seeded RNG (reproducible)
let a_data: Vec<f32> = (0..32).map(|_| rng.random::<f32>()).collect();
let b_data: Vec<f32> = (0..32).map(|_| rng.random::<f32>()).collect();
let c_data: Vec<f32> = (0..16).map(|_| rng.random::<f32>()).collect();
// Compute reference result using candle
let device = Device::Cpu;
let ref_a = Tensor::from_vec(a_data.clone(), (4, 8), &device).unwrap();
let ref_b = Tensor::from_vec(b_data.clone(), (8, 4), &device).unwrap();
let ref_c = Tensor::from_vec(c_data.clone(), (4, 4), &device).unwrap();
let ref_d = ref_a.matmul(&ref_b).unwrap();
let ref_e = (&ref_d + &ref_c).unwrap().relu().unwrap();
let expected: Vec<f32> = ref_e.flatten_all().unwrap().to_vec1().unwrap();
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
// Test initial genome
let initial = random_initial_choice(egraph, &mut rng);
prev_selected.insert(hash_choice_set(&initial));
if let Err(e) = validate_choice_set(egraph, &initial, ops) {
panic!("Initial genome invalid: {}", e);
}
// Extract and execute initial genome
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
let llir_graph = egglog_to_llir(
egraph,
initial.clone(),
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
let mut rt: CudaRuntime = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir_graph);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt.set_data(c, c_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&result, &expected, tol, tol);
println!("Initial genome: correct");
// If no mutable eclasses, only one valid graph exists
if mutable_eclasses == 0 {
println!("No mutable eclasses, only one valid graph - test passed");
return;
}
// Generate and test many genomes
let mut base = initial;
let mut tested = 0;
let target = 50;
for _generation in 0..100 {
let offspring = extract_generation(egraph, &base, 10, 2, &mut prev_selected, &mut rng);
if offspring.is_empty() {
println!("Search space exhausted");
break;
}
for genome in offspring {
// Validate
if let Err(e) = validate_choice_set(egraph, &genome, ops) {
panic!("Invalid genome: {}", e);
}
// Extract and execute
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
let llir_graph = egglog_to_llir(
egraph,
genome.clone(),
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
// Create fresh runtime for this genome
let mut rt: CudaRuntime = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir_graph);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt.set_data(c, c_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Verify correctness
assert_close(&result, &expected, tol, tol);
tested += 1;
base = genome;
if tested >= target {
break;
}
}
if tested >= target {
break;
}
}
println!(
"Fuzz test: verified {} genomes produce correct results",
tested
);
assert!(tested > 0, "No genomes were tested");
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(3))]
#[test]
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
fuzz_test_cuda_genomes_impl(seed);
}
}
fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64) {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping test");
return;
};
let mut cx = Graph::default();
let token_ids = cx.tensor(seq_len).as_dtype(luminal::dtype::DType::Int);
let embed_table = cx.tensor((vocab_size, embed_dim));
let output = embed_table
.gather(
(token_ids * embed_dim).expand_dim(1, embed_dim)
+ cx.arange(embed_dim).expand_dim(0, seq_len),
)
.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let token_data: Vec<i32> = random_i32_vec(seq_len, seed, 0, vocab_size as i32 - 1);
let embed_data: Vec<f32> = random_f32_vec(vocab_size * embed_dim, seed, -0.5, 0.5);
rt.set_data(token_ids, token_data.clone());
rt.set_data(embed_table, embed_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let mut expected = vec![0.0f32; seq_len * embed_dim];
for i in 0..seq_len {
let tid = token_data[i] as usize;
for j in 0..embed_dim {
expected[i * embed_dim + j] = embed_data[tid * embed_dim + j];
}
}
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&result, &expected, tol, tol);
// Fuzz genomes: verify multiple graph rewrites produce consistent results
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(token_ids, token_data.clone());
rt.set_data(embed_table, embed_data.clone());
},
output.id,
&expected,
tol,
tol,
GENOME_FUZZ_COUNT,
seed,
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
#[test]
fn test_embed_proptest(
vocab_size in 10usize..200,
embed_dim in 8usize..128,
seq_len in 1usize..32,
seed in any::<u64>(),
) {
run_embed_test(vocab_size, embed_dim, seq_len, seed);
}
}

View File

@@ -0,0 +1,94 @@
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use tracing::{Level, enabled};
use crate::cuda_bandwidth_gbps;
use crate::runtime::CudaRuntime;
/// Test that measures bandwidth utilization for a large element-wise add kernel.
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
#[test]
pub fn kernel_add_bandwidth_test() {
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
let size = 64 * 1024 * 1024;
let mut cx = Graph::default();
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let output = (a + b).output();
// Generate test data
let data_a: Vec<f32> = (0..size).map(|i| (i % 1000) as f32 * 0.001).collect();
let data_b: Vec<f32> = (0..size)
.map(|i| ((i + 500) % 1000) as f32 * 0.001)
.collect();
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
rt = cx.search(rt, 5);
// Warm up
rt.execute(&cx.dyn_map);
// Run and measure
rt.execute(&cx.dyn_map);
// Print stats
println!("\n=== Large KernelAdd Bandwidth Test ===");
println!(
"Tensor size: {} elements ({} MB per tensor)",
size,
size * 4 / 1024 / 1024
);
println!(
"Total memory traffic: {} MB (2 reads + 1 write)",
size * 4 * 3 / 1024 / 1024
);
if enabled!(Level::INFO) {
rt.print_execution_stats();
}
// Verify correctness (spot check)
let result = rt.get_f32(output);
for i in [0, size / 2, size - 1] {
let expected = data_a[i] + data_b[i];
let got = result[i];
assert!(
(got - expected).abs() < 1e-5,
"Mismatch at {}: expected {}, got {}",
i,
expected,
got
);
}
// Check bandwidth is reasonable (at least 50% of peak for large kernels)
if let Some(peak_bw) = cuda_bandwidth_gbps(&ctx) {
for stat in &rt.last_kernel_stats {
let total_bytes = stat.bytes_loaded + stat.bytes_stored;
if stat.name == "Add" && total_bytes > 0 {
let utilization = stat.bandwidth_gbps / peak_bw as f64 * 100.0;
println!(
"\nAdd kernel achieved {:.1} GB/s ({:.1}% of {:.0} GB/s peak)",
stat.bandwidth_gbps, utilization, peak_bw
);
println!(
" Loaded: {} bytes, Stored: {} bytes",
stat.bytes_loaded, stat.bytes_stored
);
// Large adds should achieve decent bandwidth
assert!(
utilization > 50.0,
"Bandwidth utilization too low: {:.1}%",
utilization
);
}
}
}
}

View File

@@ -0,0 +1,510 @@
//! Fuzz tests for small transformer models on CUDA.
//!
//! Builds a mini Llama-like transformer (RMSNorm + causal self-attention + SwiGLU MLP)
//! and verifies CUDA execution against a CPU reference implementation using candle.
use luminal::prelude::*;
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
use crate::runtime::CudaRuntime;
// ---- Tiny Llama hyperparameters ----
const SEQ: usize = 4;
const HIDDEN: usize = 16;
const INTERMEDIATE: usize = 32;
// ---- Graph-based mini transformer (Luminal) ----
/// RMSNorm: x * rsqrt(mean(x^2) + eps), optionally scaled by weight
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
let normed = x.std_norm(x.shape.last_axis(), eps);
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
}
/// Build self-attention using a simple single-head approach.
/// Input: (seq, hidden), outputs: (seq, hidden)
fn self_attention(
x: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
) -> GraphTensor {
// Project to Q, K, V: (seq, hidden) @ (hidden, hidden)^T = (seq, hidden)
let q = x.matmul(wq.t());
let k = x.matmul(wk.t());
let v = x.matmul(wv.t());
// Simple single-head scaled dot-product attention (no causal mask for simplicity)
let scale = 1.0 / (HIDDEN as f32).sqrt();
let scores = q.matmul(k.t()) * scale; // (seq, seq)
let attn_weights = scores.softmax(1); // softmax over key dim
// Apply attention to values and output projection
attn_weights.matmul(v).matmul(wo.t())
}
/// SwiGLU MLP: down(swish(gate(x)) * up(x))
fn swiglu_mlp(
x: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
) -> GraphTensor {
let gate = x.matmul(w_gate.t()).swish();
let up = x.matmul(w_up.t());
(gate * up).matmul(w_down.t())
}
/// Build a single transformer layer on the graph.
struct MiniTransformerLayer {
attn_norm_w: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
mlp_norm_w: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
}
impl MiniTransformerLayer {
fn init(cx: &mut Graph) -> Self {
Self {
attn_norm_w: cx.tensor(HIDDEN),
wq: cx.tensor((HIDDEN, HIDDEN)),
wk: cx.tensor((HIDDEN, HIDDEN)),
wv: cx.tensor((HIDDEN, HIDDEN)),
wo: cx.tensor((HIDDEN, HIDDEN)),
mlp_norm_w: cx.tensor(HIDDEN),
w_gate: cx.tensor((INTERMEDIATE, HIDDEN)),
w_up: cx.tensor((INTERMEDIATE, HIDDEN)),
w_down: cx.tensor((HIDDEN, INTERMEDIATE)),
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
// Pre-norm attention with residual
let normed = rms_norm(x, self.attn_norm_w, 1e-5);
let attn_out = self_attention(normed, self.wq, self.wk, self.wv, self.wo);
let x = x + attn_out;
// Pre-norm MLP with residual
let normed = rms_norm(x, self.mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp(normed, self.w_gate, self.w_up, self.w_down);
x + mlp_out
}
/// Return all weight tensors and their sizes for data loading
fn weights(&self) -> Vec<(GraphTensor, usize)> {
vec![
(self.attn_norm_w, HIDDEN),
(self.wq, HIDDEN * HIDDEN),
(self.wk, HIDDEN * HIDDEN),
(self.wv, HIDDEN * HIDDEN),
(self.wo, HIDDEN * HIDDEN),
(self.mlp_norm_w, HIDDEN),
(self.w_gate, INTERMEDIATE * HIDDEN),
(self.w_up, INTERMEDIATE * HIDDEN),
(self.w_down, HIDDEN * INTERMEDIATE),
]
}
}
// ---- Candle CPU reference ----
/// CPU reference for RMSNorm using candle
fn rms_norm_ref(
x: &candle_core::Tensor,
weight: &candle_core::Tensor,
eps: f64,
) -> candle_core::Tensor {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
let normed = x.broadcast_mul(&rsqrt).unwrap();
normed
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
.unwrap()
}
/// CPU reference for self-attention (single-head, no causal mask)
fn self_attention_ref(
x: &candle_core::Tensor,
wq: &candle_core::Tensor,
wk: &candle_core::Tensor,
wv: &candle_core::Tensor,
wo: &candle_core::Tensor,
) -> candle_core::Tensor {
let q = x.matmul(&wq.t().unwrap()).unwrap();
let k = x.matmul(&wk.t().unwrap()).unwrap();
let v = x.matmul(&wv.t().unwrap()).unwrap();
let scale = 1.0 / (HIDDEN as f64).sqrt();
let scores = q.matmul(&k.t().unwrap()).unwrap();
let scores = (scores * scale).unwrap();
// Softmax over key dimension (dim 1)
let max_val = scores.max(1).unwrap().unsqueeze(1).unwrap();
let shifted = scores.broadcast_sub(&max_val).unwrap();
let exps = shifted.exp().unwrap();
let sum_exps = exps.sum(1).unwrap().unsqueeze(1).unwrap();
let attn_weights = exps.broadcast_div(&sum_exps).unwrap();
attn_weights
.matmul(&v)
.unwrap()
.matmul(&wo.t().unwrap())
.unwrap()
}
/// CPU reference for SwiGLU MLP
fn swiglu_mlp_ref(
x: &candle_core::Tensor,
w_gate: &candle_core::Tensor,
w_up: &candle_core::Tensor,
w_down: &candle_core::Tensor,
) -> candle_core::Tensor {
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
let up = x.matmul(&w_up.t().unwrap()).unwrap();
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
/// CPU reference for one transformer layer
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &candle_core::Tensor,
attn_norm_w: &candle_core::Tensor,
wq: &candle_core::Tensor,
wk: &candle_core::Tensor,
wv: &candle_core::Tensor,
wo: &candle_core::Tensor,
mlp_norm_w: &candle_core::Tensor,
w_gate: &candle_core::Tensor,
w_up: &candle_core::Tensor,
w_down: &candle_core::Tensor,
) -> candle_core::Tensor {
let normed = rms_norm_ref(x, attn_norm_w, 1e-5);
let attn_out = self_attention_ref(&normed, wq, wk, wv, wo);
let x = (x + attn_out).unwrap();
let normed = rms_norm_ref(&x, mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp_ref(&normed, w_gate, w_up, w_down);
(x + mlp_out).unwrap()
}
// ---- Helper to generate weight data for a layer ----
fn generate_layer_weights(
layer: &MiniTransformerLayer,
base_seed: u64,
) -> Vec<(GraphTensor, Vec<f32>)> {
layer
.weights()
.iter()
.enumerate()
.map(|(i, (tensor, size))| {
let data = random_f32_vec(*size, base_seed + i as u64, -0.5, 0.5);
// RMSNorm weights should be initialized to ~1.0
let data = if *size == HIDDEN {
data.iter().map(|x| x + 1.0).collect::<Vec<_>>()
} else {
data
};
(*tensor, data)
})
.collect()
}
fn build_candle_ref(input_data: &[f32], weight_data: &[(GraphTensor, Vec<f32>)]) -> Vec<f32> {
let device = candle_core::Device::Cpu;
let ref_input =
candle_core::Tensor::from_vec(input_data.to_vec(), (SEQ, HIDDEN), &device).unwrap();
// weight_data: [attn_norm_w, wq, wk, wv, wo, mlp_norm_w, w_gate, w_up, w_down]
let w = |idx: usize, shape: &[usize]| {
candle_core::Tensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
};
let ref_attn_norm_w = w(0, &[HIDDEN]);
let ref_wq = w(1, &[HIDDEN, HIDDEN]);
let ref_wk = w(2, &[HIDDEN, HIDDEN]);
let ref_wv = w(3, &[HIDDEN, HIDDEN]);
let ref_wo = w(4, &[HIDDEN, HIDDEN]);
let ref_mlp_norm_w = w(5, &[HIDDEN]);
let ref_w_gate = w(6, &[INTERMEDIATE, HIDDEN]);
let ref_w_up = w(7, &[INTERMEDIATE, HIDDEN]);
let ref_w_down = w(8, &[HIDDEN, INTERMEDIATE]);
let expected = transformer_layer_ref(
&ref_input,
&ref_attn_norm_w,
&ref_wq,
&ref_wk,
&ref_wv,
&ref_wo,
&ref_mlp_norm_w,
&ref_w_gate,
&ref_w_up,
&ref_w_down,
);
expected.flatten_all().unwrap().to_vec1().unwrap()
}
// ---- Tests ----
/// Test a single transformer layer on CUDA against candle CPU reference.
#[test]
fn test_mini_transformer_layer() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let out = layer.forward(input).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
rt.set_data(input, input_data.clone());
let weight_data = generate_layer_weights(&layer, 100);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data.clone());
}
// Use minimal search iterations to avoid excessive graph rewriting
// which can cause float drift through softmax/RMSNorm reordering
rt = cx.search(rt, 1);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let expected = build_candle_ref(&input_data, &weight_data);
assert_close(&result, &expected, 1e-2, 1e-2);
}
/// Test a two-layer transformer on CUDA against candle CPU reference.
#[test]
fn test_mini_transformer_two_layers() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let layer1 = MiniTransformerLayer::init(&mut cx);
let layer2 = MiniTransformerLayer::init(&mut cx);
let x = layer1.forward(input).graph_break();
let out = layer2.forward(x).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
rt.set_data(input, input_data.clone());
let layer1_weights = generate_layer_weights(&layer1, 200);
let layer2_weights = generate_layer_weights(&layer2, 300);
for (tensor, data) in layer1_weights.iter().chain(layer2_weights.iter()) {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, 1);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
// Run two layers on CPU reference
let device = candle_core::Device::Cpu;
let mut ref_x = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
for weights in [&layer1_weights, &layer2_weights] {
let w = |idx: usize, shape: &[usize]| {
candle_core::Tensor::from_vec(weights[idx].1.clone(), shape, &device).unwrap()
};
ref_x = transformer_layer_ref(
&ref_x,
&w(0, &[HIDDEN]),
&w(1, &[HIDDEN, HIDDEN]),
&w(2, &[HIDDEN, HIDDEN]),
&w(3, &[HIDDEN, HIDDEN]),
&w(4, &[HIDDEN, HIDDEN]),
&w(5, &[HIDDEN]),
&w(6, &[INTERMEDIATE, HIDDEN]),
&w(7, &[INTERMEDIATE, HIDDEN]),
&w(8, &[HIDDEN, INTERMEDIATE]),
);
}
let expected: Vec<f32> = ref_x.flatten_all().unwrap().to_vec1().unwrap();
// Two layers accumulate more drift
assert_close(&result, &expected, 2e-2, 2e-2);
}
/// Test the transformer with multiple random data seeds to catch data-dependent bugs.
#[test]
fn test_transformer_multi_seed() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
for seed in [42u64, 99, 777] {
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let out = layer.forward(input).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
rt.set_data(input, input_data.clone());
let weight_data = generate_layer_weights(&layer, seed + 100);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, 1);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let expected = build_candle_ref(&input_data, &weight_data);
assert_close(&result, &expected, 1e-2, 1e-2);
}
}
/// Test just the RMSNorm component on CUDA
#[test]
fn test_rms_norm_cuda() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let weight = cx.tensor(HIDDEN);
let out = rms_norm(input, weight, 1e-5).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 1, -0.5, 0.5);
let weight_data: Vec<f32> = random_f32_vec(HIDDEN, 2, -0.5, 0.5)
.iter()
.map(|x| x + 1.0)
.collect();
rt.set_data(input, input_data.clone());
rt.set_data(weight, weight_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
let ref_weight = candle_core::Tensor::from_vec(weight_data, HIDDEN, &device).unwrap();
let expected = rms_norm_ref(&ref_input, &ref_weight, 1e-5);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3, 1e-3);
}
/// Test just the self-attention on CUDA
#[test]
fn test_self_attention_cuda() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let wq = cx.tensor((HIDDEN, HIDDEN));
let wk = cx.tensor((HIDDEN, HIDDEN));
let wv = cx.tensor((HIDDEN, HIDDEN));
let wo = cx.tensor((HIDDEN, HIDDEN));
let out = self_attention(input, wq, wk, wv, wo).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 10, -0.5, 0.5);
let wq_data = random_f32_vec(HIDDEN * HIDDEN, 11, -0.5, 0.5);
let wk_data = random_f32_vec(HIDDEN * HIDDEN, 12, -0.5, 0.5);
let wv_data = random_f32_vec(HIDDEN * HIDDEN, 13, -0.5, 0.5);
let wo_data = random_f32_vec(HIDDEN * HIDDEN, 14, -0.5, 0.5);
rt.set_data(input, input_data.clone());
rt.set_data(wq, wq_data.clone());
rt.set_data(wk, wk_data.clone());
rt.set_data(wv, wv_data.clone());
rt.set_data(wo, wo_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
let ref_wq = candle_core::Tensor::from_vec(wq_data, (HIDDEN, HIDDEN), &device).unwrap();
let ref_wk = candle_core::Tensor::from_vec(wk_data, (HIDDEN, HIDDEN), &device).unwrap();
let ref_wv = candle_core::Tensor::from_vec(wv_data, (HIDDEN, HIDDEN), &device).unwrap();
let ref_wo = candle_core::Tensor::from_vec(wo_data, (HIDDEN, HIDDEN), &device).unwrap();
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2, 1e-2);
}
/// Test just the SwiGLU MLP on CUDA
#[test]
fn test_swiglu_mlp_cuda() {
let Some(stream) = get_cuda_stream() else {
println!("CUDA not available, skipping");
return;
};
let mut cx = Graph::default();
let input = cx.tensor((SEQ, HIDDEN));
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = random_f32_vec(SEQ * HIDDEN, 20, -0.5, 0.5);
let gate_data = random_f32_vec(INTERMEDIATE * HIDDEN, 21, -0.5, 0.5);
let up_data = random_f32_vec(INTERMEDIATE * HIDDEN, 22, -0.5, 0.5);
let down_data = random_f32_vec(HIDDEN * INTERMEDIATE, 23, -0.5, 0.5);
rt.set_data(input, input_data.clone());
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let device = candle_core::Device::Cpu;
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
let ref_gate =
candle_core::Tensor::from_vec(gate_data, (INTERMEDIATE, HIDDEN), &device).unwrap();
let ref_up = candle_core::Tensor::from_vec(up_data, (INTERMEDIATE, HIDDEN), &device).unwrap();
let ref_down =
candle_core::Tensor::from_vec(down_data, (HIDDEN, INTERMEDIATE), &device).unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3, 1e-3);
}

View File

@@ -0,0 +1,496 @@
use candle_core::{Device, Tensor, WithDType};
use cudarc::driver::CudaContext;
use half::{bf16, f16};
use luminal::egglog_utils::{
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
};
use luminal::prelude::*;
use num_traits::{Num, Signed};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::sync::Arc;
use crate::runtime::{CudaRuntime, ToCudaInput};
/// Safety factor multiplied with epsilon for tolerance calculations
pub const TOLERANCE_SAFETY_FACTOR: f32 = 2.0;
/// Number of genomes to fuzz per op test invocation.
pub const GENOME_FUZZ_COUNT: usize = 20;
/// Trait for test-compatible data types that can be used in generic test functions.
/// Bridges luminal's runtime types with candle's tensor types.
pub trait TestDType:
Clone + Sized + WithDType + PartialEq + Copy + std::fmt::Debug + 'static
where
Vec<Self>: ToCudaInput,
{
/// The corresponding luminal DType
const DTYPE: luminal::dtype::DType;
/// Retrieve data from the runtime in this dtype
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self>;
/// Extract a Vec from a candle Tensor
fn candle_to_vec(tensor: &Tensor) -> Vec<Self>;
/// Compare two result vectors. Float types use tolerance; exact types use equality.
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32);
}
impl TestDType for f32 {
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F32;
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
rt.get_f32(id)
}
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
tensor.to_vec1::<f32>().unwrap()
}
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
assert_close(a, b, rtol, atol);
}
}
impl TestDType for f16 {
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F16;
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
rt.get_f16(id)
}
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
tensor.to_vec1::<f16>().unwrap()
}
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
assert_close(a, b, f16::from_f32(rtol), f16::from_f32(atol));
}
}
impl TestDType for bf16 {
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Bf16;
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
rt.get_bf16(id)
}
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
tensor.to_vec1::<bf16>().unwrap()
}
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
assert_close(a, b, bf16::from_f32(rtol), bf16::from_f32(atol));
}
}
impl TestDType for i32 {
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Int;
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
rt.get_i32(id)
}
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
tensor.to_vec1::<i32>().unwrap()
}
fn assert_match(a: &[Self], b: &[Self], _rtol: f32, _atol: f32) {
assert_eq!(a, b);
}
}
#[allow(dead_code)]
pub fn random_i32_vec(n: usize, seed: u64, low: i32, high: i32) -> Vec<i32> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n).map(|_| rng.random_range(low..=high)).collect()
}
pub fn random_f32_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<f32> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n).map(|_| rng.random_range(low..high)).collect()
}
/// Assert two vectors are close following NumPy/PyTorch conventions.
/// Formula: |a - b| <= atol + rtol * |b|
/// Generic version that works with any Float type (f32, f16, bf16).
pub fn assert_close<T: Num + Signed + PartialOrd + Copy + std::fmt::Display>(
a_vec: &[T],
b_vec: &[T],
rtol: T,
atol: T,
) {
assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match");
for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
let diff = (*a - *b).abs();
let tolerance = atol + rtol * b.abs();
if diff > tolerance {
panic!("{a} is not close to {b}, index {i}, diff: {diff}, tolerance: {tolerance}");
}
}
}
pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
let ctx = CudaContext::new(0).ok()?;
ctx.bind_to_thread().ok()?;
Some(ctx.default_stream())
}
/// Get the GPU compute capability as (major, minor).
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
let ctx = CudaContext::new(0).ok()?;
ctx.compute_capability().ok()
}
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
let Some((major, _)) = gpu_compute_cap() else {
return false;
};
match dtype {
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
luminal::dtype::DType::F4E2M1
| luminal::dtype::DType::F8E4M3
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
_ => true,
}
}
/// Machine epsilon for each dtype (approximate)
pub fn dtype_epsilon(dtype: luminal::dtype::DType) -> f32 {
match dtype {
luminal::dtype::DType::F32 => 1.19e-7, // 2^-23
luminal::dtype::DType::F16 => 9.77e-4, // 2^-10
luminal::dtype::DType::Bf16 => 7.81e-3, // 2^-7
luminal::dtype::DType::Int => 0.0,
luminal::dtype::DType::Bool => 0.0,
other => todo!("dtype_epsilon not implemented for {other}"),
}
}
/// Map a luminal DType to the corresponding candle DType.
pub fn to_candle_dtype(dtype: luminal::dtype::DType) -> candle_core::DType {
match dtype {
luminal::dtype::DType::F32 => candle_core::DType::F32,
luminal::dtype::DType::F16 => candle_core::DType::F16,
luminal::dtype::DType::Bf16 => candle_core::DType::BF16,
luminal::dtype::DType::Int => candle_core::DType::I32,
luminal::dtype::DType::Bool => candle_core::DType::U8,
other => todo!("candle dtype mapping not implemented for {other}"),
}
}
/// Base unary test function with input generator (CUDA version)
/// Generic over dtype T - comparison happens in native precision.
pub fn test_unary_cuda<T: TestDType>(
shape: impl ToShape,
func: impl Fn(GraphTensor) -> GraphTensor,
ref_func: impl Fn(Tensor) -> Tensor,
generator: impl Fn(usize, u64) -> Vec<T>,
seed: u64,
) where
Vec<T>: ToCudaInput,
{
let Some(stream) = get_cuda_stream() else {
return;
};
let shape: Vec<usize> = shape
.to_shape()
.into_iter()
.map(|e| e.to_usize().unwrap())
.collect();
let n_elements: usize = shape.iter().product();
let mut cx = Graph::default();
let a = cx.tensor(shape.clone());
let b = func(a).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = generator(n_elements, seed);
rt.set_data(a, input_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, b.id);
// Reference using candle on CUDA
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
let ref_a = Tensor::from_slice(&input_data, shape, &device).unwrap();
let ref_b = ref_func(ref_a).flatten_all().unwrap();
let ref_vec = T::candle_to_vec(&ref_b);
let eps = dtype_epsilon(<T as TestDType>::DTYPE);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
T::assert_match(&result, &ref_vec, tol, tol);
// Fuzz genomes: verify multiple graph rewrites produce consistent results
fuzz_genomes::<T>(
&cx,
&stream,
|rt| rt.set_data(a, input_data.clone()),
b.id,
&ref_vec,
tol,
tol,
GENOME_FUZZ_COUNT,
seed,
);
}
/// Base binary test function with input generators
/// Generic over dtype T - comparison happens in native precision.
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
#[allow(clippy::too_many_arguments)]
pub fn test_binary_cuda<T: TestDType>(
a_shape: impl ToShape,
b_shape: impl ToShape,
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
ref_func: impl Fn(Tensor, Tensor) -> Tensor,
a_generator: impl Fn(usize, u64) -> Vec<T>,
b_generator: impl Fn(usize, u64) -> Vec<T>,
seed: u64,
rtol: f32,
atol: f32,
) where
Vec<T>: ToCudaInput,
{
let Some(stream) = get_cuda_stream() else {
return;
};
let a_shape: Vec<usize> = a_shape
.to_shape()
.into_iter()
.map(|e| e.to_usize().unwrap())
.collect();
let b_shape: Vec<usize> = b_shape
.to_shape()
.into_iter()
.map(|e| e.to_usize().unwrap())
.collect();
let a_elements: usize = a_shape.iter().product();
let b_elements: usize = b_shape.iter().product();
let mut cx = Graph::default();
let a: GraphTensor = cx.tensor(a_shape.clone());
let b = cx.tensor(b_shape.clone());
let c = func(a, b).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let a_data = a_generator(a_elements, seed);
let b_data = b_generator(b_elements, seed.wrapping_add(1));
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, c.id);
// Reference using candle on CUDA
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
let ref_a = Tensor::from_slice(&a_data, a_shape, &device).unwrap();
let ref_b = Tensor::from_slice(&b_data, b_shape, &device).unwrap();
let ref_c = ref_func(ref_a, ref_b).flatten_all().unwrap();
let ref_vec = T::candle_to_vec(&ref_c);
T::assert_match(&result, &ref_vec, rtol, atol);
// Fuzz genomes: verify multiple graph rewrites produce consistent results
fuzz_genomes::<T>(
&cx,
&stream,
|rt| {
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
},
c.id,
&ref_vec,
rtol,
atol,
GENOME_FUZZ_COUNT,
seed,
);
}
/// Test mod operation with element-wise reference using Rust's % operator
pub fn test_mod(
a_shape: impl ToShape,
b_shape: impl ToShape,
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
seed: u64,
) {
let Some(stream) = get_cuda_stream() else {
return;
};
let a_shape: Vec<usize> = a_shape
.to_shape()
.into_iter()
.map(|e| e.to_usize().unwrap())
.collect();
let b_shape: Vec<usize> = b_shape
.to_shape()
.into_iter()
.map(|e| e.to_usize().unwrap())
.collect();
let a_elements: usize = a_shape.iter().product();
let b_elements: usize = b_shape.iter().product();
let mut cx = Graph::default();
let a = cx.tensor(a_shape.clone());
let b = cx.tensor(b_shape.clone());
let c = func(a, b).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
// Generate divisor values away from zero (0.1 to 0.5) to avoid division issues
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(c);
// Reference: Rust's % operator matches CUDA's fmodf (IEEE 754 remainder)
let expected: Vec<f32> = a_data
.iter()
.zip(b_data.iter())
.map(|(x, y)| x % y)
.collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let rtol = eps * TOLERANCE_SAFETY_FACTOR;
let atol = eps * TOLERANCE_SAFETY_FACTOR;
assert_close(&result, &expected, rtol, atol);
// Fuzz genomes: verify multiple graph rewrites produce consistent results
fuzz_genomes::<f32>(
&cx,
&stream,
|rt| {
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
},
c.id,
&expected,
rtol,
atol,
GENOME_FUZZ_COUNT,
seed,
);
}
/// Generate a slice range for an axis of given size.
/// If do_start is true, randomly choose a start offset (leaving at least 1 element).
/// If do_end is true, randomly choose an end before the axis end.
pub fn gen_slice_range(
size: usize,
do_start: bool,
do_end: bool,
rng: &mut impl Rng,
) -> (usize, usize) {
let start = if do_start && size > 1 {
rng.random_range(0..size)
} else {
0
};
let remaining = size - start;
let end = if do_end && remaining > 1 {
start + rng.random_range(1..remaining)
} else {
size
};
(start, end)
}
/// Fuzz test multiple genomes from the e-graph search space.
///
/// After a graph has been built and compared against a reference, this function
/// extracts random genomes via mutation and verifies they all produce results
/// matching the expected reference output. This catches bugs where graph rewrites
/// produce incorrect computation.
///
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
#[allow(clippy::too_many_arguments)]
pub fn fuzz_genomes<T: TestDType>(
cx: &Graph,
stream: &Arc<cudarc::driver::CudaStream>,
setup_inputs: impl Fn(&mut CudaRuntime),
output_id: NodeIndex,
expected: &[T],
rtol: f32,
atol: f32,
num_genomes: usize,
seed: u64,
) where
Vec<T>: ToCudaInput,
{
let Some(egraph) = cx.egraph() else {
return;
};
let Some(ops) = cx.egglog_ops() else {
return;
};
// Check if there are alternative genomes to explore
let mutable_eclasses: usize = egraph
.eclasses
.iter()
.filter(|(_, (label, enodes))| {
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
})
.count();
if mutable_eclasses == 0 {
return; // Only one valid graph, nothing to fuzz
}
// Use a different seed offset to avoid correlating with the search seed
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(7777));
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
let initial = random_initial_choice(egraph, &mut rng);
prev_selected.insert(hash_choice_set(&initial));
let mut base = initial;
let mut tested = 0;
for _ in 0..100 {
let offspring = extract_generation(egraph, &base, 10, 2, &mut prev_selected, &mut rng);
if offspring.is_empty() {
break;
}
for genome in offspring {
if validate_choice_set(egraph, &genome, ops).is_err() {
continue;
}
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
let llir_graph = egglog_to_llir(
egraph,
genome.clone(),
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir_graph);
setup_inputs(&mut rt);
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, output_id);
T::assert_match(&result, expected, rtol, atol);
tested += 1;
base = genome;
if tested >= num_genomes {
return;
}
}
}
}

View File

@@ -0,0 +1,22 @@
[package]
name = "luminal_metal"
version = "0.2.0"
edition = "2021"
description = "Metal backend for luminal"
license = "MIT OR Apache-2.0"
[dependencies]
luminal = { path = "../.." }
metal = "0.31"
objc = "0.2"
as-any = "0.3.2"
itertools = "0.12.1"
half = "2.7.1"
tracing = "0.1.43"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
proptest = "1.9.0"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }

View File

@@ -0,0 +1,227 @@
use super::{MetalMulInfo, MetalSumReduceInfo};
use luminal::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MetalMatmulFamily {
#[default]
Naive,
RegularTiled,
}
#[derive(Debug, Clone)]
pub struct MatmulDescriptor {
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub batch_shape: Vec<Expression>,
pub lhs_strides: Vec<Expression>,
pub rhs_strides: Vec<Expression>,
pub out_strides: Vec<Expression>,
pub transpose_lhs: bool,
pub transpose_rhs: bool,
}
impl MatmulDescriptor {
pub fn from_mul_and_sum(
mul_info: &MetalMulInfo,
sum_info: &MetalSumReduceInfo,
) -> Option<Self> {
let zero = Expression::from(0);
let z = Expression::from('z');
let is_simple_2d_matmul = mul_info.shape.len() == 3
&& sum_info.shape.len() == 2
&& mul_info.a_strides.len() == 3
&& mul_info.b_strides.len() == 3
&& sum_info.strides.len() == 2
&& mul_info.shape[0] == sum_info.shape[0]
&& mul_info.shape[1] == sum_info.shape[1]
&& mul_info.shape[2] == sum_info.iters
&& mul_info.a_strides[1] == zero
&& mul_info.a_strides[2] == z
&& mul_info.b_strides[0] == zero
&& mul_info.b_strides[1] == z
&& sum_info.strides[1] == z
&& sum_info.iter_stride == z;
if !is_simple_2d_matmul {
return None;
}
Some(Self {
m: sum_info.shape[0],
n: sum_info.shape[1],
k: sum_info.iters,
batch_shape: Vec::new(),
lhs_strides: mul_info.a_strides.clone(),
rhs_strides: mul_info.b_strides.clone(),
out_strides: sum_info.strides.clone(),
transpose_lhs: false,
transpose_rhs: false,
})
}
}
#[derive(Debug, Clone)]
pub struct MatmulPlan {
pub family: MetalMatmulFamily,
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub lda: Expression,
pub ldb: Expression,
pub ldd: Expression,
pub batch_size: u32,
pub batch_stride_a: u32,
pub batch_stride_b: u32,
pub batch_stride_d: u32,
pub bm: u16,
pub bn: u16,
pub bk: u16,
pub wm: u16,
pub wn: u16,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct MetalMatmulPlanner;
impl MetalMatmulPlanner {
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
let family = if desc.batch_shape.is_empty()
&& desc.m.as_num().is_some_and(|m| m >= 32)
&& desc.n.as_num().is_some_and(|n| n >= 32)
&& desc.k.as_num().is_some_and(|k| k >= 32)
{
MetalMatmulFamily::RegularTiled
} else {
MetalMatmulFamily::Naive
};
MatmulPlan {
family,
m: desc.m,
n: desc.n,
k: desc.k,
lda: desc.lhs_strides[0],
ldb: desc.rhs_strides[2],
ldd: desc.out_strides[0],
batch_size: 1,
batch_stride_a: 0,
batch_stride_b: 0,
batch_stride_d: 0,
bm: 16,
bn: 16,
bk: 8,
wm: 2,
wn: 2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn descriptor_recovers_simple_2d_matmul() {
let mul = MetalMulInfo {
shape: vec![
Expression::from(4),
Expression::from(8),
Expression::from(16),
],
a_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
b_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
output_strides: vec![
Expression::from('z') * 16,
Expression::from('z') * 8,
Expression::from('z'),
],
};
let sum = MetalSumReduceInfo {
shape: vec![Expression::from(4), Expression::from(8)],
strides: vec![Expression::from('z') * 8, Expression::from('z')],
iters: Expression::from(16),
iter_stride: Expression::from('z'),
};
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
assert_eq!(desc.m, Expression::from(4));
assert_eq!(desc.n, Expression::from(8));
assert_eq!(desc.k, Expression::from(16));
}
#[test]
fn planner_keeps_small_problems_on_naive_path() {
let desc = MatmulDescriptor {
m: Expression::from(4),
n: Expression::from(8),
k: Expression::from(16),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::Naive);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
assert_eq!(plan.lda, Expression::from('z') * 16);
assert_eq!(plan.ldb, Expression::from('z') * 8);
assert_eq!(plan.ldd, Expression::from('z') * 8);
}
#[test]
fn planner_promotes_large_problems_to_regular_tiled() {
let desc = MatmulDescriptor {
m: Expression::from(64),
n: Expression::from(64),
k: Expression::from(64),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 64,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 64,
],
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
}
}

View File

@@ -0,0 +1,81 @@
mod matmul;
mod ops;
pub use matmul::*;
pub use ops::*;
use luminal::dtype::DType;
use luminal::op::EgglogOp;
use luminal::prelude::*;
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
pub const DYN_SLOT_COUNT: usize = 26;
#[derive(Debug, Clone)]
pub struct MetalMulInfo {
pub shape: Vec<Expression>,
pub a_strides: Vec<Expression>,
pub b_strides: Vec<Expression>,
pub output_strides: Vec<Expression>,
}
#[derive(Debug, Clone)]
pub struct MetalSumReduceInfo {
pub shape: Vec<Expression>,
pub strides: Vec<Expression>,
pub iters: Expression,
pub iter_stride: Expression,
}
pub trait MetalKernelOp: EgglogOp {
fn compile(
&self,
device: &Device,
input_dtypes: &[DType],
output_dtype: DType,
) -> ComputePipelineState;
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
input_dtypes.first().copied().unwrap_or(DType::F32)
}
fn output_size(&self) -> Expression;
fn encode(
&self,
encoder: &ComputeCommandEncoderRef,
pipeline: &ComputePipelineState,
inputs: &[&Buffer],
output: &Buffer,
dyn_map: &FxHashMap<char, usize>,
);
// ========================================================================
// Performance Metrics for MBU/MFU Calculation
// ========================================================================
fn bytes_loaded(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
0
}
fn bytes_stored(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
0
}
fn flops(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
0
}
fn mul_info(&self) -> Option<MetalMulInfo> {
None
}
fn sum_reduce_info(&self) -> Option<MetalSumReduceInfo> {
None
}
fn is_matmul(&self) -> bool {
false
}
}
luminal::impl_into_ops!(MetalKernelOp);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
pub mod kernel;
pub mod runtime;
#[cfg(test)]
mod tests;
pub use metal::{Buffer, Device, MTLResourceOptions};
pub use objc::rc::autoreleasepool;
pub use runtime::MetalRuntime;
// Re-export kernel ops
pub use kernel::MetalOps;

View File

@@ -0,0 +1,555 @@
use crate::kernel::{
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
};
use half::f16;
use itertools::Itertools;
use luminal::{
dtype::DType,
graph::LLIRGraph,
hlir::{Input, NativeData, Output},
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
prelude::{
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
FxHashMap, NodeIndex, ToId,
},
};
use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptions};
use objc::runtime::Object;
use std::time::Duration;
pub struct MetalRuntime {
device: Device,
command_queue: CommandQueue,
/// Host-side input tensors provided by the user.
input_data: FxHashMap<NodeIndex, NativeData>,
/// Buffers for HLIR input tensors (set by user)
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
/// Buffers for LLIR intermediate/output tensors
pub buffers: FxHashMap<NodeIndex, Buffer>,
/// Dynamic dimensions table (a-z), shared across all kernels.
dyn_buffer: Buffer,
/// The current LLIR graph
llir_graph: LLIRGraph,
/// Inferred runtime dtype for each LLIR node.
node_dtypes: FxHashMap<NodeIndex, DType>,
/// Compiled pipeline states for each kernel node
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
}
impl MetalRuntime {
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
let mut graph = llir_graph.clone();
let planner = MetalMatmulPlanner;
let mut rewrites = Vec::new();
for sum_node in graph.node_indices().collect::<Vec<_>>() {
let Some(sum_info) = graph[sum_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.sum_reduce_info())
else {
continue;
};
let input_edges: Vec<_> = graph
.edges_directed(sum_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if input_edges.len() != 1 {
continue;
}
let mul_node = input_edges[0];
let Some(mul_info) = graph[mul_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.mul_info())
else {
continue;
};
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
continue;
};
let mul_inputs: Vec<_> = graph
.edges_directed(mul_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if mul_inputs.len() != 2 {
continue;
}
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
}
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
graph[sum_node] =
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
m: plan.m,
n: plan.n,
k: plan.k,
lda: plan.lda,
ldb: plan.ldb,
ldd: plan.ldd,
family: plan.family,
bm: plan.bm,
bn: plan.bn,
bk: plan.bk,
wm: plan.wm,
wn: plan.wn,
batch_size: plan.batch_size,
batch_stride_a: plan.batch_stride_a,
batch_stride_b: plan.batch_stride_b,
batch_stride_d: plan.batch_stride_d,
}));
graph.remove_node(mul_node);
graph.add_edge(mul_inputs[0], sum_node, ());
graph.add_edge(mul_inputs[1], sum_node, ());
}
graph
}
#[cfg(test)]
pub(crate) fn contains_matmul(&self) -> bool {
self.llir_graph.node_indices().any(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.is_some_and(|op| op.is_matmul())
})
}
#[cfg(test)]
pub(crate) fn debug_kernel_ops(&self) -> Vec<String> {
self.llir_graph
.node_indices()
.filter_map(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.map(|op| format!("{op:?}"))
})
.collect()
}
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
self.input_data.insert(id.to_id(), data.into());
}
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
let id = id.to_id();
let output_id = self
.llir_graph
.node_indices()
.find(|n| {
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
*node == id.index()
} else {
false
}
})
.expect("Cannot find output tensor!");
let data_id = self
.llir_graph
.neighbors_directed(output_id, Direction::Incoming)
.next()
.unwrap();
let buffer = self
.buffers
.get(&data_id)
.or_else(|| {
// If data_id is an Input node, get from hlir_buffers
if let Some(Input { node, .. }) = self.llir_graph[data_id].to_op::<Input>() {
self.hlir_buffers.get(&NodeIndex::new(*node))
} else {
None
}
})
.expect("Cannot find tensor in runtime!");
let dtype = self
.node_dtypes
.get(&data_id)
.copied()
.or_else(|| {
self.llir_graph[data_id]
.to_op::<Input>()
.map(|inp| inp.dtype)
})
.unwrap_or(DType::F32);
unsafe {
match dtype {
DType::F16 => {
let ptr = buffer.contents() as *const f16;
let len = buffer.length() as usize / std::mem::size_of::<f16>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| v.to_f32())
.collect()
}
DType::Int => {
let ptr = buffer.contents() as *const i32;
let len = buffer.length() as usize / std::mem::size_of::<i32>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| *v as f32)
.collect()
}
_ => {
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
std::slice::from_raw_parts(ptr, len).to_vec()
}
}
}
}
}
impl Runtime for MetalRuntime {
type Ops = crate::kernel::MetalOps;
type CompileArg = ();
type ExecReturn = ();
type ProfileMetric = Duration;
fn initialize(_: Self::CompileArg) -> Self {
let device = Device::system_default().expect("No Metal device found!");
let command_queue = device.new_command_queue();
let dyn_buffer = device.new_buffer(
(DYN_SLOT_COUNT * std::mem::size_of::<i32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
Self {
device,
command_queue,
input_data: FxHashMap::default(),
hlir_buffers: FxHashMap::default(),
buffers: FxHashMap::default(),
dyn_buffer,
llir_graph: StableGraph::default(),
node_dtypes: FxHashMap::default(),
pipelines: FxHashMap::default(),
}
}
#[tracing::instrument(skip_all)]
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
self.pipelines.clear();
self.buffers.clear();
self.hlir_buffers.clear();
self.node_dtypes.clear();
self.llir_graph = Self::fuse_matmuls(llir_graph);
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
for node in topo_order {
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
self.node_dtypes.insert(node, input.dtype);
let hlir_id = NodeIndex::new(input.node);
if let Some(data) = self.input_data.get(&hlir_id) {
let buffer = self.create_input_buffer(data, input.dtype);
self.hlir_buffers.insert(hlir_id, buffer);
}
continue;
}
if self.llir_graph[node].to_op::<Output>().is_some() {
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let input_dtypes: Vec<DType> = input_nodes
.iter()
.map(|n| {
self.node_dtypes
.get(n)
.copied()
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
})
.collect();
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
self.node_dtypes.insert(node, output_dtype);
self.pipelines.insert(node, pipeline);
}
}
}
#[tracing::instrument(skip_all)]
fn profile(
&mut self,
llir_graph: &LLIRGraph,
dyn_map: &FxHashMap<char, usize>,
trials: usize,
) -> (Self::ProfileMetric, String) {
self.load_llir(llir_graph);
self.allocate_intermediate_buffers(dyn_map);
let trials = trials.max(1);
let mut duration = Duration::default();
for _ in 0..trials {
let start = std::time::Instant::now();
self.execute(dyn_map);
duration += start.elapsed();
}
duration /= trials as u32;
(duration, format!("{:.2?}", duration))
}
#[tracing::instrument(skip_all)]
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) -> Self::ExecReturn {
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
.llir_graph
.node_indices()
.filter_map(|n| {
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
Some((n, NodeIndex::new(*node)))
} else {
None
}
})
.collect();
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
|| self.llir_graph[node].to_op::<Output>().is_some()
{
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let input_buffers: Vec<&Buffer> = input_nodes
.iter()
.map(|&n| {
if let Some(hlir_node) = llir_to_hlir.get(&n) {
self.hlir_buffers
.get(hlir_node)
.expect("Input buffer not set!")
} else {
self.buffers
.get(&n)
.expect("Intermediate buffer not found!")
}
})
.collect();
let output_buffer = self
.buffers
.get(&node)
.expect("Output buffer not allocated!");
// Bind dyn dims right after the output slot:
// [inputs..., output, dyn, bytes...]
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
}
}
impl RuntimeStats for MetalRuntime {
fn execute_with_stats(&mut self, dyn_map: &FxHashMap<char, usize>) -> Option<ExecutionStats> {
let mut total_bytes_loaded = 0usize;
let mut total_bytes_stored = 0usize;
let mut total_flops = 0usize;
for node in self.llir_graph.node_indices() {
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
total_bytes_loaded += kernel_op.bytes_loaded(dyn_map);
total_bytes_stored += kernel_op.bytes_stored(dyn_map);
total_flops += kernel_op.flops(dyn_map);
}
}
let (time_us, timing_method) = self.execute_timed(dyn_map);
Some(ExecutionStats::with_timing_method(
time_us,
total_bytes_loaded,
total_bytes_stored,
total_flops,
timing_method,
))
}
}
impl MetalRuntime {
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
match dtype {
DType::F32 => {
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::F16 => {
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::Int => {
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
unsupported => panic!("Metal input dtype {unsupported:?} is not supported yet"),
}
}
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
for node in self.llir_graph.node_indices() {
if self.llir_graph[node].to_op::<Input>().is_some() {
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let size = kernel_op.output_size().exec(dyn_map).unwrap();
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
let buffer = self.device.new_buffer(
(size * dtype.bits().div_ceil(8)) as u64,
MTLResourceOptions::StorageModeShared,
);
self.buffers.insert(node, buffer);
}
}
}
fn update_dyn_buffer(&mut self, dyn_map: &FxHashMap<char, usize>) {
let ptr = self.dyn_buffer.contents() as *mut i32;
unsafe {
for idx in 0..DYN_SLOT_COUNT {
*ptr.add(idx) = 0;
}
for (&symbol, &value) in dyn_map {
if symbol.is_ascii_lowercase() {
let slot = (symbol as u8 - b'a') as usize;
if slot < DYN_SLOT_COUNT {
*ptr.add(slot) = value as i32;
}
}
}
}
}
/// Execute and return GPU-side execution time in microseconds.
fn execute_timed(&mut self, dyn_map: &FxHashMap<char, usize>) -> (f64, TimingMethod) {
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
.llir_graph
.node_indices()
.filter_map(|n| {
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
Some((n, NodeIndex::new(*node)))
} else {
None
}
})
.collect();
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
|| self.llir_graph[node].to_op::<Output>().is_some()
{
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let input_buffers: Vec<&Buffer> = input_nodes
.iter()
.map(|&n| {
if let Some(hlir_node) = llir_to_hlir.get(&n) {
self.hlir_buffers
.get(hlir_node)
.expect("Input buffer not set!")
} else {
self.buffers
.get(&n)
.expect("Intermediate buffer not found!")
}
})
.collect();
let output_buffer = self
.buffers
.get(&node)
.expect("Output buffer not allocated!");
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
// gpuStartTime and gpuEndTime are available on macOS 10.15+
let gpu_start: f64 = unsafe {
use objc::{msg_send, sel, sel_impl};
let ptr = command_buffer as *const _ as *mut Object;
msg_send![ptr, GPUStartTime]
};
let gpu_end: f64 = unsafe {
use objc::{msg_send, sel, sel_impl};
let ptr = command_buffer as *const _ as *mut Object;
msg_send![ptr, GPUEndTime]
};
let gpu_time_seconds = gpu_end - gpu_start;
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
(gpu_time_us, TimingMethod::DeviceTimestamp)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
[package]
name = "luminal_nn"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
itertools = "0.12.1"
luminal = { path = "../.." }
rustc-hash = "1.1.0"
rand = "0.9.2"
[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
paste = "1.0.14"
candle-core = "0.9.2-alpha.1"

View File

@@ -0,0 +1,125 @@
use luminal::prelude::*;
/// Rectified Linear Unit activation function
#[derive(Default)]
pub struct ReLU;
impl ReLU {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
input.relu()
}
}
/// Gaussian Error Linear Unit activation function
#[derive(Default)]
pub struct GeLU;
impl GeLU {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
input.gelu()
}
}
/// Sigmoid activation function
#[derive(Default)]
pub struct Sigmoid;
impl Sigmoid {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
input.sigmoid()
}
}
/// Swish activation function
#[derive(Default)]
pub struct Swish;
impl Swish {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
input.swish()
}
}
/// Tanh activation function
#[derive(Default)]
pub struct Tanh;
impl Tanh {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
input.tanh()
}
}
// #[cfg(test)]
// mod tests {
// use super::ReLU;
// use crate::Linear;
// use dfdx::prelude::{Module as DfdxModule, *};
// use luminal::{
// prelude::{Module, *},
// tests::assert_close,
// };
// #[test]
// fn test_relu_and_linear() {
// // Test single and batch, unoptimized and optimized
// let mut cx = Graph::new();
// let batch = cx.tensor((2, 3)).set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
// let a = cx.tensor(3).set(vec![1.0, 2.0, 3.0]);
// let model = (
// Linear::new(3, 4, false, &mut cx),
// ReLU,
// Linear::new(4, 2, false, &mut cx),
// );
// model
// .0
// .weight
// .set(vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.]);
// model.2.weight.set(vec![1., 2., 3., 1., 2., 3., 1., 2.]);
// let mut b = model.forward(a).retrieve();
// let mut batch_out = model.forward(batch).retrieve();
// cx.execute();
// let unoptimized_b = b.data();
// let unoptimized_batch_out = batch_out.data();
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
// cx.execute();
// assert_close(&unoptimized_b, &b.data());
// assert_close(&unoptimized_batch_out, &batch_out.data());
// // Test against dfdx
// let dev = Cpu::default();
// let mut model = <(
// dfdx::nn::modules::builders::UnbiasedLinear<3, 4>,
// dfdx::nn::modules::builders::ReLU,
// dfdx::nn::modules::builders::UnbiasedLinear<4, 2>,
// )>::build_on_device(&dev);
// // Set weights
// model.0.weight = dev
// .tensor_from_vec(
// vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.],
// (dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>),
// )
// .permute();
// model.2.weight = dev
// .tensor_from_vec(
// vec![1., 2., 3., 1., 2., 3., 1., 2.],
// (dfdx::shapes::Const::<4>, dfdx::shapes::Const::<2>),
// )
// .permute();
// let a = dev.tensor_from_vec(vec![1.0, 2.0, 3.0], (dfdx::shapes::Const::<3>,));
// let d_batch = dev.tensor_from_vec(
// vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
// (dfdx::shapes::Const::<2>, dfdx::shapes::Const::<3>),
// );
// let out = model.forward(a);
// let d_batch_out = model.forward(d_batch);
// assert_close(&unoptimized_b, &out.as_vec());
// assert_close(&unoptimized_batch_out, &d_batch_out.as_vec());
// }
// }

View File

@@ -0,0 +1,451 @@
use luminal::prelude::*;
use luminal::shape::Expression;
/// Gather entire rows from a 2D tensor using row indices.
///
/// - `data`: (R, D) tensor
/// - `indices`: (N,) Int tensor of row indices
/// - `d`: the number of columns (D), must match data's second dimension
///
/// Returns: (N, D) tensor where output[i] = data[indices[i]]
pub fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
assert_eq!(indices.dtype, DType::Int);
let n = indices.dims1();
// base[i] = indices[i] * D → flat starting position for each row
let base = (indices * d).expand_dim(1, d); // (N, D) broadcast along cols
// col[j] = j → column offsets 0..D
let col = data.graph().arange(d as i32).expand_dim(0, n); // (N, D) broadcast along rows
// flat_idx[i,j] = indices[i] * D + j
let flat_idx = base + col;
data.gather(flat_idx)
}
/// Scatter entire rows into a 2D tensor using row indices.
///
/// - `src`: (N, D) tensor of values to write
/// - `indices`: (N,) Int tensor of destination row indices
/// - `dest`: (R, D) tensor to write into (copied first, then overwritten at index positions)
/// - `d`: the number of columns (D)
///
/// Returns: (R, D) tensor where output = copy(dest); output[indices[i]] = src[i]
pub fn scatter_rows(
src: GraphTensor,
indices: GraphTensor,
dest: GraphTensor,
d: usize,
) -> GraphTensor {
assert_eq!(indices.dtype, DType::Int);
let n = indices.dims1();
// Same index expansion as gather_rows
let base = (indices * d).expand_dim(1, d);
let col = src.graph().arange(d as i32).expand_dim(0, n);
let flat_idx = base + col;
src.scatter(flat_idx, dest)
}
/// Pure HLIR paged attention for one layer with causal masking.
///
/// Inputs:
/// - `q`: (s, hidden) f32 — query vectors
/// - `k_new`: (s, kv_dim) f32 — new key vectors
/// - `v_new`: (s, kv_dim) f32 — new value vectors
/// - `k_cache`: (num_slots, kv_dim) f32 — key cache (preallocated)
/// - `v_cache`: (num_slots, kv_dim) f32 — value cache (preallocated)
/// - `gather_idx`: (ctx_len,) Int — which cache slots to read
/// - `scatter_idx`: (s,) Int — which cache slots to write new KV into
/// - `prev_seq`: number of previously cached tokens (for causal mask offset)
/// - `n_heads`: number of query heads
/// - `n_kv_heads`: number of KV heads (for GQA)
/// - `head_dim`: dimension per head
///
/// Returns: (attn_out, k_cache_new, v_cache_new)
/// - `attn_out`: (s, hidden) f32
/// - `k_cache_new`: (num_slots, kv_dim) f32
/// - `v_cache_new`: (num_slots, kv_dim) f32
#[allow(clippy::too_many_arguments)]
pub fn paged_attention(
q: GraphTensor,
k_new: GraphTensor,
v_new: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
gather_idx: GraphTensor,
scatter_idx: GraphTensor,
prev_seq: Expression,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let kv_dim = n_kv_heads * head_dim;
let kv_groups = n_heads / n_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let s = q.dims()[0];
let ctx = gather_idx.dims()[0];
let cx = q.graph();
// ── Phase 1: Write new KV into cache ──
let k_cache = scatter_rows(k_new, scatter_idx, k_cache, kv_dim);
let v_cache = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
// ── Phase 2: Gather context KV from cache ──
let k = gather_rows(k_cache, gather_idx, kv_dim); // (ctx, kv_dim)
let v = gather_rows(v_cache, gather_idx, kv_dim); // (ctx, kv_dim)
// ── Phase 3: Reshape for multi-head attention ──
// Q: (s, hidden) → (s, n_heads, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
// → (n_kv_heads, kv_groups, s, head_dim)
let q = q
.split_dims(1, head_dim) // (s, n_heads, head_dim)
.split_dims(1, kv_groups) // (s, n_kv_heads, kv_groups, head_dim)
.permute((1, 2, 0, 3)); // (n_kv_heads, kv_groups, s, head_dim)
// K: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, head_dim, ctx)
let k = k
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
.permute((1, 2, 0)); // (n_kv_heads, head_dim, ctx)
// V: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, ctx, head_dim)
let v = v
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
.permute((1, 0, 2)); // (n_kv_heads, ctx, head_dim)
// ── Phase 4: Attention ──
// Broadcast K, V over kv_groups dimension
let k = k.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, head_dim, ctx)
let v = v.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, ctx, head_dim)
// QK^T: (n_kv_heads, kv_groups, s, head_dim) @ (n_kv_heads, kv_groups, head_dim, ctx)
// → (n_kv_heads, kv_groups, s, ctx)
let scores = q.matmul(k) * scale;
// Build causal mask: query at position prev_seq+i can attend to context j iff j <= prev_seq+i.
// row_vals[i] = prev_seq + i, col_vals[j] = j
// mask[i,j] = -1e9 where row_vals[i] < col_vals[j], else 0
let z = Expression::from('z');
let row_vals = cx.iota(z + prev_seq, s).expand_dim(1, ctx); // (s, ctx)
let col_vals = cx.arange(ctx).expand_dim(0, s); // (s, ctx)
let mask = row_vals
.cast(DType::F32)
.lt(col_vals.cast(DType::F32))
.cast(DType::F32)
* -1e9;
// Broadcast (s, ctx) → (n_kv_heads, kv_groups, s, ctx)
let mask = mask.expand_dim(0, n_kv_heads).expand_dim(1, kv_groups);
let scores = scores + mask;
// Softmax over context dimension (axis 3)
let weights = scores.softmax(3);
// Weighted sum: (n_kv_heads, kv_groups, s, ctx) @ (n_kv_heads, kv_groups, ctx, head_dim)
// → (n_kv_heads, kv_groups, s, head_dim)
let out = weights.matmul(v);
// ── Phase 5: Reshape output ──
// (n_kv_heads, kv_groups, s, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
let mut out = out.permute((2, 0, 1, 3));
out.shape = ShapeTracker::new((s, n_heads * head_dim));
(out, k_cache, v_cache)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gather_rows() {
let mut cx = Graph::new();
let data = cx.tensor((4, 3)); // 4 rows, 3 cols
let indices = cx.tensor(3).as_dtype(DType::Int);
let result = gather_rows(data, indices, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
rt.set_data(
data.id,
vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
);
// Gather rows 0, 2, 3
rt.set_data(indices.id, vec![0, 2, 3]);
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(result.id),
vec![1., 2., 3., 7., 8., 9., 10., 11., 12.]
);
}
#[test]
fn test_scatter_rows() {
let mut cx = Graph::new();
let src = cx.tensor((2, 3));
let indices = cx.tensor(2).as_dtype(DType::Int);
let dest = cx.tensor((4, 3));
let result = scatter_rows(src, indices, dest, 3).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
rt.set_data(indices.id, vec![1, 3]);
rt.set_data(dest.id, vec![0.; 12]);
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(result.id),
vec![0., 0., 0., 10., 20., 30., 0., 0., 0., 40., 50., 60.]
);
}
#[test]
fn test_scatter_then_gather_roundtrip() {
let mut cx = Graph::new();
let kv_new = cx.tensor((2, 4)); // 2 new rows, dim=4
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
let cache = cx.tensor((6, 4)); // 6 slots
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
// Scatter new rows into cache, then gather them back
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
rt.set_data(cache.id, vec![0.; 24]); // Zero cache
rt.set_data(gather_idx.id, vec![1, 4]); // Read back from same slots
rt.execute(&cx.dyn_map);
assert_eq!(
*rt.get_f32(gathered.id),
vec![1., 2., 3., 4., 5., 6., 7., 8.]
);
}
#[test]
fn test_paged_attention_shape_and_cache_update() {
// Minimal config: n_heads=2, n_kv_heads=2, head_dim=2, kv_groups=1
// hidden = 4, kv_dim = 4
let n_heads = 2;
let n_kv_heads = 2;
let head_dim = 2;
let hidden = n_heads * head_dim; // 4
let kv_dim = n_kv_heads * head_dim; // 4
let num_slots = 8;
let mut cx = Graph::new();
let q = cx.tensor((1, hidden)); // 1 new token
let k_new = cx.tensor((1, kv_dim));
let v_new = cx.tensor((1, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context tokens
let scatter_idx = cx.tensor(1).as_dtype(DType::Int); // 1 new token
// prev_seq=2: this is the 3rd token (positions 0,1 cached, position 2 is new)
let (attn_out, k_cache_new, v_cache_new) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
2.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
let k_cache_new = k_cache_new.output();
let v_cache_new = v_cache_new.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
rt.set_data(q.id, vec![1., 0., 1., 0.]);
// k_new = [0.5, 0.5, 0.5, 0.5]
rt.set_data(k_new.id, vec![0.5, 0.5, 0.5, 0.5]);
// v_new = [1, 2, 3, 4]
rt.set_data(v_new.id, vec![1., 2., 3., 4.]);
// Zero caches
rt.set_data(k_cache.id, vec![0.; num_slots * kv_dim]);
rt.set_data(v_cache.id, vec![0.; num_slots * kv_dim]);
// Scatter new KV to slot 2
rt.set_data(scatter_idx.id, vec![2]);
// Gather context from slots 0, 1, 2 (slots 0,1 are zeros, slot 2 is the new KV)
rt.set_data(gather_idx.id, vec![0, 1, 2]);
rt.execute(&cx.dyn_map);
// Verify output shape: (1, hidden=4)
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), hidden);
// Verify KV cache was updated: k_cache_new should have [0.5, 0.5, 0.5, 0.5] at slot 2
let k_out = rt.get_f32(k_cache_new.id);
assert_eq!(k_out.len(), num_slots * kv_dim);
// Slot 2 is at offset 2*4=8..12
assert_eq!(&k_out[8..12], &[0.5, 0.5, 0.5, 0.5]);
// Slot 0 should still be zeros
assert_eq!(&k_out[0..4], &[0., 0., 0., 0.]);
let v_out = rt.get_f32(v_cache_new.id);
assert_eq!(&v_out[8..12], &[1., 2., 3., 4.]);
}
#[test]
fn test_paged_attention_known_values() {
// Test with values where we can compute expected attention output.
// n_heads=1, n_kv_heads=1, head_dim=2, kv_groups=1
// hidden=2, kv_dim=2
let n_heads = 1;
let n_kv_heads = 1;
let head_dim = 2;
let hidden = 2;
let kv_dim = 2;
let num_slots = 4;
let mut cx = Graph::new();
let q = cx.tensor((1, hidden));
let k_new = cx.tensor((1, kv_dim));
let v_new = cx.tensor((1, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
let scatter_idx = cx.tensor(1).as_dtype(DType::Int);
// prev_seq=1: 1 cached token + 1 new token, context len=2
// Query at absolute position 1 can attend to context positions 0 and 1
let (attn_out, _, _) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
1.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
// K cached at slot 0: [1, 0]
// K new (written to slot 1): [0, 1]
// V cached at slot 0: [10, 20]
// V new (written to slot 1): [30, 40]
// Q: [1, 1]
let mut k_cache_data = vec![0.; num_slots * kv_dim];
k_cache_data[0] = 1.;
k_cache_data[1] = 0.; // slot 0 K = [1, 0]
let mut v_cache_data = vec![0.; num_slots * kv_dim];
v_cache_data[0] = 10.;
v_cache_data[1] = 20.; // slot 0 V = [10, 20]
rt.set_data(q.id, vec![1., 1.]);
rt.set_data(k_new.id, vec![0., 1.]); // new K = [0, 1]
rt.set_data(v_new.id, vec![30., 40.]); // new V = [30, 40]
rt.set_data(k_cache.id, k_cache_data);
rt.set_data(v_cache.id, v_cache_data);
rt.set_data(scatter_idx.id, vec![1]); // write to slot 1
rt.set_data(gather_idx.id, vec![0, 1]); // gather slots 0, 1
rt.execute(&cx.dyn_map);
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), hidden);
let expected = vec![20.0, 30.0];
for (a, b) in out.iter().zip(&expected) {
assert!((a - b).abs() < 0.1, "Expected {expected:?}, got {out:?}");
}
}
#[test]
fn test_paged_attention_causal_mask() {
// Verify that the causal mask blocks future positions.
// n_heads=1, n_kv_heads=1, head_dim=2
let n_heads = 1;
let n_kv_heads = 1;
let head_dim = 2;
let hidden = 2;
let kv_dim = 2;
let num_slots = 4;
let mut cx = Graph::new();
let q = cx.tensor((2, hidden)); // 2 new tokens
let k_new = cx.tensor((2, kv_dim));
let v_new = cx.tensor((2, kv_dim));
let k_cache = cx.tensor((num_slots, kv_dim));
let v_cache = cx.tensor((num_slots, kv_dim));
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context (1 cached + 2 new)
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
// prev_seq=1: 1 cached token, 2 new tokens → context len=3
// Query 0 at absolute pos 1: can see ctx 0,1 (not 2)
// Query 1 at absolute pos 2: can see ctx 0,1,2
let (attn_out, _, _) = paged_attention(
q,
k_new,
v_new,
k_cache,
v_cache,
gather_idx,
scatter_idx,
1.into(),
n_heads,
n_kv_heads,
head_dim,
);
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
// Cache has 1 token at slot 0
let mut k_cache_data = vec![0.; num_slots * kv_dim];
k_cache_data[0] = 1.;
k_cache_data[1] = 0.; // slot 0: K=[1,0]
let mut v_cache_data = vec![0.; num_slots * kv_dim];
v_cache_data[0] = 100.;
v_cache_data[1] = 0.; // slot 0: V=[100,0]
// 2 new tokens
rt.set_data(q.id, vec![1., 0., 0., 1.]);
rt.set_data(k_new.id, vec![0., 1., 1., 1.]); // token0 K=[0,1], token1 K=[1,1]
rt.set_data(v_new.id, vec![0., 10., 0., 20.]); // token0 V=[0,10], token1 V=[0,20]
rt.set_data(k_cache.id, k_cache_data);
rt.set_data(v_cache.id, v_cache_data);
rt.set_data(scatter_idx.id, vec![1, 2]); // write to slots 1, 2
rt.set_data(gather_idx.id, vec![0, 1, 2]); // gather all 3
rt.execute(&cx.dyn_map);
let out = rt.get_f32(attn_out.id);
assert_eq!(out.len(), 2 * hidden);
// Token 0 (abs pos 1): attends to ctx 0,1 only (ctx 2 is masked)
// Token 1 (abs pos 2): attends to ctx 0,1,2
// Verify output has valid (non-NaN, non-inf) values and correct length
for val in out.iter() {
assert!(val.is_finite(), "Output contains non-finite value: {}", val);
}
}
}

View File

@@ -0,0 +1,408 @@
use luminal::prelude::*;
/// Generic N-dimensional convolution layer implemented with the GraphTensor `unfold` helper.
///
/// The layer expects inputs shaped like `[batch..., channels, spatial...]` where the number of
/// spatial dimensions is greater than zero. The kernel configuration controls how many spatial
/// axes are convolved (N) and must be shorter than the input rank (K): `K > N` is asserted.
pub struct ConvND {
pub weight: GraphTensor, // (ch_out, ch_in * kernel_product)
pub bias: Option<GraphTensor>,
kernel: Vec<usize>,
stride: Vec<usize>,
dilation: Vec<usize>,
padding: Vec<usize>,
ch_in: usize,
ch_out: usize,
}
impl ConvND {
#[allow(clippy::too_many_arguments)]
pub fn new(
ch_in: usize,
ch_out: usize,
kernel: Vec<usize>,
stride: Vec<usize>,
dilation: Vec<usize>,
padding: Vec<usize>,
bias: bool,
cx: &mut Graph,
) -> Self {
assert!(
!kernel.is_empty(),
"ConvND requires at least one spatial dimension in the kernel",
);
let k = kernel.len();
assert_eq!(
stride.len(),
k,
"Stride dimensions ({}) must match kernel dimensions ({k})",
stride.len()
);
assert_eq!(
dilation.len(),
k,
"Dilation dimensions ({}) must match kernel dimensions ({k})",
dilation.len()
);
assert_eq!(
padding.len(),
k,
"Padding dimensions ({}) must match kernel dimensions ({k})",
padding.len()
);
let kernel_product: usize = kernel.iter().product();
Self {
weight: cx
.named_tensor("ConvWeight", (ch_out, ch_in * kernel_product))
.persist(),
bias: if bias {
Some(cx.named_tensor("ConvBias", ch_out).persist())
} else {
None
},
kernel,
stride,
dilation,
padding,
ch_in,
ch_out,
}
}
/// Apply convolution to an input shaped `[batch..., channels, spatial...]`.
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
let input_dims = input.dims();
let rank = input_dims.len();
let spatial = self.kernel.len();
assert!(
rank > spatial,
"ConvND expects input rank ({rank}) to be greater than kernel dims ({spatial})",
);
let batch_len = rank - spatial - 1;
assert_eq!(
input_dims[batch_len],
Expression::from(self.ch_in),
"Input channel dimension ({}) must match ch_in ({})",
input_dims[batch_len],
self.ch_in
);
assert_eq!(
self.weight.dims()[0],
Expression::from(self.ch_out),
"Weight output channels ({}) must match ch_out ({})",
self.weight.dims()[0],
self.ch_out
);
// Pad only the spatial dimensions.
let mut padding = vec![(Expression::from(0), Expression::from(0)); rank];
for (i, pad) in self.padding.iter().enumerate() {
let axis = batch_len + 1 + i;
padding[axis] = (Expression::from(*pad), Expression::from(*pad));
}
let padded = input.pad(padding, 0.0);
// Build unfold parameters with ones for non-spatial axes.
let mut kernel_shape = vec![1; rank];
let mut stride_shape = vec![1; rank];
let mut dilation_shape = vec![1; rank];
for i in 0..spatial {
let axis = batch_len + 1 + i;
kernel_shape[axis] = self.kernel[i];
stride_shape[axis] = self.stride[i];
dilation_shape[axis] = self.dilation[i];
}
let unfolded = padded.unfold(kernel_shape, stride_shape, dilation_shape);
// Move window dimensions to the front for easier indexing.
let mut order: Vec<usize> = (rank..2 * rank).collect();
order.extend(0..rank);
let unfolded = unfolded.permute(order);
let unfolded_dims = unfolded.dims();
// Capture output spatial dimensions from the unfolded view.
let output_dims: Vec<Expression> =
unfolded_dims[batch_len + 1..batch_len + 1 + spatial].to_vec();
// Reorder to [batch..., out..., channels, kernel_spatial..., kernel_batch..., kernel_channel].
let mut order2 = Vec::with_capacity(2 * rank);
// window batch dims
order2.extend(0..batch_len);
// window spatial dims (outputs)
order2.extend(batch_len + 1..batch_len + 1 + spatial);
// window channel dim
order2.push(batch_len);
// kernel spatial dims
order2.extend(rank + batch_len + 1..rank + batch_len + 1 + spatial);
// kernel batch dims and kernel channel dim (to be merged away)
order2.extend(rank..rank + batch_len + 1);
let mut patches = unfolded.permute(order2);
// Drop kernel axes for batch + channel by merging them into the previous dimension.
for _ in 0..=batch_len {
let last = patches.dims().len();
patches = patches.merge_dims(last - 2, last - 1);
}
// Flatten channel and kernel spatial dimensions together.
for _ in 0..spatial {
let channel_axis = batch_len + spatial;
patches = patches.merge_dims(channel_axis, channel_axis + 1);
}
// Collapse batch dimensions into one and output dimensions into one for matmul.
for _ in 1..batch_len {
patches = patches.merge_dims(0, 1);
}
for _ in 1..spatial {
patches = patches.merge_dims(1, 2);
}
let mut out = patches.matmul(self.weight.permute((1, 0)));
// Restore batch and spatial dimensions.
for dim in self.input_batch_dims(&input_dims, batch_len).iter().rev() {
out = out.split_dims(0, *dim);
}
for dim in output_dims.iter().rev() {
out = out.split_dims(batch_len, *dim);
}
// Move channel dimension ahead of the spatial axes: [batch..., ch_out, spatial...]
let mut final_order: Vec<usize> = (0..batch_len).collect();
final_order.push(batch_len + spatial);
final_order.extend(batch_len..batch_len + spatial);
out = out.permute(final_order);
if let Some(_b) = self.bias {
todo!()
// out += b.expand(out.shape);
}
out
}
fn input_batch_dims(&self, input_dims: &[Expression], batch_len: usize) -> Vec<Expression> {
input_dims[..batch_len].to_vec()
}
pub fn infer_output_shape(&self, input: &[usize]) -> Vec<usize> {
let rank = input.len();
let spatial = self.kernel.len();
assert!(rank > spatial, "expected input rank > spatial dims");
let batch_len = rank - spatial - 1;
assert_eq!(
input[batch_len], self.ch_in,
"input channel dimension does not match ch_in",
);
let batch_prefix = &input[..batch_len];
let spatial_dims = &input[batch_len + 1..];
let out_spatial: Vec<usize> = spatial_dims
.iter()
.zip(
self.kernel
.iter()
.zip(self.stride.iter())
.zip(self.dilation.iter())
.zip(self.padding.iter()),
)
.map(|(dim, (((k, s), d), p))| (dim + 2 * p - d * (k - 1) - 1) / s + 1)
.collect();
let mut shape = batch_prefix.to_vec();
shape.push(self.ch_out);
shape.extend(out_spatial);
shape
}
}
#[cfg(test)]
mod tests {
use super::ConvND;
use candle_core::{Device, Tensor};
fn assert_close(a: &[f32], b: &[f32]) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
for (idx, (lhs, rhs)) in a.iter().zip(b.iter()).enumerate() {
let diff = (lhs - rhs).abs();
if diff > 1e-4 {
panic!("values differ at {idx}: {lhs} vs {rhs} (diff {diff})");
}
}
}
fn candle_conv1d_output(
conv: &ConvND,
input: &[f32],
width: usize,
weight: &[f32],
bias: Option<&[f32]>,
) -> candle_core::Result<Vec<f32>> {
let device = Device::Cpu;
let input = Tensor::from_vec(input.to_vec(), (1, conv.ch_in, width), &device)?;
let weight = Tensor::from_vec(
weight.to_vec(),
(conv.ch_out, conv.ch_in, conv.kernel[0]),
&device,
)?;
let bias = match bias {
Some(b) => Some(Tensor::from_vec(b.to_vec(), conv.ch_out, &device)?),
None => None,
};
let output = input.conv1d(
&weight,
conv.padding[0],
conv.stride[0],
conv.dilation[0],
1,
)?;
let output = match bias {
Some(bias) => {
let bias = bias.reshape((1, conv.ch_out, 1))?;
output.broadcast_add(&bias)?
}
None => output,
};
output.flatten_all()?.to_vec1::<f32>()
}
fn candle_conv2d_output(
conv: &ConvND,
input: &[f32],
height: usize,
width: usize,
weight: &[f32],
bias: Option<&[f32]>,
) -> candle_core::Result<Vec<f32>> {
let device = Device::Cpu;
let input = Tensor::from_vec(input.to_vec(), (1, conv.ch_in, height, width), &device)?;
let weight = Tensor::from_vec(
weight.to_vec(),
(conv.ch_out, conv.ch_in, conv.kernel[0], conv.kernel[1]),
&device,
)?;
let bias = match bias {
Some(b) => Some(Tensor::from_vec(b.to_vec(), conv.ch_out, &device)?),
None => None,
};
assert_eq!(
conv.padding[0], conv.padding[1],
"Candle conv2d only supports equal padding"
);
assert_eq!(
conv.stride[0], conv.stride[1],
"Candle conv2d only supports equal stride"
);
assert_eq!(
conv.dilation[0], conv.dilation[1],
"Candle conv2d only supports equal dilation"
);
let output = input.conv2d(
&weight,
conv.padding[0],
conv.stride[0],
conv.dilation[0],
1,
)?;
let output = match bias {
Some(bias) => {
let bias = bias.reshape((1, conv.ch_out, 1, 1))?;
output.broadcast_add(&bias)?
}
None => output,
};
output.flatten_all()?.to_vec1::<f32>()
}
#[test]
fn conv1d_values_match_expected_window_sums() -> candle_core::Result<()> {
let mut cx = luminal::graph::Graph::new();
let conv = ConvND::new(1, 1, vec![3], vec![1], vec![1], vec![1], true, &mut cx);
let input = [1., 2., 3., 4., 5.];
let weight = [1., 1., 1.];
let bias = [0.5];
let out = candle_conv1d_output(&conv, &input, input.len(), &weight, Some(&bias))?;
assert_close(&out, &[3.5, 6.5, 9.5, 12.5, 9.5]);
Ok(())
}
#[test]
fn conv2d_values_accumulate_across_channels() -> candle_core::Result<()> {
let mut cx = luminal::graph::Graph::new();
let conv = ConvND::new(
2,
1,
vec![2, 2],
vec![1, 1],
vec![1, 1],
vec![0, 0],
true,
&mut cx,
);
let input = [
1., 2., 3., 4., 5., 6., 7., 8., 9., // channel 0
9., 8., 7., 6., 5., 4., 3., 2., 1., // channel 1
];
let weight = [1., 1., 1., 1., 2., 2., 2., 2.];
let bias = [0.25];
let out = candle_conv2d_output(&conv, &input, 3, 3, &weight, Some(&bias))?;
assert_close(&out, &[68.25, 64.25, 56.25, 52.25]);
Ok(())
}
#[test]
fn conv1d_shapes_follow_stride_and_padding() {
let mut cx = luminal::graph::Graph::new();
let conv = ConvND::new(1, 1, vec![3], vec![2], vec![1], vec![1], false, &mut cx);
// expected length: floor((padded_len - dilation*(k-1) -1)/stride +1)
// padded_len = 7 + 2 = 9
// effective kernel = 3
// => (9 -3)/2 +1 = 4
let inferred = conv.infer_output_shape(&[2, 1, 7]);
assert_eq!(inferred, vec![2, 1, 4]);
}
#[test]
fn conv2d_shapes_follow_stride_and_padding() {
let mut cx = luminal::graph::Graph::new();
let conv = ConvND::new(
3,
2,
vec![2, 3],
vec![1, 2],
vec![1, 1],
vec![0, 1],
true,
&mut cx,
);
// height: (5 - dilation*(2-1) -1 + 0 +0)/1 +1 = 4
// width: (6 - dilation*(3-1) -1 + 1 +1)/2 +1 = 3
let inferred = conv.infer_output_shape(&[1, 3, 5, 6]);
assert_eq!(inferred, vec![1, 2, 4, 3]);
}
}

View File

@@ -0,0 +1,116 @@
// use luminal::{prelude::*, tests::random_vec};
// pub struct Embedding {
// permute: bool,
// pub weight: GraphTensor, // n embeddings x embedding dim
// embedding_dim: usize,
// }
// impl Embedding {
// pub fn new(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
// Self {
// weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
// permute: false,
// embedding_dim,
// }
// }
// pub fn new_permuted(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
// Self {
// weight: cx.named_tensor("Embedding Weight", (embedding_dim, n_embeddings)),
// permute: true,
// embedding_dim,
// }
// }
// pub fn initialize(self) -> Self {
// self.weight.set(random_vec(
// self.weight.shape.n_elements().to_usize().unwrap(),
// ));
// self
// }
// }
// impl SerializeModule for Embedding {
// fn serialize(&self, s: &mut luminal::module::Serializer) {
// s.tensor("weight", self.weight);
// }
// }
// impl Module<GraphTensor> for Embedding {
// type Output = GraphTensor;
// fn forward(&self, input: GraphTensor) -> Self::Output {
// // Flatten batches
// let batch_size = input.shape.n_elements();
// let inp = input.reshape(batch_size);
// // Gather
// let out = if self.permute {
// self.weight.permute((1, 0)).gather(inp)
// } else {
// self.weight.gather(inp)
// };
// // Unflatten
// let mut new_shape = input.dims();
// new_shape.push(self.embedding_dim.into());
// out.reshape(new_shape)
// }
// }
// impl Embedding {
// // Reverse from embedding to token distribution
// pub fn reverse(&self, input: GraphTensor) -> GraphTensor {
// if self.permute {
// input.matmul(self.weight)
// } else {
// input.matmul(self.weight.permute((1, 0)))
// }
// }
// }
// #[cfg(test)]
// mod tests {
// use dfdx::{
// prelude::Module as DfdxModule,
// tensor::{Cpu, TensorFromVec},
// };
// use luminal::prelude::Module;
// use super::Embedding;
// use dfdx::nn::BuildOnDevice;
// luminal::test_imports!();
// #[test]
// fn test_embedding() {
// let mut cx = Graph::new();
// let batch = cx.tensor((2, 3)).set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
// let a = cx.tensor(3).set(vec![1.0, 0.0, 1.0]).retrieve();
// let model = Embedding::new(3, 4, &mut cx).initialize();
// model
// .weight
// .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
// let mut b = model.forward(a).retrieve();
// let mut batch_out = model.forward(batch).retrieve();
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
// cx.execute();
// let d_dev = Cpu::default();
// let mut d_model = <dfdx::nn::modules::builders::Embedding<3, 4>>::build_on_device(&d_dev);
// d_model.weight = d_dev.tensor_from_vec(
// vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.],
// (DConst::<3>, DConst::<4>),
// );
// let d_a = d_dev.tensor_from_vec(vec![1, 0, 1], (DConst::<3>,));
// let d_batch = d_dev.tensor_from_vec(vec![1, 0, 2, 1, 0, 1], (DConst::<2>, DConst::<3>));
// let d_b = d_model.forward(d_a);
// let d_batch_out = d_model.forward(d_batch);
// assert_close(&b.data(), &d_b.as_vec());
// assert_close(&batch_out.data(), &d_batch_out.as_vec());
// }
// }

View File

@@ -0,0 +1,18 @@
#![allow(unused_imports)]
mod activation;
pub use activation::*;
mod convolution;
pub use convolution::*;
mod embedding;
pub use embedding::*;
mod linear;
pub use linear::*;
mod norm;
pub use norm::*;
mod pooling;
pub use pooling::*;
mod moe;
pub use moe::*;
mod attention;
pub use attention::*;

View File

@@ -0,0 +1,76 @@
use luminal::prelude::*;
/// A simple unbiased linear layer
pub struct Linear {
pub weight: GraphTensor,
pub bias: Option<GraphTensor>,
permute: bool,
}
impl Linear {
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (inp, out)).persist(),
bias: if bias {
Some(cx.named_tensor("Bias", out).persist())
} else {
None
},
permute: false,
}
}
pub fn new_permuted(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (out, inp)).persist(),
bias: if bias {
Some(cx.named_tensor("Bias", out).persist())
} else {
None
},
permute: true,
}
}
}
impl Linear {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
let output = input.matmul(if self.permute {
self.weight.permute((1, 0))
} else {
self.weight
});
if let Some(_bias) = self.bias {
todo!()
// output += bias.expand(output.shape);
}
output
}
}
// #[cfg(test)]
// mod tests {
// use super::Linear;
// use luminal::{prelude::*, tests::assert_close};
// #[test]
// fn test_linear() {
// let mut cx = Graph::new();
// let batch = cx.tensor((2, 3)).set([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
// let a = cx.tensor(3).set([1.0, 2.0, 3.0]);
// let model = Linear::new(3, 4, false, &mut cx).init_rand();
// let mut b = model.forward(a).retrieve();
// let mut batch_out = model.forward(batch).retrieve();
// cx.execute();
// let unoptimized_b = b.data();
// let unoptimized_batch_out = batch_out.data();
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
// cx.execute();
// assert_close(&unoptimized_b, &b.data());
// assert_close(&unoptimized_batch_out, &batch_out.data());
// }
// }

View File

@@ -0,0 +1,513 @@
use luminal::prelude::*;
/// A layer of E experts and a router
pub struct MoE {
pub expert_weights: GraphTensor, // [E, in, out]
pub router: GraphTensor, // [in, E]
pub k: usize,
}
impl MoE {
pub fn forward(&self, activations: GraphTensor) -> GraphTensor {
let n = activations.dims().len();
let e_dim = *self.router.dims().last().unwrap();
let (_, in_size, out_size) = self.expert_weights.dims3();
let io = in_size * out_size;
let k_expr = Expression::from(self.k);
// 1. Routing probabilities: [batch.., E]
let routing_weights = activations.matmul(self.router).softmax(n - 1);
// 2. Top-k expert indices: [batch.., k] (Int)
let top_k_indices = routing_weights.topk_indexes(self.k, n - 1);
// 3. Gather top-k routing values: [batch.., k]
// flat_idx = batch_row * E + expert_idx
// iota(z / k * E) gives batch_row * E at each position in [batch.., k]
let row_offsets = activations
.graph()
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
let routing_flat_idx =
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
let top_k_values = routing_weights.gather(routing_flat_idx); // [batch.., k]
// 4. Gather expert weight matrices: [batch.., k, in, out]
// flat_idx[.., ki, i, o] = expert_idx[.., ki] * in*out + i * out + o
let base = (top_k_indices * io).cast(DType::F32); // [batch.., k]
let within = activations
.graph()
.iota(Expression::from('z'), (in_size, out_size))
.cast(DType::F32); // [in, out] values 0..in*out-1
// Expand base to [batch.., k, in, out]
let n_base = base.dims().len();
let exp_base = base
.expand_dim(n_base, in_size)
.expand_dim(n_base + 1, out_size);
// Expand within to [batch.., k, in, out]
let mut exp_within = within;
for (i, dim) in base.dims().iter().enumerate() {
exp_within = exp_within.expand_dim(i, *dim);
}
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
let gathered = self.expert_weights.gather(expert_flat_idx); // [batch.., k, in, out]
// 5. Batched matmul: [batch.., k, 1, in] @ [batch.., k, in, out] → [batch.., k, out]
let expanded_act = activations
.expand_dim(n - 1, self.k) // [batch.., k, in]
.unsqueeze(n); // [batch.., k, 1, in]
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
(expert_out * weights_exp).sum(n - 1)
}
}
#[cfg(test)]
mod tests {
use super::MoE;
use luminal::prelude::*;
use rand::{rng, Rng};
fn random_vec(n: usize) -> Vec<f32> {
let mut r = rng();
(0..n).map(|_| r.random_range(-0.5..0.5)).collect()
}
fn assert_close(a: &[f32], b: &[f32]) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
let diff = (x - y).abs();
if diff > 1e-3 {
panic!(
"{x} is not close to {y} at index {i}, diff={diff}\n actual: {a:?}\n expected: {b:?}"
);
}
}
}
/// Reference MoE computation for a single input vector.
/// input: [in_dim], router: [in_dim, n_experts] (row-major),
/// expert_weights: [n_experts, in_dim, out_dim] (row-major)
fn moe_reference_1d(
input: &[f32],
router: &[f32],
expert_weights: &[f32],
n_experts: usize,
in_dim: usize,
out_dim: usize,
k: usize,
) -> Vec<f32> {
// 1. Router logits: input @ router → [n_experts]
let mut logits = vec![0.0f32; n_experts];
for e in 0..n_experts {
for i in 0..in_dim {
logits[e] += input[i] * router[i * n_experts + e];
}
}
// 2. Softmax
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|x| (x - max_l).exp()).collect();
let sum_e: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|x| x / sum_e).collect();
// 3. Top-k indices (descending by probability)
let mut indices: Vec<usize> = (0..n_experts).collect();
indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
let top_k_idx = &indices[..k];
let top_k_w: Vec<f32> = top_k_idx.iter().map(|&i| probs[i]).collect();
// 4. Weighted sum of expert outputs (no renormalization, matching code intent)
let mut output = vec![0.0f32; out_dim];
for (ki, &eidx) in top_k_idx.iter().enumerate() {
for o in 0..out_dim {
let mut val = 0.0f32;
for i in 0..in_dim {
val += input[i] * expert_weights[eidx * in_dim * out_dim + i * out_dim + o];
}
output[o] += top_k_w[ki] * val;
}
}
output
}
/// Reference MoE for batched input [batch, in_dim]
#[allow(clippy::too_many_arguments)]
fn moe_reference_batch(
input: &[f32],
router: &[f32],
expert_weights: &[f32],
n_experts: usize,
in_dim: usize,
out_dim: usize,
k: usize,
batch: usize,
) -> Vec<f32> {
let mut output = Vec::with_capacity(batch * out_dim);
for b in 0..batch {
let inp = &input[b * in_dim..(b + 1) * in_dim];
let out = moe_reference_1d(inp, router, expert_weights, n_experts, in_dim, out_dim, k);
output.extend_from_slice(&out);
}
output
}
// ── Test: 1D input, k=1, strongly-routed to expert 0 ────────────────
#[test]
fn test_moe_1d_k1() {
let n_experts = 2;
let in_dim = 3;
let out_dim = 2;
let k = 1;
let mut cx = Graph::new();
let input = cx.tensor(in_dim);
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
let router_w = cx.tensor((in_dim, n_experts));
let moe = MoE {
expert_weights: expert_w,
router: router_w,
k,
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let input_data = vec![1.0, 2.0, 3.0];
// Router strongly favors expert 0
let router_data = vec![
10.0, -10.0, // feature 0
10.0, -10.0, // feature 1
10.0, -10.0, // feature 2
];
// Expert 0: simple linear, Expert 1: different
let expert_data = vec![
// Expert 0: [3x2]
1.0, 0.0, 0.0, 1.0, 1.0, 1.0, // Expert 1: [3x2]
2.0, 0.0, 0.0, 2.0, 2.0, 2.0,
];
rt.set_data(input.id, input_data.clone());
rt.set_data(router_w.id, router_data.clone());
rt.set_data(expert_w.id, expert_data.clone());
rt.execute(&cx.dyn_map);
let expected = moe_reference_1d(
&input_data,
&router_data,
&expert_data,
n_experts,
in_dim,
out_dim,
k,
);
// With strong routing to expert 0: output ≈ [1,2,3]@[[1,0],[0,1],[1,1]] = [4, 5]
assert_close(rt.get_f32(output.id), &expected);
}
// ── Test: 1D input, k=E (all experts selected) ─────────────────────
#[test]
fn test_moe_1d_k_equals_e() {
let n_experts = 3;
let in_dim = 2;
let out_dim = 2;
let k = 3; // select all experts
let mut cx = Graph::new();
let input = cx.tensor(in_dim);
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
let router_w = cx.tensor((in_dim, n_experts));
let moe = MoE {
expert_weights: expert_w,
router: router_w,
k,
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let input_data = vec![1.0, 1.0];
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
let router_data = vec![0.01, 0.02, 0.03, 0.01, 0.02, 0.03];
// Each expert: identity-scaled by index+1
let expert_data = vec![
// Expert 0: identity
1.0, 0.0, 0.0, 1.0, // Expert 1: 2x
2.0, 0.0, 0.0, 2.0, // Expert 2: 3x
3.0, 0.0, 0.0, 3.0,
];
rt.set_data(input.id, input_data.clone());
rt.set_data(router_w.id, router_data.clone());
rt.set_data(expert_w.id, expert_data.clone());
rt.execute(&cx.dyn_map);
let expected = moe_reference_1d(
&input_data,
&router_data,
&expert_data,
n_experts,
in_dim,
out_dim,
k,
);
// Equal routing: each expert weight = 1/3
// output = 1/3 * [1,1] + 1/3 * [2,2] + 1/3 * [3,3] = [2, 2]
assert_close(rt.get_f32(output.id), &expected);
}
// ── Test: 2D batched input ──────────────────────────────────────────
#[test]
fn test_moe_batched() {
let n_experts = 2;
let in_dim = 3;
let out_dim = 2;
let k = 1;
let batch = 2;
let mut cx = Graph::new();
let input = cx.tensor((batch, in_dim));
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
let router_w = cx.tensor((in_dim, n_experts));
let moe = MoE {
expert_weights: expert_w,
router: router_w,
k,
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let input_data = vec![
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
0.0, 1.0, 0.0, // batch 1: routes to expert via feature 1
];
// Router: feature 0 → expert 0, feature 1 → expert 1
let router_data = vec![
10.0, -10.0, // feature 0 → expert 0
-10.0, 10.0, // feature 1 → expert 1
0.0, 0.0, // feature 2 → neutral
];
let expert_data = vec![
// Expert 0: [3x2]
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // Expert 1: [3x2]
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
];
rt.set_data(input.id, input_data.clone());
rt.set_data(router_w.id, router_data.clone());
rt.set_data(expert_w.id, expert_data.clone());
rt.execute(&cx.dyn_map);
let expected = moe_reference_batch(
&input_data,
&router_data,
&expert_data,
n_experts,
in_dim,
out_dim,
k,
batch,
);
assert_close(rt.get_f32(output.id), &expected);
}
// ── Test: random inputs with k=2 ────────────────────────────────────
#[test]
fn test_moe_random_k2() {
let n_experts = 4;
let in_dim = 8;
let out_dim = 4;
let k = 2;
let mut cx = Graph::new();
let input = cx.tensor(in_dim);
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
let router_w = cx.tensor((in_dim, n_experts));
let moe = MoE {
expert_weights: expert_w,
router: router_w,
k,
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let input_data = random_vec(in_dim);
let router_data = random_vec(in_dim * n_experts);
let expert_data = random_vec(n_experts * in_dim * out_dim);
rt.set_data(input.id, input_data.clone());
rt.set_data(router_w.id, router_data.clone());
rt.set_data(expert_w.id, expert_data.clone());
rt.execute(&cx.dyn_map);
let expected = moe_reference_1d(
&input_data,
&router_data,
&expert_data,
n_experts,
in_dim,
out_dim,
k,
);
assert_close(rt.get_f32(output.id), &expected);
}
// ── Test: batched random inputs ─────────────────────────────────────
#[test]
fn test_moe_batched_random() {
let n_experts = 3;
let in_dim = 4;
let out_dim = 3;
let k = 2;
let batch = 4;
let mut cx = Graph::new();
let input = cx.tensor((batch, in_dim));
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
let router_w = cx.tensor((in_dim, n_experts));
let moe = MoE {
expert_weights: expert_w,
router: router_w,
k,
};
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let input_data = random_vec(batch * in_dim);
let router_data = random_vec(in_dim * n_experts);
let expert_data = random_vec(n_experts * in_dim * out_dim);
rt.set_data(input.id, input_data.clone());
rt.set_data(router_w.id, router_data.clone());
rt.set_data(expert_w.id, expert_data.clone());
rt.execute(&cx.dyn_map);
let expected = moe_reference_batch(
&input_data,
&router_data,
&expert_data,
n_experts,
in_dim,
out_dim,
k,
batch,
);
assert_close(rt.get_f32(output.id), &expected);
}
/// Dump the egglog HLIR for a QwenMoE-style GLU-MoE pattern.
/// This helps identify the exact pattern for the GLUMoE backend HostOp.
#[test]
fn dump_glu_moe_egglog() {
use luminal::dtype::DType;
use luminal::egglog_utils::hlir_to_egglog;
let n_experts = 4;
let hidden = 8;
let intermediate = 4;
let top_k: usize = 2;
let mut cx = Graph::new();
// Input tensors
let x = cx.tensor(('s', hidden));
let router = cx.tensor((n_experts, hidden));
let gate_up_weights = cx
.tensor((n_experts, intermediate * 2, hidden))
.as_dtype(DType::Bf16);
let down_weights = cx
.tensor((n_experts, hidden, intermediate))
.as_dtype(DType::Bf16);
let n = x.dims().len(); // 2
let e_dim = *router.dims().first().unwrap(); // E
let k_expr = luminal::shape::Expression::from(top_k);
// 1. Router: softmax(x @ router^T) → [s, E]
let routing_weights = x.matmul(router.t()).softmax(n - 1);
// 2. TopK expert selection → [s, k] (Int)
let top_k_indices = routing_weights.topk_indexes(top_k, n - 1);
// 3. Gather top-k routing values → [s, k]
let row_offsets = cx.iota(
luminal::shape::Expression::from('z') / k_expr * e_dim,
top_k_indices.dims(),
);
let routing_flat_idx =
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
let top_k_values = routing_weights.gather(routing_flat_idx);
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
let gate_up_gathered =
gather_experts_test(x, top_k_indices, gate_up_weights).cast(DType::F32);
let x_exp = x.expand_dim(n - 1, top_k).unsqueeze(n); // [s, k, 1, H]
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n); // [s, k, intermediate*2]
// 5. SwiGLU: silu(gate) * up → [s, k, intermediate]
let gate = gate_up_out.slice((.., .., ..intermediate));
let up = gate_up_out.slice((.., .., intermediate..));
let hidden_act = gate.silu() * up;
// 6. Gather down expert weights → [s, k, H, intermediate]
let down_gathered = gather_experts_test(x, top_k_indices, down_weights).cast(DType::F32);
let hidden_exp = hidden_act.unsqueeze(2); // [s, k, 1, intermediate]
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
// 7. Weighted sum over k experts → [s, H]
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
let _output = (down_out * weights_exp).sum(n - 1).output();
// Dump the HLIR to egglog
let (program, root) = hlir_to_egglog(&cx);
println!("=== GLU-MoE HLIR Egglog Dump ===");
println!("Root: {root}");
println!("{program}");
}
/// Helper: gather expert weight matrices using topk indices.
fn gather_experts_test(
graph_source: GraphTensor,
top_k_indices: GraphTensor,
weights: GraphTensor,
) -> GraphTensor {
let (_, d1, d2) = weights.dims3();
let io = d1 * d2;
let base = (top_k_indices * io).cast(DType::F32);
let within = graph_source
.graph()
.iota(luminal::shape::Expression::from('z'), (d1, d2))
.cast(DType::F32);
let n_base = base.dims().len();
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
let mut exp_within = within;
for (i, dim) in base.dims().iter().enumerate() {
exp_within = exp_within.expand_dim(i, *dim);
}
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
weights.gather(expert_flat_idx)
}
}

View File

@@ -0,0 +1,44 @@
use luminal::prelude::*;
/// A simple layer norm with an optional weight and bias
#[derive(Default)]
pub struct LayerNorm {
pub weight: Option<GraphTensor>,
pub bias: Option<GraphTensor>,
mean_norm: bool,
epsilon: f32,
}
impl LayerNorm {
pub fn new(
dim: usize,
weight: Option<&str>,
bias: Option<&str>,
mean_norm: bool,
epsilon: f32,
cx: &mut Graph,
) -> Self {
Self {
weight: weight.map(|w| cx.named_tensor(w, dim).persist()),
bias: bias.map(|b| cx.named_tensor(b, dim).persist()),
mean_norm,
epsilon,
}
}
}
impl LayerNorm {
pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
if self.mean_norm {
input = input.mean_norm(input.shape.last_axis());
}
input = input.std_norm(input.shape.last_axis(), self.epsilon);
if let Some(w) = self.weight {
input *= w.expand_lhs(&input.dims()[..input.dims().len() - 1]);
}
if let Some(b) = self.bias {
input += b.expand_lhs(&input.dims()[..input.dims().len() - 1]);
}
input
}
}

View File

@@ -0,0 +1,106 @@
// use luminal::prelude::*;
// pub struct AvgPool2D {
// kernel: (usize, usize),
// stride: (usize, usize),
// }
// impl AvgPool2D {
// pub fn new(kernel: (usize, usize), stride: (usize, usize)) -> Self {
// Self { kernel, stride }
// }
// }
// impl SerializeModule for AvgPool2D {
// fn serialize(&self, _s: &mut luminal::module::Serializer) {
// // No parameters to serialize for average pooling
// }
// }
// impl AvgPool2D {
// pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
// // Input: (batch (optional), ch_in, dimx_in, dimy_in)
// let mut expanded = false;
// if input.shape.len() == 3 {
// // Expand batch
// input = input.expand_dim(0, 1);
// expanded = true;
// }
// let (batch, ch_in, dimx_in, dimy_in) = input.dims4();
// let dimx_out = ((dimx_in - self.kernel.0) / self.stride.0 + 1).simplify();
// let dimy_out = ((dimy_in - self.kernel.1) / self.stride.1 + 1).simplify();
// let output = input
// .pool_last_dim(self.kernel.1, self.stride.1, 1) // dilation = 1 for pooling
// .permute((0, 1, 3, 4, 2))
// .pool_last_dim(self.kernel.0, self.stride.0, 1)
// .permute((0, 1, 5, 3, 4, 2))
// .reshape((
// batch,
// ch_in,
// self.kernel.0 * self.kernel.1,
// dimx_out * dimy_out,
// ))
// .mean(2) // Average over the kernel dimension
// .reshape((batch, ch_in, dimx_out, dimy_out));
// if expanded {
// output.reshape((ch_in, dimx_out, dimy_out))
// } else {
// output
// }
// }
// }
// pub struct AdaptiveAvgPool2D {
// output_size: (usize, usize),
// }
// impl AdaptiveAvgPool2D {
// pub fn new(output_size: (usize, usize)) -> Self {
// Self { output_size }
// }
// }
// impl SerializeModule for AdaptiveAvgPool2D {
// fn serialize(&self, _s: &mut luminal::module::Serializer) {
// // No learnable parameters
// }
// }
// impl AdaptiveAvgPool2D {
// pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
// let mut expanded = false;
// // Handle missing batch dimension
// if input.shape.len() == 3 {
// input = input.expand_dim(0, 1);
// expanded = true;
// }
// // Extract dimensions
// let (batch, ch, h_in, w_in) = input.dims4();
// let (h_out, w_out) = self.output_size;
// let stride_h = (h_in / h_out).simplify();
// let stride_w = (w_in / w_out).simplify();
// let kernel_h = (h_in - (h_out - 1) * stride_h).simplify();
// let kernel_w = (w_in - (w_out - 1) * stride_w).simplify();
// // Two-stage pooling (Y then X), followed by averaging over the kernel window
// let mut output = input
// .pool_last_dim(kernel_w, stride_w, 1)
// .permute((0, 1, 3, 4, 2))
// .pool_last_dim(kernel_h, stride_h, 1)
// .permute((0, 1, 5, 3, 4, 2))
// .reshape((batch, ch, kernel_h * kernel_w, h_out * w_out))
// .mean(2)
// .reshape((batch, ch, h_out, w_out));
// // Remove batch dim if it was originally absent
// if expanded {
// output = output.reshape((ch, h_out, w_out));
// }
// output
// }
// }

5
crates/luminal_python/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
*.onnx
__pycache__/
*.pyc
uv.lock
.venv

View File

@@ -0,0 +1,120 @@
## Python Environment
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
## Lessons Learned
At the end of any session that involved a hard or non-obvious bug, append an entry to
`LessonsLearned.md` in this directory. A "hard bug" means any bug that required significant
investigation — intermittent failures, wrong output without a crash, egglog/optimizer issues,
or anything that took more than a few minutes to locate.
Each entry should cover:
1. **What the symptom was** (test failure, wrong output, panic, etc.)
2. **What the actual root cause was** (the specific code/logic that was wrong)
3. **Why it was hard to find** (what made it non-obvious or intermittent)
4. **The fix** (what changed and why it works)
5. **A general principle** extracted from the bug — something that helps avoid the same
class of mistake in future code
The goal is to build a living record of codebase-specific pitfalls that future sessions can
consult before writing new egglog rules, CUDA kernels, or optimizer passes.
1. If you want to run tests:
- `./run_test.sh` - runs tests with the native backend
- `./run_tests_cuda.sh` - runs tests with the CUDA backend
## Testing Best Practices
### Overview
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
### Test Pattern (CORRECT)
All tests should follow this standard pattern:
```python
def test_operation():
"""Brief description of what operation is being tested."""
# 1. Instantiate PyTorch model
model: torch.nn.Module = OperationTestModel()
# 2. Compile with luminal backend
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
# 3. Create test input
x: torch.Tensor = torch.tensor([...]) # or torch.rand(...)
# 4. Run both original and compiled versions
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
# 5. Verify outputs match
assert torch.allclose(output, original)
```
### Test Models
- Define test model classes in `tests/test_models.py`
- Each model should be a simple `torch.nn.Module` that demonstrates one operation or pattern
- Use clear, descriptive class names (e.g., `AddTestModel`, `TransposeTestModel`)
- Include docstrings explaining what the model tests
Example:
```python
class AddTestModel(torch.nn.Module):
"""Tests element-wise addition."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x
```
### What NOT to Do
**❌ DO NOT create ONNX files directly in tests:**
```python
# WRONG - bypasses the PyTorch integration
model_path = create_onnx_model(...)
graph_result = luminal.process_onnx(model_path, backend='native')
```
**✓ DO create PyTorch models and use torch.compile:**
```python
# CORRECT - tests actual user workflow
model: torch.nn.Module = MyTestModel()
model_compiled = torch.compile(model, backend=luminal_backend)
```
### Rationale
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
- **User-facing API**: Tests use the same API that users will use (torch.compile)
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
### Special Cases
**Testing constants:**
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1.0, 2.0, 3.0])
return x + constant
```
**Testing type casts:**
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
```
**Testing complex operations:**
Chain operations naturally in PyTorch - ONNX export handles the conversion:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
transposed = x.transpose(0, 1)
scaled = transposed * 2.0
return scaled + 1.0
```

View File

@@ -0,0 +1,758 @@
# Lessons Learned
This file documents hard bugs encountered in this codebase, their root causes, and principles
to prevent similar issues in the future.
---
## 2026-02-24 — Intermittent CUDA Backend Failures: Embed False Match + Batched Matmul Dimension Drop
### Background: Why the Failures Were Intermittent
Both bugs only appeared on roughly 50% of test runs. The source of non-determinism is
`FxHashMap` (a fixed-seed hash map). The egglog optimizer's `SerializedEGraph::new` builds
`Vec<NodeId>` orderings for each e-class by iterating a `FxHashMap`, producing non-deterministic
node orderings. `random_initial_choice()` in `src/egglog_utils/mod.rs` then randomly picks one
e-node per e-class as the starting representation for the profiling phase. The combination means
some runs pick a correct kernel and some pick a broken one from the same e-class.
**Lesson**: When a test fails intermittently at a roughly 50% rate, suspect the egglog extractor
choosing between two e-nodes in the same e-class — one correct, one broken. The fix is always in
the broken e-node's rewrite rule.
---
### Bug 1: `test_gather_elements` — KernelEmbed and RowEmbed False Match
**Files changed**:
- `crates/luminal_cuda/src/kernel/hlir.rs` (KernelEmbed, 4 rules)
- `crates/luminal_cuda/src/block/ops.rs` (RowEmbed, 4 rules)
#### What happened
`gather_elements` (axis-aware gather) decomposes into a flat gather by computing:
```
flat_idx = Add(
Mul(indices, stride[axis]),
Mul(Expand(Iota(dim_size)), stride[non_axis])
)
```
`KernelEmbed` and `RowEmbed` are optimized embedding lookup kernels. A genuine embedding
lookup produces:
```
flat_idx = Add(
Mul(Cast(token_ids), embed_dim),
Iota(embed_dim) ← bare Iota, the position within an embedding row
)
```
The egglog rewrite rules for both ops matched `Add(?mul_result, ?iota_result)` where
`?iota_result` was **unconstrained** — it could bind to anything, including
`Mul(Expand(Iota(n)), stride)` from `gather_elements`. This created a `KernelEmbed`/`RowEmbed`
node in the same e-class as the `Gather` node. When the extractor picked it, `build_payload`
called `flatten_mul_strides(range, token_stride)` which asserted `range.len() == token_stride.len()`:
- `range` came from `RemoveNthFromEnd(idx_shape, 0)` → length 1
- `token_stride` came from the indices strides → length 2
- Assertion failed → panic.
#### The fix
Add `(= ?iota_result (Iota ?iota_expr ?iota_range))` to all 8 rules, requiring the positional
component to be a bare `Iota` node:
```egglog
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
(= ?iota_result (Iota ?iota_expr ?iota_range)) ← added
(= ?mul_result (Mul ...))
```
#### Investigation note
The initial plan correctly identified `KernelEmbed` as faulty, but missed `RowEmbed`. The two
ops are structurally identical but live in different parts of the codebase (`kernel/` vs
`block/`). The second bug was only discovered when the backtrace pointed to
`RowEmbed::build_payload` instead of `KernelEmbed::compile`. Always search for sibling
implementations when fixing a pattern-matching bug in one op.
---
### Bug 2: `test_matmul_batched` — CuBlasLt Drops Batch Dimension
**Files changed**:
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg`
#### What happened
The luminal frontend decomposes `(2,3,4) @ (2,4,5)` into:
```rust
let w = rhs.permute((0, 2, 1)); // (2,4,5) → (2,5,4)
let mul = self.expand_dim(2, d) // (2,3,4) → (2,3,5,4)
* w.expand_dim(1, b); // (2,5,4) → (2,3,5,4)
mul.sum(3) // → (2,3,5), correct out_shape
```
All four cublaslt rewrite rules extracted `m` and `n` from the output shape using
`nth_from_end`, which succeeds for any rank:
```egglog
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
For `out_shape = [2, 3, 5]`: `?m = 3`, `?n = 5`. The batch dim `2` is never extracted or
stored. The rules also validated stride patterns using `nth_from_end` on the stride arrays —
but for this batched case, **all stride checks coincidentally passed** because the last three
strides of the 4D expanded tensors happened to satisfy the 2D row/column-major patterns.
The resulting `CuBlasLt` node had `output_size() = m * n = 15`. The batch dimension was
silently discarded. The runtime allocated a 15-element output buffer, cuBLAS wrote a 3×5
result, and the test got back 15 values instead of 30.
#### The fix
Add `(= (len ?out_shape) 2)` to all 4 rules:
```egglog
(= (len ?out_shape) 2) ← added: cuBLAS is 2D only
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
`len` counts elements in the `ECons`-list shape. With this constraint, any `Sum` node with a
3D+ output shape (batched matmul) is not matched by cuBLAS rules and falls through to
`KernelSumReduce + KernelMul` (or the tiling block ops), which correctly use
`out_shape.iter().product()` for their output sizes.
Note: `TileMatmulSplitK` and `TileMatmulFullSplit` do NOT need this fix — their `output_size()`
already returns `untiled_range.iter().product()` which includes all dimensions.
---
### General Principle: Always Constrain Shape Rank in Egglog Rules
Both bugs share the same structural cause: **egglog rewrite rules that used `nth_from_end` to
extract dimensions from a shape list without constraining the list's length.** Since
`nth_from_end` silently succeeds for any list with enough trailing elements, rules written for
2D tensors accidentally matched higher-rank tensors.
**Rule for writing egglog rewrite rules in this codebase**:
> If a rule is designed for a specific tensor rank, always add an explicit
> `(= (len ?shape) N)` constraint. If a rule is designed to handle arbitrary ranks but an
> op's output only covers a subset of dimensions (like cuBLAS covering only the last 2),
> that is a correctness bug — either implement strided batched cuBLAS or add the rank
> constraint and fall back to a kernel that handles all dimensions.
---
### Debugging Intermittent CUDA Failures: Effective Approach
The investigation used extensive `eprintln!` debug logging to trace which kernels were compiled
vs. skipped. Key observations:
1. **In the passing case**: `KernelSumReduce::compile()` was called, kernels were allocated.
2. **In the failing case**: `KernelSumReduce::compile()` was never called, yet output was produced.
This asymmetry pointed to a `HostOp` path (cuBLAS) executing instead of the `KernelOp` path,
which narrowed the search to cublaslt rewrite rules. The HLIR-level `SumReduce::to_egglog` log
confirmed the correct HLIR node existed — the bug was in the e-graph optimization choosing
a different (broken) e-node from the same e-class.
**Effective debug strategy for egglog non-determinism bugs**:
1. Add logging at compile time for each kernel type (`KernelFoo::compile`, `HostFoo::execute`)
2. Compare passing vs. failing runs to see which kernels are/aren't invoked
3. The missing kernel's e-class contains a broken alternative — find it via the egglog rewrite rules
4. Check the op that *is* executing — its `output_size()` reveals what's wrong with the false match
---
## 2026-02-25 — OneHot Test Panic: Cast(Int→F32) Produces Int Output
### What the symptom was
`test_onehot` panicked at `src/hlir.rs:1625` in `get_f32()`: the output buffer was
`NativeData::Int` instead of the expected `NativeData::F32`.
### What the actual root cause was
The Cast parser's `* 1.0` workaround for `Int → F32` casts used `input * one_expanded`
(Int GraphTensor on the left, F32 constant on the right). However, `Mul for GraphTensor`
always uses `self.dtype` (the **left** operand's dtype) for the result, and the native
runtime's `Mul::execute` dispatches on the **first** input's `NativeData` variant. So
`Int * F32` produced `DType::Int` / `NativeData::Int` — the exact opposite of the intended
F32 output.
### Why it was hard to find
1. **The OneHot parser was a red herring**: The initial plan assumed the OneHot ONNX node
was being parsed, but `torch.onnx.export` decomposes `one_hot` into
`Unsqueeze → Equal → Cast(Bool→Int) → Cast(Int→F32)`. The OneHot parser was never called.
2. **The `* 1.0` workaround looked correct**: It was used successfully in many other parsers,
but those all had F32 inputs (where `F32 * F32 = F32`). The Int→F32 case was the only
path where the left operand was Int.
3. **Operand order matters silently**: Nothing warns about mixed-dtype Mul — it just takes
the left operand's dtype.
### The fix
In `ops_parse/unary.rs` `parse_cast_node`, split the combined condition into two cases:
- **No-op cast** (`cast_result.id == input.id`): `input * one_expanded` — preserves dtype
- **Int source** (`input.dtype == DType::Int`): `one_expanded * input` — F32 on the left
ensures F32 output
### General principle
**In luminal, binary op dtype is always the LEFT operand's dtype.** When constructing
`GraphTensor * constant_float(1.0)` for type materialization, always put the operand
whose dtype you want to preserve on the LEFT side. When converting Int→F32, the F32
constant must be the left operand.
---
## 2026-02-26 — ScatterND Fails on CUDA: "does not produce an egraph"
### What the symptom was
`test_scatter_nd` passed on native backend but failed on CUDA with "does not produce an
egraph". The CUDA compilation could not extract a valid program from the e-graph.
### What the actual root cause was
`scatter_nd` in `movement.rs` does `indices * 1` (line 353) to materialize the tensor for
reshaping. The `* 1` dispatches to `Mul<S: Into<Expression>>`, which creates a `constant(1)`
`Iota(1,1)``DType::Int`. But the ONNX parser creates all tensors as `DType::F32`
(via `named_tensor()` in `compiled_graph.rs:70`), so indices arrive as F32. This produces
`Mul(F32, Int)` — mixed dtypes.
The HLIR Mul dtype rule (`hlir.rs:886-888`) uses `(= ?dty (dtype ?lhs))` and
`(= ?dty (dtype ?rhs))` with the same `?dty` variable, requiring both inputs to have
matching dtypes. `F32 != Int` → the rule never fires → the Mul node gets **no dtype**.
Every downstream op checks `(= ?dty (dtype ?upstream))`. Without dtype on the Mul, no
CUDA kernel rewrite rules fire for any downstream op (KernelMul, KernelAdd, KernelLessThan,
etc.). When `cleanup_hlir` runs (enabled for CUDA, disabled for native), it deletes all
unrewritten HLIR ops, leaving empty e-classes → egraph extraction fails.
### Why it was hard to find
1. **Works on native**: `cleanup_hlir = false` for NativeRuntime, so unrewritten HLIR ops
are never deleted. NativeOp dispatches on actual runtime data, not egglog dtype.
2. **Cascading failure**: The root cause (missing dtype on one Mul) silently propagated
through every downstream op, making it look like a systemic CUDA issue rather than a
single dtype mismatch.
3. **`scatter_elements` works fine**: The sibling op already cast indices via
`(idx_f32 + (is_neg * adj)).cast(DType::Int)`, so only `scatter_nd` had this bug.
### The fix
Added `let indices = indices.cast(DType::Int);` at the top of `scatter_nd` in
`movement.rs`, before any arithmetic on indices. `GraphTensor::cast()` short-circuits
when `self.dtype == dtype`, so this is safe for callers already passing Int indices.
Also added the same cast in `parse_scatter_nd_node` for explicitness.
### General principle
**Always cast index tensors to `DType::Int` before arithmetic in graph-building code.**
ONNX tensors arrive as F32 from the Python bridge. Any `indices * stride` or
`indices * 1` will produce `Mul(F32, Int)` which breaks HLIR dtype propagation on CUDA.
The pattern `let indices = indices.cast(DType::Int);` at the top of any index-consuming
function is defensive and free (no-op when already Int).
---
## 2026-03-04 — Dynamic Shapes: Empty Buffer for BOOL Scalar Initializer
### What the symptom was
`test_hf_llama_decode_loop_dynamic` panicked at `bin_fn: a index 0 out of bounds (a.len=0), shape=[1, 1, 4, 4], strides=[0, 0, 0, 0]`. An Input node labeled `"new_ones"` had an empty buffer at runtime.
### What the actual root cause was
Two issues combined:
1. **`load_tensor_floats` didn't handle ONNX data_type=9 (BOOL)**. The `new_ones` initializer was a BOOL scalar (1 byte in `raw_data`). `load_tensor_floats` fell through to the fallback case, which tried `chunks_exact(4)` on 1 byte → produced 0 chunks → returned empty vec `[]`. The buffer was set with empty data.
2. **Scalar initializers with empty `dims` created 0-dimensional tensors**. ONNX represents scalars with `dims=[]`. The initializer loop computed `shape = init.dims.iter().map(|&d| d as usize).collect()` → empty vec `[]`, then called `named_tensor(name, [])` which created a tensor with 0 dimensions instead of the intended scalar `[1]`.
### Why it was hard to find
1. **Misdiagnosed as ConstantOfShape issue**: The original plan targeted `ConstantOfShape` with dynamic shapes. The shape `[1,1,4,4]` with strides `[0,0,0,0]` looked like a broadcast from a constant fill. But `parse_constant_of_shape` was never called — the `new_ones` tensor came from an ONNX initializer, not a computation node.
2. **The BOOL data type is unusual**: Most ONNX tensors are FLOAT, INT32, or INT64. BOOL initializers only appear in specific patterns (like `torch.ones()` in attention mask computation). `load_initializer_as_f32` already handled BOOL, but its sibling `load_tensor_floats` didn't.
3. **Empty vec is valid data**: `set_data(node_id, [])` doesn't panic — it silently sets an empty buffer. The error only manifests later when a downstream op tries to read index 0.
### The fix
1. Added `data_type=9` (BOOL) handling to `load_tensor_floats` in `util.rs` — same logic as `load_initializer_as_f32`: 1 byte per element, non-zero → 1.0, zero → 0.0.
2. In `compiled_graph.rs`, initializer tensor creation: if `shape.is_empty()`, set `shape = vec![1]` (scalar representation in luminal).
### General principle
**Keep data loading functions in sync.** `load_tensor_floats` and `load_initializer_as_f32` serve the same purpose (loading ONNX TensorProto data as f32) but had different data type coverage. When adding a new data type to one, check and update the other. Better yet, refactor them into a single function.
**ONNX scalars have `dims=[]`, luminal scalars have shape `[1]`.** Always convert empty dims to `[1]` when creating luminal tensors from ONNX data.
---
## 2026-03-04 — Where Node Missing Broadcast: KernelMul flatten_strides Panic on CUDA
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` panicked at `flatten_strides` with `left: 4, right: 1` during
CUDA `KernelMul::compile`. The `KernelMul` had `out_shape=[1, 1, a, a]` but `b_stride=[z]` (1D).
### What the actual root cause was
`parse_where_node` called `x.cond(condition, y)` without broadcasting the inputs to matching ranks.
The ONNX Where op for the attention mask had condition=[1,1,a,a] (4D), x=[1] (scalar), y=[1] (scalar).
Luminal's `cond` doesn't auto-broadcast — it passes the shape trackers directly to the HLIR node.
The resulting Mul had input A with 4D strides and input B with 1D strides.
### Why it was hard to find
1. **Only triggered by 1B model**: The tiny model's Where inputs all had matching ranks (no scalars).
2. **CUDA-only**: The native runtime's `bin_fn` uses `StridedIterator` which handles mismatched
strides more gracefully. CUDA's `KernelMul::compile` calls `flatten_strides` which asserts
`range.len() == strides.len()`.
3. **Delayed crash**: The mismatch was created during ONNX parsing but only manifested during
CUDA kernel compilation (graph search phase).
### The fix
Added numpy-style broadcasting to `parse_where_node`: compute the broadcast shape across all 3
inputs, then `broadcast_to_expr` each to the common shape before calling `cond`.
### General principle
**ONNX binary/ternary ops all use numpy broadcasting.** When parsing ONNX ops that take multiple
tensor inputs (Where, Add, Mul, etc.), always broadcast all inputs to a common shape BEFORE
calling the luminal graph operation. Luminal graph ops do NOT auto-broadcast — they expect inputs
with matching shape tracker dimensions.
---
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
the value at column 0 of each row (all three top-k positions got the same value).
Native backend was fine.
2. **Root cause**: `gather_elements` was called with a non-contiguous index tensor produced by
`argsort(axis=1) → slice_along(..k, axis=1)`. The slice creates a ShapeTracker view of the
[4,8] argsort buffer with dims [4,3] and strides [8,1]. When this flowed through the
gather_elements Int arithmetic chain (cast, multiply, add) and into the final Gather CUDA
kernel, the non-contiguous strides caused incorrect index reads for later rows.
3. **Why it was hard to find**: `test_topk_indices` passed (it only tests argsort+slice, not
the downstream gather_elements). A standalone `test_gather_elements` with constant indices
also passed because constant indices are contiguous. The bug only manifested when runtime-
computed non-contiguous indices were used with data of a different size along the gather axis.
4. **Fix**: In `parse_topk_node`, compute `gather_elements(x, full_argsort, axis)` with the
full [4,8] argsort result (same size as data), then slice the gathered values to [4,3].
This ensures gather_elements always operates on same-sized contiguous tensors.
5. **General principle**: When building graph operations that chain shape-tracker views
(slice, transpose, etc.) into downstream HLIR ops on CUDA, prefer operating on full
contiguous tensors first and slicing the result afterward. Non-contiguous views flowing
through multiple CUDA kernels can trigger stride-related bugs in the egglog-compiled code.
---
## 2026-03-07 — Non-deterministic CUDA_ERROR_ILLEGAL_ADDRESS: Multiple Missing Rank Constraints
### What the symptom was
`test_hf_llama_tiny` on CUDA failed ~70% of runs with `CUDA_ERROR_ILLEGAL_ADDRESS`. Failures
were non-deterministic due to egglog's `FxHashMap` iteration order in `random_initial_choice()`.
### What the actual root cause was
**Multiple** matmul egglog rules lacked `(= (len ?out_shape) 2)` constraints:
1. `TileMatmulSplitK` in `block/ops.rs` (disabled via comment but rule still registered)
2. `TileMatmulFullSplit` in `block/ops.rs`
3. All 4 `sgemm_v2_*.egg` rules in `host/cublas/`
The `cublaslt_*.egg` rules already had the constraint. When egglog picked TileMatmul or sgemm
for a 3D+ batched matmul, the generated CUDA kernels accessed out-of-bounds memory.
Additionally, `KernelEmbed` in `kernel/hlir.rs` had an output indexing bug:
`out[out_offset * embed_dim + embed_idx]` should be `out[out_offset + embed_idx]` because
`out_offset` already includes the embed_dim factor from `flatten_strides`.
**Most critically**, the KernelEmbed and RowEmbed "with cast" egglog rules passed the
**pre-cast** float token_ids (`?token_ids`) to the embed kernel instead of the **post-cast**
int token_ids (`?token_ids_cast`). The CUDA kernel reads token_ids as `const int*`, so float
data gets reinterpreted as enormous garbage integers, causing out-of-bounds embed table access.
### Why it was hard to find
1. **Multiple independent bug sources**: The ~70% failure rate was caused by three separate bugs
(matmul rank, embed output indexing, embed pre-cast input). Each fix only reduced the rate
partially, making it seem like each fix was insufficient.
2. **CudaGraph wrapping**: The crash occurred inside `CudaGraphOp::execute_internal` which
batches multiple kernels via CUDA graphs. The error just said "CudaGraph" — it
didn't identify which kernel crashed. Adding per-kernel debug launches was essential.
3. **Cascading failures**: When the Megakernel (containing RowEmbed with the pre-cast bug)
corrupted the embed output, the NEXT CudaGraph group's kernels crashed reading the garbage.
This made the Megakernel appear to be the victim, not the source.
4. **The pre-cast bug only crashes SOMETIMES**: Egglog's random choice determines whether
KernelEmbed/RowEmbed is selected (crash) or the generic Gather path is used (works).
Float token_id 1.0 (= 0x3F800000 = 1065353216 as int) produces an astronomically large
embed table index, causing ILLEGAL_ADDRESS.
### The fix
- Added `(= (len ?out_shape) 2)` to TileMatmulSplitK, TileMatmulFullSplit, and all 4 sgemm_v2 rules
- Fixed KernelEmbed output indexing: `out[out_offset + embed_idx]`
- **Fixed KernelEmbed/RowEmbed "with cast" rules**: Changed input from `?token_ids` to
`?token_ids_cast` — using the post-Cast int tensor instead of the pre-Cast float tensor
### Results
Failure rate: ~70% → 0% (20/20 passing). All three bugs needed to be fixed together.
### General principle
**When an egglog rule matches a sub-expression chain (like Cast→Mul→Add), be precise about
which intermediate result becomes each input.** The "with cast" embed rules matched
`Cast(?token_ids, ...)` to verify the Cast existed, but then passed `?token_ids` (the Cast
INPUT) instead of `?token_ids_cast` (the Cast OUTPUT) to the embed kernel. The kernel expects
int data, so the pre-cast float data was reinterpreted as garbage ints.
**Always search for sibling implementations**: KernelEmbed (in `kernel/hlir.rs`) and RowEmbed
(in `block/ops.rs`) had the SAME bug in their "with cast" rules. Fixing one without the other
only reduces the failure rate — both must be fixed.
---
## 2026-03-09 — TileMatmulFullSplit Matches Element-wise Square+Sum from LayerNorm
### What the symptom was
`test_qwen_image_transformer_tiny` on CUDA produced NaN in specific output rows. The failure
was non-deterministic (~85% failure rate) due to egglog's random e-class extraction picking
TileMatmulFullSplit for some operations.
### What the actual root cause was
The `TileMatmulFullSplit` rewrite rule in `block/ops.rs` matched any `Mul + Sum` pattern with
a 2D output, contiguous K-strides, and F32 inputs. This correctly matched real matmuls, but
ALSO matched the element-wise `x * x + Sum(last_dim)` pattern from LayerNorm/RMSNorm
(Pow(x, 2) → ReduceMean).
For a [1, 4, 64] activation tensor `x`:
- `Mul(x, x)` shape: [1, 4, 64], strides: [256z, 64z, z] for both inputs
- `Sum(dim=2)` output: [1, 4], len=2 ✓
TileMatmulFullSplit interpreted this as a [1, 64] × [64, 4] → [1, 4] matmul with:
- A = row 0 of x (64 elements), B = same buffer at column offsets
The kernel computed `C[j] = sum_k x[k] * x[j*64+k]` (cross-products) instead of the correct
`C[j] = sum_k x[j*64+k]^2` (squared sums). This produced subtly wrong values for j > 0
(correct for j=0 since cross-product with self = squared sum). These wrong values propagated
through LayerNorm → downstream operations → softmax → NaN.
Key diagnostic: adding `printf` to the kernel showed `a_ptr == b_ptr` (same buffer for both
inputs), confirming the kernel was operating on `x * x` not a real matmul.
### Why it was hard to find
1. **Individual op tests passed**: Simple Gemm tests, attention tests, and all other bisection
tests passed because they didn't have the specific `x*x → Sum` pattern.
2. **Non-deterministic**: The bug only manifested when egglog selected TileMatmulFullSplit
over the kernel fallback for the square+sum operation.
3. **No NaN from TileMatmulFullSplit itself**: The kernel produced wrong-but-finite values.
NaN only appeared downstream through softmax (exp(large) → ∞ → ∞/∞ = NaN).
4. **Systematic elimination needed**: Had to disable all block ops, then enable one at a time,
to narrow down TileMatmulFullSplit as the culprit.
### The fix
Added matmul broadcast constraints to both `TileMatmulFullSplit` and `TileMatmulSplitK` rules:
```egglog
; Assert proper matmul broadcast pattern:
; A is broadcast over N (a_n_stride = 0), B is broadcast over M (b_m_stride = 0)
(= ?a_n_stride (MNum 0))
(= ?b_m_stride (MNum 0))
```
In a real matmul `[M, K] × [K, N]`, the Mul is created by expanding dims:
- A is broadcast over N → a_n_stride = 0
- B is broadcast over M → b_m_stride = 0
In element-wise `x * x`, both strides are identical (non-zero for all dims), so the
constraints correctly reject it. The cuBLAS `.egg` rules already had these constraints.
### General principle
**Matmul Mul+Sum patterns have specific broadcast structure: one input is broadcast over M
and the other over N.** When writing egglog rules that match `Mul + Sum` patterns for matmul
optimization, always verify the broadcast pattern (`a_n_stride = 0` and `b_m_stride = 0`).
This prevents matching element-wise operations like `x*x → sum` that happen to have a 2D
output and contiguous strides.
---
## 2026-03-09 — Conv3D Permute Axis Mismatch in ONNX Conv Parser
### Symptom
`test_qwen_image_vae_decoder_tiny` panicked with:
> Permute axes (5) doesn't match shape axes (6)
at `src/shape/tracker.rs:153`, during `parse_conv_node`.
### Root cause
The Conv parser's unfold → matmul algorithm used two consecutive permutes with incorrect
index calculations. After unfold produces a 2N-dimensional tensor
`[win_0..win_{N-1}, k_0..k_{N-1}]`, the first permute swapped kernel dims to the front.
But the second permute's index math still assumed the original (pre-first-permute) ordering,
confusing kernel dimensions with window dimensions. Additionally:
1. `output_spatial_dims` was captured from wrong indices (kernel dims instead of window
spatial dims)
2. The `split_dims` loop iterated `spatial` times instead of `spatial-1`, creating a
spurious size-1 dimension
3. The final permute array had `1+spatial` elements for a tensor with `2+spatial` dims
For Conv2D (spatial=2) this was never caught because the xfail'd VAE decoder test was the
only test exercising the Conv parser — the transformer tests don't use Conv ONNX nodes.
### Why it was hard to find
The Conv parser was written and the VAE test immediately xfail'd due to a *different* bug
(`merge_dims` being `todo!()`). Once `merge_dims` was implemented, the Conv parser's own
bugs surfaced for the first time.
### Fix
Rewrote the unfold → matmul section with a single correct permute:
1. **One permute** to `[N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]`
— groups batch | output spatial | channel+kernel
2. **Capture** `output_spatial_dims` from correct indices `[1..1+spatial]`
3. **Merge** all channel+kernel dims from the end into one
4. **Merge** spatial dims into one → `[N, spatial_product, C_in*kernel_product]`
5. **Matmul**`[N, spatial_product, C_out]`
6. **Split** spatial back with `spatial-1` splits (not `spatial`)
7. **Permute** C_out to position 1 with correct `2+spatial` element array
### General principle
**When chaining permutes on high-dimensional tensors, prefer a single combined permute.**
Multiple permutes with hand-computed index arrays are error-prone because each permute
redefines what indices mean. A single permute from the original layout to the target layout
is easier to verify and less likely to confuse source/destination ordering. Also, ensure
`split_dims` loop counts match: splitting N dims out of a product requires N-1 splits
(the outermost dim is the quotient, not split out separately).
---
## 2026-03-18 — CUDA Search Rejects All Candidates: Zero Dummy Data Causes NaN for Div/Pow/Mod/Erf
### What the symptom was
6 CUDA tests (`test_pow`, `test_pow_broadcast`, `test_div`, `test_mod`, `test_mod_broadcast`,
`test_erf`) consistently failed with `Failed to find a viable initial genome for group 0 after
100 attempts`. All 6 passed on native backend.
### What the actual root cause was
The CUDA two-phase initialization in `build_cuda_backend` set ALL input tensor buffers to
`0.0f32` as dummy data for profiling. When `torch.compile` decomposes a model, it passes
model weights as additional ONNX graph inputs (not initializers). Since there were no ONNX
initializers to overwrite the zeros, weight buffers stayed all-zero during search.
Operations with zero inputs produced NaN:
- `fmod(0, 0) = NaN` (Mod test)
- `weight * recip(0) = weight * inf` → with any zero weight → `0 * inf = NaN` (Div test)
- `abs(0).log() = log(0) = -inf` → downstream NaN (Pow test)
- `sign(0)` chain → operations on zero inputs (Erf test)
The `has_nan_outputs` check rejected every candidate genome, exhausting all 100 attempts.
### Why it was hard to find
1. **No panic, no crash — silent NaN rejection**: The error message said "Failed to find a
viable initial genome" which suggested an egglog rewrite issue, not a data issue.
2. **Works on native**: `NativeRuntime::has_nan_outputs()` returns `false` by default (no NaN
check), so zero inputs never caused problems on native.
3. **torch.compile vs direct export difference**: Directly exporting a model via
`torch.onnx.export(model, ...)` produces initializers. But `torch.compile`'s backend
receives a `GraphModule` where weights are graph inputs, not initializers. The ONNX file
from `torch.compile` has 0 initializers.
4. **CudaRuntime's own `allocate_dummy_input` already uses 1.0**: The runtime knew zeros
were problematic (comment: "Zero inputs often hide numerical issues"), but the
`compiled_graph.rs` code used `0.0f32` independently.
### The fix
Changed dummy data from `vec![0.0f32; n_elements]` to `vec![1.0f32; n_elements]` in
`build_cuda_backend`. Using 1.0 is numerically safe: `fmod(1,1)=0`, `recip(1)=1`,
`log(1)=0`, `exp(1)≈2.7` — no NaN or inf. Profiling timing is unaffected (same number
of FLOPs and memory accesses).
### General principle
**Use small non-zero values (1.0) for dummy profiling data, never zeros.** Zero is a
singularity for many floating-point operations (division, log, fmod with zero divisor).
The CUDA runtime's `allocate_dummy_input` already followed this principle — the ONNX
pipeline's `build_cuda_backend` was inconsistent. When creating dummy data for GPU
profiling, always match the runtime's safer default.
---
## 2026-03-18 — Dynamic Decode Loop Fails: HLIR Weight Buffers Consumed After First Execute
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` passed step 0 (seq_len=6) but panicked on step 1
(seq_len=7) with `no entry found for key` at `cublaslt/mod.rs:294` — the CuBlasLt op couldn't
find its weight input buffer.
### What the actual root cause was
**Two bugs:**
1. **Missing `)` in egglog rule** (`luminal_cuda_lite/src/kernel/hlir.rs:3042`): The fourth
KernelEmbed rule ("kernel embed with mul reversed") had 3 closing parens after `INil` instead
of 4. The missing `)` failed to close the `(= ?mul_result ...)` form. This caused an egglog
parse error during search, caught by `catch_unwind`. The rule was dead code — it never fired,
but the parse error consumed a search iteration.
2. **HLIR buffer consumption killed weight buffers** (`luminal_cuda_lite/src/runtime.rs:1010-1057`):
After each `execute()`, the runtime removed all HLIR buffers (weights, constants) except those
directly connected to Output nodes. This was intended to free one-shot input data, but it also
deleted all 168 weight buffers. On the next `graph.run()`, CuBlasLt couldn't find any of its
weight inputs — `hlir_buffers` had 1 entry (the just-set `input_ids`) instead of 169.
### Why it was hard to find
1. **Misdirection by the egglog syntax error**: The plan identified the missing `)` as THE cause.
Fixing it allowed the rule to parse correctly, but the real runtime failure was independent.
2. **Step 0 always succeeds**: The weight consumption happens AFTER a successful execution. So
the first `graph.run()` works perfectly — all 169 HLIR buffers exist. The panic only occurs
on the second call, after consumption has cleared 168 of them.
3. **The consumption code was deliberately designed**: Comments said "weight tensors must have
`.persist()` to survive." The ONNX pipeline didn't call `.persist()` on weights, but this
had never been a problem before because single-shot inference only calls `execute()` once.
4. **Search phase panics masked by `catch_unwind`**: The same "no entry found for key" error
occurred during profiling of search candidates, but was silently caught. This made it look
like only certain LLIR variants had the issue, not all of them.
5. **Debug output needed 4 iterations to find**: The first debug showed which NodeIndex was
missing, the second showed it was an Input node, the third showed the HLIR mapping, and
the fourth revealed `hlir_buffers_count` dropping from 169 to 1 between steps.
### The fix
1. Added missing `)` to the KernelEmbed egglog rule at `hlir.rs:3042`.
2. In `compiled_graph.rs`, added `.persist()` calls on all weight/constant tensors (anything
not in `input_names`) after `process_onnx_nodes` completes. `.persist()` creates an Output
node connected to the Input, which the consumption code recognizes as "do not consume."
User inputs (like `input_ids`) are intentionally NOT persisted — they are consumed after
each `execute()` and re-set via `set_input()` before the next call.
### General principle
**Mark weight/constant tensors as persistent in the graph-building pipeline.** The runtime's
`execute()` consumes all HLIR buffers not connected to Output nodes. This is correct behavior
for one-shot user inputs, but weights must survive across calls. Always call `.persist()` on
tensors that should outlive a single execution. In the ONNX pipeline, the distinction is clear:
`input_names` (user-provided data per step) vs everything else (weights/constants loaded once).
---
## 2026-03-20 — PT2 CUDA Search Rejects All Candidates: Integer Buffers Misinterpreted as Float NaN
### What the symptom was
`test_hf_llama_tiny` on CUDA via PT2 failed with:
`pyo3_runtime.PanicException: Failed to find a viable initial genome for group 0 after 100 attempts`
The search tried 100 different egglog rewrites and ALL were rejected by the `has_nan_outputs` check.
### What the actual root cause was
**Two issues, both required to fix:**
1. **Integer buffers misinterpreted as float in NaN check.** `has_nan_outputs` in
`luminal_cuda_lite/src/runtime.rs` checks ALL `self.buffers` by reinterpreting raw bytes
as `f32` and calling `is_nan()`. The PT2-translated graph has integer intermediate
buffers (from `arange`, `cast(Int)`, integer arithmetic for embedding index computation).
Certain valid `i32` bit patterns (e.g., large integers from `token_id * hidden_dim`)
have exponent=0xFF and non-zero mantissa when reinterpreted as f32 — matching the
IEEE 754 NaN pattern. This caused false NaN rejections for EVERY candidate genome.
2. **Real weights/constants loaded before search contain -inf.** The PT2 path loaded real
safetensors weights and model constants (including the causal attention mask with `-inf`
values) BEFORE the search. While the ONNX path also loads real initializer data before
search, the PT2 graph's different structure (more explicit integer operations) made the
integer NaN false-positive the blocking issue.
### Why it was hard to find
- The original plan diagnosed this as the same zero-dummy-data bug fixed on 2026-03-18.
Changing `0.0` to `1.0` was insufficient because the root cause was different.
- `has_nan_outputs` checking ALL intermediate buffers (not just outputs) masked the real
issue — the NaN was in integer index-computation buffers, not in the model's float outputs.
- The ONNX-translated graph didn't have this problem because it doesn't produce as many
integer intermediate buffers (ONNX embedding uses different ops).
- The NaN pattern was identical across all 100 search attempts, which was the key clue:
it was deterministic and independent of egglog rewrite choices, pointing to input data
or buffer interpretation rather than graph optimization issues.
### The fix
Four changes:
1. **`luminal_cuda_lite/src/kernel/mod.rs`** (`KernelOp` trait): Added `output_dtype()`
method with default `DType::F32`. Each kernel now reports its actual output dtype.
2. **`luminal_cuda_lite/src/kernel/hlir.rs`** and **`other_ops.rs`**: Overrode
`output_dtype()` in all kernels with a `dtype` field (returns `self.dtype`), plus
special cases: `KernelIota``DType::Int`, `KernelLessThan``DType::Bool`,
`KernelCast``self.out_dtype`.
3. **`luminal_cuda_lite/src/runtime.rs`** (`has_nan_outputs`): Replaced fragile
`format!("{:?}").contains("dtype: Int")` string matching with proper
`op.to_dialect::<dyn KernelOp>().output_dtype()` check. Only F32 buffers are
checked for NaN; integer and bool buffers are skipped.
4. **`rust/src/pt2_compiled_model.rs`** (`init_cuda_runtime`): Set ALL input nodes
(weights, constants, user inputs) to `vec![1.0f32; n_elements]` before search via
new `set_all_inputs_dummy_cuda` function, then reload real data after search.
This prevents any -inf values from the causal mask from polluting intermediate
float computations during profiling.
### General principle
**Never reinterpret integer buffer bytes as float for NaN checking.** When a graph has
mixed-dtype operations (float model computation + integer index computation), raw byte
buffers from integer kernels contain valid i32 values that look like NaN when cast to f32.
The search's `has_nan_outputs` must be dtype-aware — use the kernel's `output_dtype()`
method rather than string-matching on Debug output. Additionally, when diagnosing "all
candidates rejected" during search, check whether the rejection is from actual float NaN
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
identical across all attempts (dtype issue) vs varying (actual numerical issue).
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.

View File

View File

@@ -0,0 +1,369 @@
"""Run pytest on Modal with a dynamically selected GPU.
Usage:
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
uv run modal run modal_pytest_runner.py --gpu T4 tests/
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
"""
import argparse
import json
import os
import shutil
import subprocess
import sys
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import modal
from modal.volume import FileEntryType
app = modal.App("luminal-tests")
DEFAULT_TIMEOUT = 30 * 60
CUDARC_CUDA_VERSION = "12080"
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = "/root/luminal/crates/luminal_python"
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
SRC_PATH = f"{PROJECT_DIR}/src"
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
HF_TOKEN_ENV_KEY = "HF_TOKEN"
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
HF_CACHE_VOLUME = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
image = (
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
.uv_sync(
str(LOCAL_PROJECT_DIR),
frozen=False,
groups=["dev"],
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
)
.workdir(PROJECT_DIR)
.add_local_dir(
str(LOCAL_PROJECT_DIR.parent.parent),
remote_path="/root/luminal",
copy=True,
ignore=[
".git",
".claude-project",
".cargo-local",
"**/.venv",
"**/.pytest_cache",
"**/__pycache__",
"**/luminal_artifacts",
"**/target",
"docs",
],
)
)
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _hf_token_secret() -> modal.Secret | None:
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
if not hf_token:
return None
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
return any(arg == flag for arg in pytest_args)
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
return (
cli_profile
or _has_pytest_flag(pytest_args, "--profile")
or _has_pytest_flag(pytest_args, "--profile-svg")
)
def _run_id() -> str:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
def _prepare_scratch_dir(scratch_dir: Path) -> None:
scratch_dir.mkdir(parents=True, exist_ok=True)
linked_names = {
".venv",
".pytest_cache",
"__pycache__",
"luminal_artifacts",
"prof",
}
for entry in Path(PROJECT_DIR).iterdir():
if entry.name in linked_names:
continue
target = scratch_dir / entry.name
if target.exists() or target.is_symlink():
continue
target.symlink_to(entry, target_is_directory=entry.is_dir())
def _default_profile_output_dir(run_id: str) -> Path:
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
def _prepare_local_profile_dir(output_dir: Path) -> None:
if output_dir.exists() and not output_dir.is_dir():
raise NotADirectoryError(f"{output_dir} is not a directory")
output_dir.mkdir(parents=True, exist_ok=True)
prof_dir = output_dir / "prof"
if prof_dir.exists():
shutil.rmtree(prof_dir)
manifest_path = output_dir / "manifest.json"
if manifest_path.exists():
manifest_path.unlink()
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
_prepare_local_profile_dir(output_dir)
for entry in entries:
relative_path = Path(entry.path).relative_to(run_id)
if relative_path == Path("."):
continue
destination = output_dir / relative_path
if entry.type == FileEntryType.DIRECTORY:
destination.mkdir(parents=True, exist_ok=True)
continue
if entry.type != FileEntryType.FILE:
continue
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as handle:
for chunk in PROFILE_VOLUME.read_file(entry.path):
handle.write(chunk)
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
try:
PROFILE_VOLUME.remove_file(run_id, recursive=True)
except FileNotFoundError:
return
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
class TestRunner:
@modal.method()
def run(
self,
pytest_args: list[str],
pytest_addopts: str = "",
profile_enabled: bool = False,
) -> dict[str, Any]:
started_at = _utc_now()
run_id = _run_id() if profile_enabled else None
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
if scratch_dir is not None:
_prepare_scratch_dir(scratch_dir)
env = os.environ.copy()
existing = env.get("PYTHONPATH")
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
env["LUMINAL_BACKEND"] = "cuda"
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
env["HF_HOME"] = HF_CACHE_PATH
if pytest_addopts:
env["PYTEST_ADDOPTS"] = pytest_addopts
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
dot_available = shutil.which("dot") is not None
sanitized_pytest_args = [
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
]
if profile_enabled:
sanitized_pytest_args.append("--profile")
if dot_available:
sanitized_pytest_args.append("--profile-svg")
elif original_svg_requested:
print(
"Graphviz 'dot' is unavailable in the Modal container; "
"falling back to raw .prof artifacts only.",
file=sys.stderr,
)
svg_requested = profile_enabled and dot_available
cmd = [
"uv",
"run",
"--project",
PROJECT_DIR,
"--group",
"dev",
"--reinstall-package",
"luminal_python",
"python",
"-m",
"pytest",
*sanitized_pytest_args,
]
exit_code = subprocess.run(
cmd,
env=env,
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
).returncode
HF_CACHE_VOLUME.commit()
finished_at = _utc_now()
if not profile_enabled:
return {
"exit_code": exit_code,
"run_id": None,
"profile_enabled": False,
"remote_profile_dir": None,
"local_default_dirname": None,
}
volume_root = Path(PROFILE_VOLUME_PATH)
if not volume_root.exists():
raise RuntimeError(
"Profiling requested but the profile volume is not mounted."
)
remote_run_dir = volume_root / run_id
remote_run_dir.mkdir(parents=True, exist_ok=True)
prof_dir = scratch_dir / "prof"
if prof_dir.is_dir():
shutil.copytree(prof_dir, remote_run_dir / "prof")
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
manifest = {
"exit_code": exit_code,
"finished_at": finished_at,
"profile_enabled": True,
"pytest_args": sanitized_pytest_args,
"run_id": run_id,
"started_at": started_at,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
(remote_run_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
PROFILE_VOLUME.commit()
return {
"exit_code": exit_code,
"run_id": run_id,
"profile_enabled": True,
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
"local_default_dirname": run_id,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
def _parse_cli_args(
cli_args: tuple[str, ...],
) -> tuple[str, int | None, bool, str | None, list[str]]:
parser = argparse.ArgumentParser(
prog="modal run modal_pytest_runner.py",
add_help=False,
allow_abbrev=False,
description="Run pytest on Modal with a dynamically selected GPU.",
)
parser.add_argument(
"--gpu",
required=True,
help="GPU type to request from Modal (for example: A100, T4, H100).",
)
parser.add_argument(
"--timeout",
type=int,
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable pytest-profiling and download the resulting artifacts locally.",
)
parser.add_argument(
"--profile-output-dir",
help="Directory to download profiling artifacts into when profiling is enabled.",
)
parsed, pytest_args = parser.parse_known_args(cli_args)
if pytest_args and pytest_args[0] == "--":
pytest_args = pytest_args[1:]
if not pytest_args:
pytest_args = ["tests/"]
return (
parsed.gpu,
parsed.timeout,
parsed.profile,
parsed.profile_output_dir,
pytest_args,
)
@app.local_entrypoint()
def main(*cli_args: str):
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
cli_args
)
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
runner_options = {"gpu": gpu}
hf_token_secret = _hf_token_secret()
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
if timeout is not None:
runner_options["timeout"] = timeout
if profile_enabled:
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
runner_options["volumes"] = runner_volumes
if hf_token_secret is not None:
runner_options["secrets"] = [hf_token_secret]
runner = TestRunner.with_options(**runner_options)()
result = runner.run.remote(
pytest_args=pytest_args,
pytest_addopts=pytest_addopts,
profile_enabled=profile_enabled,
)
if result["profile_enabled"] and result["run_id"] is not None:
if profile_output_dir:
output_dir = Path(profile_output_dir).expanduser().resolve()
else:
output_dir = _default_profile_output_dir(result["local_default_dirname"])
try:
_download_profile_artifacts(result["run_id"], output_dir)
print(f"Profile artifacts downloaded to {output_dir}")
_cleanup_remote_profile_artifacts(result["run_id"])
except FileNotFoundError as exc:
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
except OSError as exc:
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
sys.exit(result["exit_code"])

View File

@@ -0,0 +1,65 @@
[project]
name = "luminal_python"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.0.2",
"torch>=2.10.0",
"onnx",
"onnxscript",
"safetensors",
"flash-attn-3>=3.0.0",
]
[tool.uv]
no-build-isolation-package = ["flash-attn"]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
flash-attn-3 = { index = "pytorch-cu128" }
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.maturin]
python-source = "src"
manifest-path = "rust/Cargo.toml"
module-name = "luminal.luminal"
[tool.pytest.ini_options]
markers = [
"slow: tests that download large models or require pre-generated artifacts",
]
[dependency-groups]
dev = [
"maturin>=1.0,<2.0",
"maturin-import-hook>=0.3.0",
"pytest>=9.0.2",
"pytest-profiling",
"snakeviz",
"pytest-randomly>=4.0.1",
"transformers>=5.5.0,<6",
"diffusers>=0.35.0",
"onnxsim",
"tiktoken>=0.12.0",
"pydantic>=2.12.5",
"psutil>=7.2.2",
"modal>=1.3.5",
"pillow",
"flash-attn>=2.8.3",
]
flash-attention-4 = [
"nvidia-cutlass-dsl==4.1.0",
]

View File

@@ -0,0 +1,44 @@
#!/bin/bash
set -e
echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
# ── Phase 1: Native Backend ─────────────────────────────────
echo ""
echo "=== Phase 1: Building native backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml
echo ""
echo "--- 1a: Native + ONNX ---"
uv run pytest $NATIVE_TESTS -v
echo ""
echo "--- 1b: Native + PT2 ---"
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
# ── Phase 2: CUDA Backend ───────────────────────────────────
echo ""
echo "=== Phase 2: Building CUDA backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
echo ""
echo "--- 2a: CUDA + ONNX ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "--- 2b: CUDA + PT2 ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "=========================================="
echo " All tests passed!"
echo "=========================================="

Some files were not shown because too many files have changed in this diff Show More