Compare commits

..

774 Commits
0.1 ... 0.2

Author SHA1 Message Date
Joe Fioti
be4fb7dd9f Changed expressionstorage vis 2024-03-01 14:16:04 -06:00
Joe Fioti
6b2c216e45 Update README.md 2024-03-01 11:05:17 -06:00
Joe Fioti
d309b4f338 Update README.md 2024-03-01 07:57:39 -06:00
Joe Fioti
d562a0321b Update README.md 2024-03-01 07:21:43 -06:00
Joe Fioti
1b61fd2e4d Pushed to cratesio 2024-02-29 22:52:52 -06:00
Joe Fioti
e0c9b2b1ff removed luminal_macro target 2024-02-29 22:42:31 -06:00
Joe Fioti
d976f71585 Begin publishing crates 2024-02-29 22:41:47 -06:00
Joe Fioti
00acf7aebe removed diffs 2024-02-29 14:41:14 -06:00
Joe Fioti
29751edf20 Passing metal tests 2024-02-29 13:52:23 -06:00
Joe Fioti
235db905da Update README.md 2024-02-29 12:37:57 -06:00
Joe Fioti
6337d90bce Update README.md 2024-02-29 12:37:36 -06:00
Joe Fioti
8c434b5081 Update README.md 2024-02-29 12:36:45 -06:00
Joe Fioti
84ba491b56 readme 2024-02-28 22:16:42 -06:00
Joe Fioti
c8044504c5 Refactored tests / api 2024-02-28 22:16:05 -06:00
Joe Fioti
9132ad8d94 Updated docs 2024-02-28 21:54:24 -06:00
Joe Fioti
c20b257657 tweak 2024-02-28 21:12:04 -06:00
Joe Fioti
89d9bbe105 Merge branch 'main' of https://github.com/jafioti/luminal 2024-02-28 21:10:05 -06:00
Joe Fioti
2f34d413e1 CPU mistral 2024-02-28 21:09:57 -06:00
Joe Fioti
a9875fde4d Fixed matvec test 2024-02-28 10:58:29 -06:00
Joe Fioti
c19e211629 readme 2024-02-27 22:30:23 -06:00
Joe Fioti
3e49033616 Match llama and mistral 2024-02-27 22:26:55 -06:00
Joe Fioti
f0135920aa Bring llama closer to mistral 2024-02-27 21:51:41 -06:00
Joe Fioti
8776a1c3de Small changess 2024-02-27 20:07:27 -06:00
Joe Fioti
b6022900b0 Refined llama 2024-02-27 19:46:10 -06:00
Joe Fioti
cc94802ed0 Update README.md 2024-02-27 18:35:58 -06:00
Joe Fioti
ed2fc61c73 Fixed test 2024-02-27 18:29:33 -06:00
Joe Fioti
ebbc86a312 CPU mistral 2024-02-27 18:11:05 -06:00
Joe Fioti
b065cdd22b Small tweaks 2024-02-27 16:12:58 -06:00
Joe Fioti
c7a0944eda gitignore 2024-02-27 16:05:17 -06:00
Joe Fioti
b2d6a48eab Updated llama 2024-02-27 16:04:05 -06:00
Joe Fioti
ca71f0dd16 CPU mistral 2024-02-27 15:56:04 -06:00
Joe Fioti
94306b086a feature change 2024-02-27 15:48:12 -06:00
Joe Fioti
7f2b9cf336 Mistral working on cuda 2024-02-27 15:44:30 -06:00
Joe Fioti
df1f5c3ca8 Fixed mistral metal 2024-02-27 12:53:48 -06:00
Joe Fioti
ffb7e4c706 Updated gitignore 2024-02-27 11:41:08 -06:00
Joe Fioti
598e303649 Split crates 2024-02-27 11:40:35 -06:00
Joe Fioti
866dfb7804 Spun compilers into crates 2024-02-27 10:43:25 -06:00
Joe Fioti
1cfefed1ce Update README.md 2024-02-27 00:14:08 -06:00
Joe Fioti
d5c6ef451c Update README.md 2024-02-27 00:11:25 -06:00
Joe Fioti
e363385d3f Update readme 2024-02-27 00:08:38 -06:00
Joe Fioti
258d1be49f updated deps 2024-02-26 23:37:40 -06:00
Joe Fioti
18f15d98e2 re-added sqrt 2024-02-26 23:11:58 -06:00
Joe Fioti
0b5ce105a3 removed sqrt 2024-02-26 17:16:55 -06:00
Joe Fioti
d66bf3412a Fixed cuda 2024-02-26 17:10:01 -06:00
Joe Fioti
5ede551fcb Merge pull request #24 from jafioti/q8
Q8 Weight Quantization
2024-02-26 14:33:26 -06:00
Joe Fioti
e10a19668c tweaks 2024-02-26 13:16:18 -06:00
Joe Fioti
881fa13a13 Working fused rope kernel 2024-02-26 12:58:55 -06:00
Joe Fioti
e279912bb8 Started rope metal 2024-02-23 16:42:32 -06:00
Joe Fioti
6dc2a996d2 furthur cleanup 2024-02-23 15:04:03 -06:00
Joe Fioti
7ee1cad15c Cleanup 2024-02-23 14:52:19 -06:00
Joe Fioti
b8a0f08cea Working 8bit mistral 2024-02-23 14:10:18 -06:00
Joe Fioti
edc96f3626 comparisons 2024-02-20 22:09:50 -06:00
Joe Fioti
dd369f18a9 temp 2024-02-17 11:47:05 -06:00
Joe Fioti
9326fe3cc8 broke af 2024-02-17 00:15:39 -06:00
Joe Fioti
3bd99c9f24 Kernel cleanup 2024-02-13 00:10:35 -06:00
Joe Fioti
bd56364160 Batched matvec 2024-02-13 00:05:12 -06:00
Joe Fioti
9547004247 Added quantized matvec 2024-02-12 14:27:11 -06:00
Joe Fioti
647f119d3c Fixed compiler macros 2024-02-09 11:53:00 -06:00
Joe Fioti
8952443ebd refactor metal compiler 2024-02-07 12:21:15 -06:00
Joe Fioti
5947e5cd3d refactor metal compiler 2024-02-07 12:12:58 -06:00
Joe Fioti
10d94710f7 Change feature flags 2024-02-07 12:02:56 -06:00
Joe Fioti
d13af7c562 Remove local cudarc fork 2024-02-07 12:01:39 -06:00
Joe Fioti
c2bbe446da Merged cuda compilers 2024-02-07 11:25:33 -06:00
Joe Fioti
b0a732e5b0 chagen readme 2024-02-07 10:56:37 -06:00
Joe Fioti
59cf7998c9 Fixed cuda tests 2024-02-07 10:12:44 -06:00
Joe Fioti
a6f38be402 Changed features 2024-02-06 21:52:19 -06:00
Joe Fioti
bc92e3137f Fixed many cuda bugs 2024-02-06 21:48:13 -06:00
Joe Fioti
30310a173d Update CONTRIBUTING.md 2024-02-06 17:22:08 -06:00
Joe Fioti
c00935b451 Addded contributing 2024-02-06 17:19:14 -06:00
Joe Fioti
15e4ee6aa3 fix doctests 2024-02-06 09:53:40 -06:00
Joe Fioti
9ec1e75fe6 tweak 2024-02-04 13:29:31 -06:00
Joe Fioti
5898076da5 Added documentation 2024-02-01 22:11:40 -06:00
Joe Fioti
5b17c1880e bug fix 2024-01-29 17:11:37 -06:00
Joe Fioti
1afea6bd86 renaming 2024-01-29 15:49:30 -06:00
Joe Fioti
8dff3619b9 Fixed speed 2024-01-29 15:46:19 -06:00
Joe Fioti
111452a68e Single mistral graph 2024-01-29 15:11:46 -06:00
Joe Fioti
d147ed5063 Mistral 10tps on M1 pro 2024-01-29 09:51:39 -06:00
Joe Fioti
162859dedb small changes 2024-01-28 16:17:49 -06:00
Joe Fioti
56de7fa4c3 small chagnes 2024-01-27 20:09:34 -06:00
Joe Fioti
7cc02dd51d core optimizations 2024-01-27 12:44:52 -06:00
Joe Fioti
e5963f1c9a Update Cargo.toml 2024-01-26 22:21:13 -06:00
Joe Fioti
9d32721ca7 dep changes 2024-01-26 21:07:00 -06:00
Joe Fioti
bc6b8fb283 Small kernel change 2024-01-26 19:57:24 -06:00
Joe Fioti
12381b2624 Changed tril triu api 2024-01-26 18:54:14 -06:00
Joe Fioti
2821145268 Removed isize 2024-01-26 17:40:15 -06:00
Joe Fioti
959528efad Added matmul support for repeated B batches 2024-01-26 17:32:39 -06:00
Joe Fioti
6a5a45eeae Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-25 21:22:14 -06:00
Joe Fioti
4166e27055 gemm refactor 2024-01-25 21:22:06 -06:00
Joe Fioti
f55cf6c0f7 Merge pull request #17 from TheSeamau5/debug
Small Improvements to main
2024-01-25 09:15:59 -06:00
Hassan Hayat
6ddabf2995 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 05:01:46 +01:00
Joe Fioti
54461a6d33 non-contiguous rotate 2024-01-24 21:24:26 -06:00
Hassan Hayat
b5d6f424d9 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 04:22:24 +01:00
Joe Fioti
f846af5901 rotation speedup 2024-01-24 21:20:43 -06:00
Hassan Hayat
f9c766dca7 Merge remote-tracking branch 'upstream/main' into debug 2024-01-25 03:41:07 +01:00
Joe Fioti
218db50c79 Small improvements to std_norm 2024-01-24 15:54:18 -06:00
Hassan Hayat
3fddb7e5a8 Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 23:22:32 +01:00
Joe Fioti
7bd8de272b steel matmul is ass 2024-01-23 16:21:02 -06:00
Joe Fioti
80915d3f3a Fixed rotate compiler 2024-01-23 15:19:38 -06:00
Joe Fioti
791f1395d5 Small changes 2024-01-23 14:26:15 -06:00
Hassan Hayat
b5a13381a9 Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 18:34:23 +01:00
Joe Fioti
c64e408471 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-23 10:19:44 -06:00
Joe Fioti
b1770a0b0e Added broken rotate op 2024-01-23 10:19:38 -06:00
Hassan Hayat
37dc4428af Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 09:03:28 +01:00
Joe Fioti
2d198b6be7 Rename LICENSE to LICENSE-APACHE 2024-01-22 21:47:39 -06:00
Joe Fioti
67e8e439c0 Create LICENSE 2024-01-22 21:47:26 -06:00
Joe Fioti
908d2c9222 Create LICENSE-MIT 2024-01-22 21:47:12 -06:00
Joe Fioti
c401a95af2 Update README.md 2024-01-22 21:46:00 -06:00
Joe Fioti
e2864d852f Update README.md 2024-01-22 21:44:53 -06:00
Hassan Hayat
f043ba2d5e Merge remote-tracking branch 'upstream/main' into debug 2024-01-23 04:44:01 +01:00
Joe Fioti
cf8412d3bf Small gemv change 2024-01-22 16:24:21 -06:00
Joe Fioti
5b4bde0070 Change how metal imports work gemv 2024-01-22 16:19:13 -06:00
Joe Fioti
9fead8dad3 Change how metal imports work 2024-01-22 16:17:13 -06:00
Joe Fioti
0d44507f3c Fused softmax kernel 2024-01-22 15:26:22 -06:00
Hassan Hayat
3272749663 Merge remote-tracking branch 'upstream/main' into debug 2024-01-22 19:37:46 +01:00
Joe Fioti
5f917dcbcf Removed one contiguous call 2024-01-22 10:54:29 -06:00
Hassan Hayat
85a08aca3f Merge remote-tracking branch 'upstream/main' into debug 2024-01-22 11:01:24 +01:00
Joe Fioti
192858edf1 Simplified mistral 2024-01-21 23:58:14 -06:00
Joe Fioti
9a5e6f6e69 Simplified mistral 2024-01-21 23:57:15 -06:00
Joe Fioti
6884bd010d Moved CSE to pre generic compiler 2024-01-21 23:40:50 -06:00
Hassan Hayat
9dd852c27e move clap to dev dependencies 2024-01-22 00:27:52 +01:00
Hassan Hayat
198fe76cb3 Update main.rs 2024-01-22 00:12:21 +01:00
Hassan Hayat
9696c4ce09 Improvements to main 2024-01-22 00:09:59 +01:00
Joe Fioti
9a2f8fadd3 Reorg 2024-01-21 13:30:33 -06:00
Joe Fioti
b59fefaa11 Reorganizing 2024-01-21 12:13:47 -06:00
Joe Fioti
8348d06902 reorganized tests 2024-01-21 12:06:13 -06:00
Joe Fioti
8f7f6a6ab3 Commonized metal compiler 2024-01-21 12:03:10 -06:00
Joe Fioti
13e6dc6da5 More commonalities 2024-01-21 11:54:18 -06:00
Joe Fioti
244711d46e Commonized matmul 2024-01-21 11:30:44 -06:00
Joe Fioti
9695bcef84 Fused constants 2024-01-21 10:33:47 -06:00
Joe Fioti
2f20b9959c Removed custom swish kernel 2024-01-21 10:22:04 -06:00
Joe Fioti
308938ec02 Fixed elementwise fusion 2024-01-21 10:07:15 -06:00
Joe Fioti
b1c435b6be Fixed matmuls 2024-01-21 09:32:59 -06:00
Joe Fioti
4219d8ec7b Fixed layer norm 2024-01-20 21:54:24 -06:00
Joe Fioti
8bd7598678 Generalized matmul compiler 2024-01-19 23:37:37 -06:00
Joe Fioti
e89bdbb612 Closer to working elementwise fusion 2024-01-19 17:45:55 -06:00
Joe Fioti
ebb0df6c69 Disabled elementwise fusion 2024-01-19 11:21:30 -06:00
Joe Fioti
8f2d13df3d Enabled elementwise on metal prims 2024-01-16 17:12:56 -06:00
Joe Fioti
69c207b599 Fixed fusion bugs 2024-01-16 17:06:29 -06:00
Joe Fioti
fa04b05b5d Custom fn util 2024-01-16 15:44:29 -06:00
Joe Fioti
54912c4f6a Initial version of elementwise fusion 2024-01-16 15:36:58 -06:00
Joe Fioti
1c0f525e57 Added looped compiler 2024-01-16 09:03:41 -06:00
Joe Fioti
26c0de512f Unified matmul and matvec 2024-01-15 21:25:43 -06:00
Joe Fioti
0c27cb02a8 util functions 2024-01-15 11:57:15 -06:00
Joe Fioti
b822800ffe export node index 2024-01-13 10:24:01 -06:00
Joe Fioti
b54da0ddde bring in line with ggml kernel 2024-01-12 16:23:09 -06:00
Joe Fioti
9295ff8d72 Changed matvec 2024-01-12 16:16:04 -06:00
Joe Fioti
e5dcff3f34 Test commit 2024-01-12 13:28:04 -06:00
Joe Fioti
a1acd5883b Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:26:32 -06:00
Joe Fioti
556e386621 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:23:08 -06:00
Joe Fioti
9f9256f08a Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-12 13:23:08 -06:00
Joe Fioti
f3c53c1193 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
9f668ee333 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
617ef95c09 Test commit 2024-01-12 13:17:38 -06:00
Joe Fioti
c539946c25 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
7e9f1c7fc0 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
cf0e6ad2f6 Test commit 2024-01-12 13:14:34 -06:00
Joe Fioti
9813b188f3 reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
bf7c1c5608 reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
ec09c0202b reversed mistral weight transpose 2024-01-12 11:54:38 -06:00
Joe Fioti
71365cf2d4 Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
481d074f5a Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
a240e2adc8 Added mlx matvec 2024-01-12 11:24:10 -06:00
Joe Fioti
c3643925ef removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
a6b368fa14 removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
ab9df3d94e removed kernel hashmap 2024-01-11 22:34:00 -06:00
Joe Fioti
c727113351 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
d203df40d5 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
c506d1e783 Added support for transpose in matmul 2024-01-11 22:01:21 -06:00
Joe Fioti
56ce86f194 Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
54a8ebc60d Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
b3e07bd638 Fixed somewhat 2024-01-11 21:19:23 -06:00
Joe Fioti
94a6a0a9e9 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
fb279c9ee6 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
3ae34ad3b3 unified matmuls 2024-01-11 21:07:05 -06:00
Joe Fioti
6b08212df8 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
03d2d02d00 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
0f09b19199 MLX matmul 2024-01-11 17:13:32 -06:00
Joe Fioti
fcf232699f Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
1ed89b5656 Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
69da97727b Added cumprod 2024-01-10 18:15:58 -06:00
Joe Fioti
9edf9cdc0b Fixed swish compiler 2024-01-10 15:46:53 -06:00
Joe Fioti
2f13fd6100 Fixed swish compiler 2024-01-10 15:46:53 -06:00
Joe Fioti
ed278c9be3 Merge pull request #12 from TheSeamau5/matmul
Minor improvement to f16 matmul, Longer prompt and token generation for testing
2024-01-10 12:38:32 -06:00
Joe Fioti
9e04457895 Merge pull request #12 from TheSeamau5/matmul
Minor improvement to f16 matmul, Longer prompt and token generation for testing
2024-01-10 12:38:32 -06:00
Joe Fioti
e6c4291db6 Update other.rs 2024-01-10 12:36:34 -06:00
Joe Fioti
f62e6ad85e Update other.rs 2024-01-10 12:36:34 -06:00
Hassan Hayat
0ba62fde38 Minor improvement to f16 matmul, Longer prompt and token generation for testing 2024-01-10 12:31:43 -06:00
Hassan Hayat
d62f2e217a Minor improvement to f16 matmul, Longer prompt and token generation for testing 2024-01-10 12:31:43 -06:00
Joe Fioti
f385ea287e Fix 2024-01-10 12:30:25 -06:00
Joe Fioti
140ee69480 Fix 2024-01-10 12:30:25 -06:00
Joe Fioti
2c93b7788c Simplified copy compiler 2024-01-10 12:29:43 -06:00
Joe Fioti
4fdc8f38eb Simplified copy compiler 2024-01-10 12:29:43 -06:00
Joe Fioti
c0645fe35e Small changes 2024-01-10 09:18:42 -06:00
Joe Fioti
5b5812defa Small changes 2024-01-10 09:18:42 -06:00
Joe Fioti
349e3d2472 Merge pull request #11 from TheSeamau5/fix_swish
Fix swish
2024-01-10 09:12:27 -06:00
Joe Fioti
fa67608d48 Merge pull request #11 from TheSeamau5/fix_swish
Fix swish
2024-01-10 09:12:27 -06:00
Hassan Hayat
527c20d146 Fix swish 2024-01-10 00:52:04 -06:00
Hassan Hayat
ff1da67423 Fix swish 2024-01-10 00:52:04 -06:00
Joe Fioti
efd7489a1c Small kernel simplifications 2024-01-09 22:22:36 -06:00
Joe Fioti
4dd7cd7cfd Small kernel simplifications 2024-01-09 22:22:36 -06:00
Joe Fioti
33274b905e Fixed 2024-01-09 22:04:12 -06:00
Joe Fioti
3670378bc6 Fixed 2024-01-09 22:04:12 -06:00
Joe Fioti
275180be20 Improvements to vecmat 2024-01-09 22:02:47 -06:00
Joe Fioti
40a62e70be Improvements to vecmat 2024-01-09 22:02:47 -06:00
Joe Fioti
95462aa89e Shapetracker hack 2024-01-09 10:02:32 -06:00
Joe Fioti
7a9f9e04d0 Shapetracker hack 2024-01-09 10:02:32 -06:00
Joe Fioti
cf35b286f2 organization 2024-01-09 09:56:45 -06:00
Joe Fioti
e1cf44a4e0 organization 2024-01-09 09:56:45 -06:00
Joe Fioti
b891b8b595 Added unused softmax op 2024-01-09 00:54:02 -06:00
Joe Fioti
67366e1a2f Added unused softmax op 2024-01-09 00:54:02 -06:00
Joe Fioti
ee8206e2ca Improved compiler matching 2024-01-08 23:29:09 -06:00
Joe Fioti
5cdc559241 Improved compiler matching 2024-01-08 23:29:09 -06:00
Joe Fioti
daa7166534 Small cse improvement 2024-01-08 16:30:04 -06:00
Joe Fioti
2cf0bc29c8 Small cse improvement 2024-01-08 16:30:04 -06:00
Joe Fioti
139ae0ddad ggml rms norm 2024-01-08 12:46:17 -06:00
Joe Fioti
703f4d3847 ggml rms norm 2024-01-08 12:46:17 -06:00
Joe Fioti
d79042d334 dyn symbols in ops 2024-01-07 16:42:45 -06:00
Joe Fioti
f9b52f0058 dyn symbols in ops 2024-01-07 16:42:45 -06:00
Joe Fioti
5b50192830 Swish op 2024-01-07 14:13:43 -06:00
Joe Fioti
ae431e0dd4 Swish op 2024-01-07 14:13:43 -06:00
Joe Fioti
35626309ac Fixed generic compiler 2024-01-07 00:06:02 -06:00
Joe Fioti
a38168a91c Fixed generic compiler 2024-01-07 00:06:02 -06:00
Joe Fioti
64ebab654f Small changes 2024-01-06 23:38:52 -06:00
Joe Fioti
ec0ea40bbe Small changes 2024-01-06 23:38:52 -06:00
Joe Fioti
49ae10a25e Fixes 2024-01-06 22:26:56 -06:00
Joe Fioti
1a1ba5216b Fixes 2024-01-06 22:26:56 -06:00
Joe Fioti
0bbc6215d8 named structs 2024-01-06 22:11:06 -06:00
Joe Fioti
4e5300c4d4 named structs 2024-01-06 22:11:06 -06:00
Joe Fioti
166d4a12a5 Small 2024-01-06 12:11:48 -06:00
Joe Fioti
e4f90c304b Small 2024-01-06 12:11:48 -06:00
Joe Fioti
fa966c8c7c Contiguous elimination 2024-01-06 12:11:17 -06:00
Joe Fioti
9a0261acd2 Contiguous elimination 2024-01-06 12:11:17 -06:00
Joe Fioti
743bacb125 Fixed graph selector bug and added broken contiguous elimination 2024-01-05 22:58:44 -06:00
Joe Fioti
d0afd42eb2 Fixed graph selector bug and added broken contiguous elimination 2024-01-05 22:58:44 -06:00
Joe Fioti
4c9691c49d Fast mistral loading 2024-01-05 10:13:01 -06:00
Joe Fioti
9aaff41dfa Fast mistral loading 2024-01-05 10:13:01 -06:00
Joe Fioti
a8b6508155 Fast mistral loading 2024-01-05 10:11:29 -06:00
Joe Fioti
a23e536fa0 Fast mistral loading 2024-01-05 10:11:29 -06:00
Joe Fioti
e654f3e72d No copy metal buffers 2024-01-03 20:39:55 -05:00
Joe Fioti
1a6ce5df82 No copy metal buffers 2024-01-03 20:39:55 -05:00
Joe Fioti
a6cd8d9b0f Small changes 2024-01-03 19:46:54 -05:00
Joe Fioti
8a62e090a3 Small changes 2024-01-03 19:46:54 -05:00
Joe Fioti
b550de47e4 reinterpret entire array at once 2024-01-03 19:29:40 -05:00
Joe Fioti
5bc2477352 reinterpret entire array at once 2024-01-03 19:29:40 -05:00
Joe Fioti
370973108d Changed embedding test 2024-01-03 19:22:03 -05:00
Joe Fioti
88ed1ded6d Changed embedding test 2024-01-03 19:22:03 -05:00
Joe Fioti
e9b8a883d0 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-03 19:19:34 -05:00
Joe Fioti
4a7db75715 Merge branch 'main' of https://github.com/jafioti/luminal 2024-01-03 19:19:34 -05:00
Joe Fioti
72b3cba68b removed embedding init 2024-01-03 19:19:33 -05:00
Joe Fioti
e7c78e9b46 removed embedding init 2024-01-03 19:19:33 -05:00
Joe Fioti
0bc32b9c92 Changed graphselector and to_ids 2024-01-03 19:16:28 -05:00
Joe Fioti
9b81ef2326 Changed graphselector and to_ids 2024-01-03 19:16:28 -05:00
Joe Fioti
cfc8e7dae2 Update README.md 2024-01-03 12:24:44 -05:00
Joe Fioti
09666f93ab Update README.md 2024-01-03 12:24:44 -05:00
Joe Fioti
b489a86fa9 Removed petgraph fork 2024-01-03 11:14:31 -05:00
Joe Fioti
4d4338fb58 Removed petgraph fork 2024-01-03 11:14:31 -05:00
Joe Fioti
805ebb1931 Fixed mistral and llama 2024-01-02 20:20:03 -05:00
Joe Fioti
a57b316216 Fixed mistral and llama 2024-01-02 20:20:03 -05:00
Joe Fioti
94e08ae947 reenabled metal tests 2024-01-02 13:34:02 -05:00
Joe Fioti
21aee96114 reenabled metal tests 2024-01-02 13:34:02 -05:00
Joe Fioti
ac802a3273 tests pasing 2024-01-02 13:13:29 -05:00
Joe Fioti
70f4fff5c2 tests pasing 2024-01-02 13:13:29 -05:00
Joe Fioti
f2e1c17c8c Remoded id_remap 2024-01-02 13:05:36 -05:00
Joe Fioti
9493c11a53 Remoded id_remap 2024-01-02 13:05:36 -05:00
Joe Fioti
7c72d5b06f Started adding remap infra 2024-01-02 12:48:50 -05:00
Joe Fioti
a15cfbae65 Started adding remap infra 2024-01-02 12:48:50 -05:00
Joe Fioti
34ab545763 Small changes 2024-01-01 21:26:04 -05:00
Joe Fioti
e67d3e6598 Small changes 2024-01-01 21:26:04 -05:00
Joe Fioti
621536a1dd Fixed cse 2024-01-01 14:04:59 -05:00
Joe Fioti
6d9f9176cd Fixed cse 2024-01-01 14:04:59 -05:00
Joe Fioti
2e81b54446 Arange fix 2024-01-01 12:04:32 -05:00
Joe Fioti
38acdf315e Arange fix 2024-01-01 12:04:32 -05:00
Joe Fioti
30dff8597c Fix 2024-01-01 11:32:11 -05:00
Joe Fioti
2ebd5f2deb Fix 2024-01-01 11:32:11 -05:00
Joe Fioti
162b8c38a1 Changeed hl_ops 2024-01-01 11:31:42 -05:00
Joe Fioti
1fb155ddfd Changeed hl_ops 2024-01-01 11:31:42 -05:00
Joe Fioti
241b9f527b Optimized storage compiler 2024-01-01 00:02:06 -05:00
Joe Fioti
53dc4dd9df Optimized storage compiler 2024-01-01 00:02:06 -05:00
Joe Fioti
7c7558fcb3 Optimized graph selector 2023-12-31 23:18:06 -05:00
Joe Fioti
5262e32346 Optimized graph selector 2023-12-31 23:18:06 -05:00
Joe Fioti
664fad5f84 Changed graph searcher 2023-12-31 13:33:40 -05:00
Joe Fioti
4c3e530ef3 Changed graph searcher 2023-12-31 13:33:40 -05:00
Joe Fioti
d582111d04 Working limited reuse 2023-12-30 17:27:03 -05:00
Joe Fioti
e9384dc714 Working limited reuse 2023-12-30 17:27:03 -05:00
Joe Fioti
b97da50c9d Update README.md 2023-12-30 14:18:52 -05:00
Joe Fioti
517124b424 Update README.md 2023-12-30 14:18:52 -05:00
Joe Fioti
0ef3121ac6 Update readme 2023-12-30 14:12:58 -05:00
Joe Fioti
542f74f404 Update readme 2023-12-30 14:12:58 -05:00
Joe Fioti
8662ba864d Update readme 2023-12-30 14:12:06 -05:00
Joe Fioti
0fc68006d5 Update readme 2023-12-30 14:12:06 -05:00
Joe Fioti
eac3a57b6d Update readme 2023-12-30 14:11:12 -05:00
Joe Fioti
f46bc1cb99 Update readme 2023-12-30 14:11:12 -05:00
Joe Fioti
8f7004c4c3 Merge pull request #10 from TheSeamau5/mistral
Shell script to download mistral that actually works
2023-12-30 13:59:05 -05:00
Joe Fioti
185facb1d5 Merge pull request #10 from TheSeamau5/mistral
Shell script to download mistral that actually works
2023-12-30 13:59:05 -05:00
Joe Fioti
07d0febef1 Fixed memory leak in shared storage buffers 2023-12-30 13:52:02 -05:00
Joe Fioti
d35a40eacb Fixed memory leak in shared storage buffers 2023-12-30 13:52:02 -05:00
Hassan Hayat
8a744e6035 Shell script that actually works 2023-12-30 02:26:31 -06:00
Hassan Hayat
50b47f8610 Shell script that actually works 2023-12-30 02:26:31 -06:00
Joe Fioti
c1af144891 Cleaned up llama 2023-12-29 21:19:19 -05:00
Joe Fioti
dd123fec89 Cleaned up llama 2023-12-29 21:19:19 -05:00
Joe Fioti
92cca97a76 Updates 2023-12-29 20:42:25 -05:00
Joe Fioti
a5d01c7576 Updates 2023-12-29 20:42:25 -05:00
Joe Fioti
84fbf805c3 Merge 2023-12-29 14:54:50 -05:00
Joe Fioti
51545ee82c Merge 2023-12-29 14:54:50 -05:00
Joe Fioti
e16771035f Small changes 2023-12-29 14:51:14 -05:00
Joe Fioti
10ee2c7343 Small changes 2023-12-29 14:51:14 -05:00
Joe Fioti
f637fff192 Merge pull request #9 from TheSeamau5/mistral
Add support for Mistral
2023-12-29 14:45:58 -05:00
Joe Fioti
3e0cafbae3 Merge pull request #9 from TheSeamau5/mistral
Add support for Mistral
2023-12-29 14:45:58 -05:00
Joe Fioti
d6c9c977d8 removed metal 2023-12-29 14:43:24 -05:00
Joe Fioti
1a0f59943e removed metal 2023-12-29 14:43:24 -05:00
Joe Fioti
5032d894b8 Fixed mistral example 2023-12-29 14:30:12 -05:00
Joe Fioti
2046ee9ade Fixed mistral example 2023-12-29 14:30:12 -05:00
Joe Fioti
d48ac14458 KV cached mistral 2023-12-29 14:21:52 -05:00
Joe Fioti
24ff638e43 KV cached mistral 2023-12-29 14:21:52 -05:00
Joe Fioti
7bb4e856ec Small alterations 2023-12-28 16:06:49 -05:00
Joe Fioti
be93cfe817 Small alterations 2023-12-28 16:06:49 -05:00
Joe Fioti
140aeb4591 Faster mistral 2023-12-28 15:59:01 -05:00
Joe Fioti
907dadc6a0 Faster mistral 2023-12-28 15:59:01 -05:00
Hassan Hayat
c9a1e5c47d Convert the transformer layers into an array 2023-12-28 02:06:05 -06:00
Hassan Hayat
f36d98363c Convert the transformer layers into an array 2023-12-28 02:06:05 -06:00
Hassan Hayat
c05c0e0575 Implemented serialize module, and call keep_weights 2023-12-27 23:17:09 -06:00
Hassan Hayat
123b48d5ec Implemented serialize module, and call keep_weights 2023-12-27 23:17:09 -06:00
Hassan Hayat
c4553fc132 Revert "Merge remote-tracking branch 'upstream/main' into mistral"
This reverts commit c2a11bf114, reversing
changes made to 22ae700048.
2023-12-27 21:19:08 -06:00
Hassan Hayat
b38be86191 Revert "Merge remote-tracking branch 'upstream/main' into mistral"
This reverts commit c2a11bf114, reversing
changes made to 22ae700048.
2023-12-27 21:19:08 -06:00
Hassan Hayat
da3970082a Merge remote-tracking branch 'upstream/main' into mistral 2023-12-27 21:06:24 -06:00
Hassan Hayat
c2a11bf114 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-27 21:06:24 -06:00
Joe Fioti
75a141d8ba Shared Metal storage buffers 2023-12-27 21:36:54 -05:00
Joe Fioti
ee17a48dbe Shared Metal storage buffers 2023-12-27 21:36:54 -05:00
Hassan Hayat
02df7e7f8d Spring cleaning 2023-12-27 19:30:22 -06:00
Hassan Hayat
22ae700048 Spring cleaning 2023-12-27 19:30:22 -06:00
Hassan Hayat
dbe6a42018 Iteration works! Slowly but it works 2023-12-27 11:50:45 -06:00
Hassan Hayat
7387ca1b19 Iteration works! Slowly but it works 2023-12-27 11:50:45 -06:00
Hassan Hayat
9e03c3421f Update model.rs 2023-12-27 02:16:21 -06:00
Hassan Hayat
8a6d088ff3 Update model.rs 2023-12-27 02:16:21 -06:00
Hassan Hayat
305e8f104c Remove more unused code 2023-12-27 02:08:21 -06:00
Hassan Hayat
472eae1576 Remove more unused code 2023-12-27 02:08:21 -06:00
Hassan Hayat
6c234daba2 Simplify code 2023-12-27 02:00:43 -06:00
Hassan Hayat
dfb8691923 Simplify code 2023-12-27 02:00:43 -06:00
Hassan Hayat
2784738e41 Loading is fast 2023-12-27 01:57:32 -06:00
Hassan Hayat
b86b27e0c7 Loading is fast 2023-12-27 01:57:32 -06:00
Hassan Hayat
ba3faa49df It works! 2023-12-27 00:38:45 -06:00
Hassan Hayat
c833a65153 It works! 2023-12-27 00:38:45 -06:00
Hassan Hayat
cab6b2fff2 Made the slices into expressions 2023-12-26 17:10:50 -06:00
Hassan Hayat
35e5da1ff4 Made the slices into expressions 2023-12-26 17:10:50 -06:00
Hassan Hayat
67aac97299 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 17:01:56 -06:00
Hassan Hayat
2b884d6304 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 17:01:56 -06:00
Joe Fioti
9fa0b8d0a5 Added symbolic slicing 2023-12-26 17:59:19 -05:00
Joe Fioti
422fd32d74 Added symbolic slicing 2023-12-26 17:59:19 -05:00
Hassan Hayat
1d88be2001 Update model.rs 2023-12-26 14:22:49 -06:00
Hassan Hayat
3d5c3180be Update model.rs 2023-12-26 14:22:49 -06:00
Hassan Hayat
47b61ac847 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 14:20:05 -06:00
Hassan Hayat
18560d0852 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 14:20:05 -06:00
Joe Fioti
1400aecf1d Small changes 2023-12-26 13:36:25 -05:00
Joe Fioti
9e3bea8cac Small changes 2023-12-26 13:36:25 -05:00
Joe Fioti
b4bf84840e Assign operators 2023-12-26 12:50:48 -05:00
Joe Fioti
941a8b93eb Assign operators 2023-12-26 12:50:48 -05:00
Joe Fioti
666cbe6c5a Llama reductions 2023-12-26 12:40:07 -05:00
Joe Fioti
33b7f0914f Llama reductions 2023-12-26 12:40:07 -05:00
Joe Fioti
2b2e06d6fa Simplified llama 2023-12-26 12:27:44 -05:00
Joe Fioti
eaa4ad8ef5 Simplified llama 2023-12-26 12:27:44 -05:00
Hassan Hayat
750a6e9e8b Remove unused imports 2023-12-26 11:13:30 -06:00
Hassan Hayat
0028b5ca78 Remove unused imports 2023-12-26 11:13:30 -06:00
Hassan Hayat
d4b18a0e35 Update model.rs 2023-12-26 11:00:34 -06:00
Hassan Hayat
22d7c563cb Update model.rs 2023-12-26 11:00:34 -06:00
Hassan Hayat
4671708601 Try to do an inference loop and failt 2023-12-26 10:50:03 -06:00
Hassan Hayat
9b3948a3ff Try to do an inference loop and failt 2023-12-26 10:50:03 -06:00
Hassan Hayat
4c415fba7b Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 10:48:51 -06:00
Hassan Hayat
f775833e10 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-26 10:48:51 -06:00
Joe Fioti
b1b06b1e15 Removed forward_kv from llama 2023-12-26 11:47:27 -05:00
Joe Fioti
bf8f3d91d2 Removed forward_kv from llama 2023-12-26 11:47:27 -05:00
Hassan Hayat
b40fb1a94b we have logits 2023-12-26 09:58:01 -06:00
Hassan Hayat
2e52833bb5 we have logits 2023-12-26 09:58:01 -06:00
Hassan Hayat
7e2518bbba comment out attention mask 2023-12-26 06:58:20 -06:00
Hassan Hayat
808cf7849e comment out attention mask 2023-12-26 06:58:20 -06:00
Hassan Hayat
1a454b23f8 Remove unused code 2023-12-26 06:39:53 -06:00
Hassan Hayat
bc4483706b Remove unused code 2023-12-26 06:39:53 -06:00
Hassan Hayat
8d0cff2b0b Start debugging the full transformer pass 2023-12-26 06:34:27 -06:00
Hassan Hayat
35097e8e2b Start debugging the full transformer pass 2023-12-26 06:34:27 -06:00
Hassan Hayat
1fdc8de899 Single layer is correct now 2023-12-26 06:00:23 -06:00
Hassan Hayat
995293e5da Single layer is correct now 2023-12-26 06:00:23 -06:00
Hassan Hayat
e9d7604f0b Successfully move the code into a forward method 2023-12-26 05:37:34 -06:00
Hassan Hayat
85824bb1ee Successfully move the code into a forward method 2023-12-26 05:37:34 -06:00
Hassan Hayat
652f0e365f Yay! A full layer now just werks 2023-12-25 19:17:30 -06:00
Hassan Hayat
554331f567 Yay! A full layer now just werks 2023-12-25 19:17:30 -06:00
Hassan Hayat
7311c8f48c Got query states working 2023-12-25 18:28:21 -06:00
Hassan Hayat
9a904b6dcc Got query states working 2023-12-25 18:28:21 -06:00
Hassan Hayat
8b4234eb60 Get rotary embeddings working dammit 2023-12-25 18:24:54 -06:00
Hassan Hayat
b121bcb20b Get rotary embeddings working dammit 2023-12-25 18:24:54 -06:00
Hassan Hayat
d63ceba488 Precompute rope once using throwaway graph 2023-12-25 14:18:22 -06:00
Hassan Hayat
67d8d6b992 Precompute rope once using throwaway graph 2023-12-25 14:18:22 -06:00
Hassan Hayat
0130d5dfd9 Found the issue with rotary embeddings
It was f16. Rotary embeddings have to be precomputed in f32
2023-12-25 13:51:31 -06:00
Hassan Hayat
ab8f7187e6 Found the issue with rotary embeddings
It was f16. Rotary embeddings have to be precomputed in f32
2023-12-25 13:51:31 -06:00
Hassan Hayat
0c291d594b Merge remote-tracking branch 'upstream/main' into mistral 2023-12-25 12:53:05 -06:00
Hassan Hayat
1d828f7982 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-25 12:53:05 -06:00
Hassan Hayat
6f3d52f345 save progress 2023-12-25 12:52:57 -06:00
Hassan Hayat
ed1f76808d save progress 2023-12-25 12:52:57 -06:00
Joe Fioti
a00fe78aa1 Fixed scalar ops 2023-12-25 09:33:37 -05:00
Joe Fioti
58a56f9fc0 Fixed scalar ops 2023-12-25 09:33:37 -05:00
Hassan Hayat
2254b4c96c save progress 2023-12-25 02:15:15 -06:00
Hassan Hayat
b6a0caa79b save progress 2023-12-25 02:15:15 -06:00
Hassan Hayat
4e7c6c27ce repeat kv 2023-12-25 01:32:00 -06:00
Hassan Hayat
4eb0a8e1fb repeat kv 2023-12-25 01:32:00 -06:00
Hassan Hayat
e222cb7a97 Update model.rs 2023-12-25 01:14:26 -06:00
Hassan Hayat
858f198b43 Update model.rs 2023-12-25 01:14:26 -06:00
Hassan Hayat
4b5872b5d1 Applying rotary embeddings work 2023-12-25 01:12:21 -06:00
Hassan Hayat
d2269eebf7 Applying rotary embeddings work 2023-12-25 01:12:21 -06:00
Hassan Hayat
7f8b21f71f Get rotate half working 2023-12-25 00:36:21 -06:00
Hassan Hayat
ca1703745f Get rotate half working 2023-12-25 00:36:21 -06:00
Hassan Hayat
c3f2547349 Update model.rs 2023-12-24 23:55:01 -06:00
Hassan Hayat
5c24050775 Update model.rs 2023-12-24 23:55:01 -06:00
Hassan Hayat
927fb9fac2 Still not there with attention 2023-12-24 23:29:39 -06:00
Hassan Hayat
167944b422 Still not there with attention 2023-12-24 23:29:39 -06:00
Hassan Hayat
4cec36f4b5 Fix broken division 2023-12-24 23:23:55 -06:00
Hassan Hayat
66fbf23d67 Fix broken division 2023-12-24 23:23:55 -06:00
Hassan Hayat
7e401c69c7 past norm, to query states 2023-12-24 21:45:10 -06:00
Hassan Hayat
5c4076bc8c past norm, to query states 2023-12-24 21:45:10 -06:00
Hassan Hayat
d2cb4f0d48 Get Mistral RMS norm working 2023-12-24 21:04:25 -06:00
Hassan Hayat
ef16ee6b23 Get Mistral RMS norm working 2023-12-24 21:04:25 -06:00
Hassan Hayat
c518caacf2 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:28:52 -06:00
Hassan Hayat
acef1725f3 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:28:52 -06:00
Joe Fioti
6cad14a20b Changed gather 2023-12-24 21:28:32 -05:00
Joe Fioti
c33333724d Changed gather 2023-12-24 21:28:32 -05:00
Hassan Hayat
8149440f8f Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:25:16 -06:00
Hassan Hayat
b4717747d5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:25:16 -06:00
Joe Fioti
9fc98f3288 Fixed contiguous 2023-12-24 21:07:48 -05:00
Joe Fioti
24347bf69c Fixed contiguous 2023-12-24 21:07:48 -05:00
Hassan Hayat
c5aa4d2975 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:07:26 -06:00
Hassan Hayat
9ec05b25a8 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 20:07:26 -06:00
Joe Fioti
bd83d880a9 Fixed metal prim compiler and other things 2023-12-24 21:05:45 -05:00
Joe Fioti
8037d370ee Fixed metal prim compiler and other things 2023-12-24 21:05:45 -05:00
Hassan Hayat
4a0a86577e Embeddings is correct 2023-12-24 19:25:10 -06:00
Hassan Hayat
75ea980bd2 Embeddings is correct 2023-12-24 19:25:10 -06:00
Hassan Hayat
de5049577c Push code 2023-12-24 11:47:45 -06:00
Hassan Hayat
5f99756be4 Push code 2023-12-24 11:47:45 -06:00
Hassan Hayat
33724c7214 found the source of crash 2023-12-24 11:35:16 -06:00
Hassan Hayat
4f75032c7e found the source of crash 2023-12-24 11:35:16 -06:00
Hassan Hayat
a45b4b6e85 Focus on debug 2023-12-24 11:03:29 -06:00
Hassan Hayat
912db261fe Focus on debug 2023-12-24 11:03:29 -06:00
Hassan Hayat
fad53704fd include the correct value for rope theta 2023-12-24 07:17:43 -06:00
Hassan Hayat
29aeac0531 include the correct value for rope theta 2023-12-24 07:17:43 -06:00
Hassan Hayat
a426971470 Found a panic 2023-12-24 07:12:55 -06:00
Hassan Hayat
1bd50bff21 Found a panic 2023-12-24 07:12:55 -06:00
Hassan Hayat
d2d733b931 Add print statement, find all zeros 2023-12-24 06:03:24 -06:00
Hassan Hayat
1ad6edd9ce Add print statement, find all zeros 2023-12-24 06:03:24 -06:00
Hassan Hayat
d924809d85 Fix grouped query attention 2023-12-24 05:51:45 -06:00
Hassan Hayat
24b1b324e6 Fix grouped query attention 2023-12-24 05:51:45 -06:00
Hassan Hayat
531b28f75a Test inference code 2023-12-24 02:58:05 -06:00
Hassan Hayat
ce40bb7f58 Test inference code 2023-12-24 02:58:05 -06:00
Hassan Hayat
be667fb936 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 02:55:33 -06:00
Hassan Hayat
de2a2c8bb8 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-24 02:55:33 -06:00
Joe Fioti
55e68dff43 Moved allocations outside MetalKernelForward 2023-12-23 15:52:34 -05:00
Joe Fioti
e922d565a7 Moved allocations outside MetalKernelForward 2023-12-23 15:52:34 -05:00
Hassan Hayat
1763e85aa7 Remove unneeded annotation 2023-12-23 13:50:21 -06:00
Hassan Hayat
5d10422881 Remove unneeded annotation 2023-12-23 13:50:21 -06:00
Hassan Hayat
a004408327 Implement generic arange and argmax 2023-12-23 13:49:22 -06:00
Hassan Hayat
cc0b34a640 Implement generic arange and argmax 2023-12-23 13:49:22 -06:00
Hassan Hayat
21596a01d7 Successfully load all the weights 2023-12-23 02:02:54 -06:00
Hassan Hayat
68f0c6f6ca Successfully load all the weights 2023-12-23 02:02:54 -06:00
Hassan Hayat
3a2ab1d176 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 17:21:12 -06:00
Hassan Hayat
ff7289ef39 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 17:21:12 -06:00
Hassan Hayat
0b370359c4 precompute inverse freqs works 2023-12-22 16:41:27 -06:00
Hassan Hayat
d2b720da3f precompute inverse freqs works 2023-12-22 16:41:27 -06:00
Joe Fioti
ef1054a921 fixed embedding test 2023-12-22 16:27:11 -05:00
Joe Fioti
0b30af2a7a fixed embedding test 2023-12-22 16:27:11 -05:00
Joe Fioti
5d97b4ee52 Fixed metal subtraction and llama 2023-12-22 16:23:45 -05:00
Joe Fioti
e179494ac4 Fixed metal subtraction and llama 2023-12-22 16:23:45 -05:00
Hassan Hayat
db2fc3cbb0 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 13:46:05 -06:00
Hassan Hayat
935caa24ce Merge remote-tracking branch 'upstream/main' into mistral 2023-12-22 13:46:05 -06:00
Joe Fioti
268a9b2cf8 Added metal equal compiler 2023-12-22 10:38:29 -05:00
Joe Fioti
5d8238bcf4 Added metal equal compiler 2023-12-22 10:38:29 -05:00
Hassan Hayat
a1c4f18725 Almost done loading model 2023-12-22 02:13:39 -06:00
Hassan Hayat
19ec1f1d36 Almost done loading model 2023-12-22 02:13:39 -06:00
Hassan Hayat
3032c685cd yoke 2023-12-22 00:43:48 -06:00
Hassan Hayat
c890ebdbe1 yoke 2023-12-22 00:43:48 -06:00
Hassan Hayat
a26d2fe86f Get it to compile again 2023-12-21 23:32:08 -06:00
Hassan Hayat
312305fcb7 Get it to compile again 2023-12-21 23:32:08 -06:00
Hassan Hayat
a402a29f93 Implement the model 2023-12-21 23:00:33 -06:00
Hassan Hayat
ed964105ec Implement the model 2023-12-21 23:00:33 -06:00
Hassan Hayat
414a3dcc83 Initial attention impl 2023-12-21 21:41:12 -06:00
Hassan Hayat
c51e87385f Initial attention impl 2023-12-21 21:41:12 -06:00
Hassan Hayat
fec403b9f5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-21 17:35:06 -06:00
Hassan Hayat
90e06d90e5 Merge remote-tracking branch 'upstream/main' into mistral 2023-12-21 17:35:06 -06:00
Hassan Hayat
25bf6ee63a Load the embeddings properly 2023-12-21 17:31:51 -06:00
Hassan Hayat
9e5880b130 Load the embeddings properly 2023-12-21 17:31:51 -06:00
Joe Fioti
9220e7b1e0 redid metal arange compiler 2023-12-21 14:54:16 -05:00
Joe Fioti
4a97c8bee9 redid metal arange compiler 2023-12-21 14:54:16 -05:00
Joe Fioti
93d45509ad Added subtraction compiler 2023-12-21 14:20:19 -05:00
Joe Fioti
c0d0ec0c32 Added subtraction compiler 2023-12-21 14:20:19 -05:00
Hassan Hayat
c4b4233e20 Include start token 2023-12-21 06:09:14 -06:00
Hassan Hayat
61e59b27ec Include start token 2023-12-21 06:09:14 -06:00
Hassan Hayat
d16d22492e Get tokenizer working 2023-12-21 06:02:33 -06:00
Hassan Hayat
3a25325d37 Get tokenizer working 2023-12-21 06:02:33 -06:00
Joe Fioti
4599fec534 Primop gather 2023-12-20 19:36:11 -05:00
Joe Fioti
1637d0fdb8 Primop gather 2023-12-20 19:36:11 -05:00
Joe Fioti
13de77b68c New common metal buffer compiler 2023-12-20 16:56:16 -05:00
Joe Fioti
700d8f71e2 New common metal buffer compiler 2023-12-20 16:56:16 -05:00
Joe Fioti
84eea2a0eb Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-20 10:38:02 -05:00
Joe Fioti
b2be7b2583 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-20 10:38:02 -05:00
Joe Fioti
5c396368b6 Changed readme 2023-12-20 10:37:57 -05:00
Joe Fioti
7335d07755 Changed readme 2023-12-20 10:37:57 -05:00
Joe Fioti
b63746fe84 Merge pull request #8 from TheSeamau5/conv2d
Conv2d
2023-12-20 09:35:38 -05:00
Joe Fioti
ef964536e9 Merge pull request #8 from TheSeamau5/conv2d
Conv2d
2023-12-20 09:35:38 -05:00
Hassan Hayat
f96e3a903e Conv2D forward implemented 2023-12-19 23:23:13 -06:00
Hassan Hayat
7ec82a97d6 Conv2D forward implemented 2023-12-19 23:23:13 -06:00
Hassan Hayat
741b167910 First implementation of conv2d 2023-12-19 23:02:32 -06:00
Hassan Hayat
268fb4e9aa First implementation of conv2d 2023-12-19 23:02:32 -06:00
Joe Fioti
3b0b264ba5 Update README.md 2023-12-19 20:19:08 -05:00
Joe Fioti
1c7a3b8ed9 Update README.md 2023-12-19 20:19:08 -05:00
Joe Fioti
7c307c886e Made conv forward public 2023-12-19 20:00:50 -05:00
Joe Fioti
e00a89c647 Made conv forward public 2023-12-19 20:00:50 -05:00
Joe Fioti
c6d37ed5c5 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-19 19:59:54 -05:00
Joe Fioti
deef279977 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-19 19:59:54 -05:00
Joe Fioti
835527333c Move cumsum 2023-12-19 19:59:47 -05:00
Joe Fioti
b2735b8dc6 Move cumsum 2023-12-19 19:59:47 -05:00
Joe Fioti
7050a8bd7a Merge pull request #7 from TheSeamau5/conv1d
Conv1D module
2023-12-19 19:54:10 -05:00
Joe Fioti
2c6ac7124e Merge pull request #7 from TheSeamau5/conv1d
Conv1D module
2023-12-19 19:54:10 -05:00
Hassan Hayat
cc0c2bf8cb Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 16:40:57 -06:00
Hassan Hayat
e335bb24df Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 16:40:57 -06:00
Joe Fioti
ef0768ebef ARange in llama 2023-12-19 17:39:19 -05:00
Joe Fioti
a6c8c4c254 ARange in llama 2023-12-19 17:39:19 -05:00
Hassan Hayat
1f81ffb182 Remove extra comment 2023-12-19 16:32:39 -06:00
Hassan Hayat
23b7937507 Remove extra comment 2023-12-19 16:32:39 -06:00
Hassan Hayat
ac23472220 Remove extra comments 2023-12-19 16:31:49 -06:00
Hassan Hayat
0e07eb7614 Remove extra comments 2023-12-19 16:31:49 -06:00
Hassan Hayat
f6e2fd1be2 Fix and pass the tests, define conv as a Rank-2 tensor (remove a reshape) 2023-12-19 16:31:10 -06:00
Hassan Hayat
ddc6644a87 Fix and pass the tests, define conv as a Rank-2 tensor (remove a reshape) 2023-12-19 16:31:10 -06:00
Joe Fioti
6d987df3e2 Small change 2023-12-19 17:23:34 -05:00
Joe Fioti
7b2fd581b6 Small change 2023-12-19 17:23:34 -05:00
Joe Fioti
6f810111c4 Tril and triu 2023-12-19 17:20:57 -05:00
Joe Fioti
3b154540da Tril and triu 2023-12-19 17:20:57 -05:00
Joe Fioti
0675610007 ARange, better symbolic minimizer 2023-12-19 17:02:50 -05:00
Joe Fioti
2ae67dd894 ARange, better symbolic minimizer 2023-12-19 17:02:50 -05:00
Hassan Hayat
8623843e72 Remove extra .DS_Store 2023-12-19 16:01:37 -06:00
Hassan Hayat
98ef29fec0 Remove extra .DS_Store 2023-12-19 16:01:37 -06:00
Hassan Hayat
7d37b56c20 Add harder test, doesn't pass yet 2023-12-19 15:42:21 -06:00
Hassan Hayat
54c48df279 Add harder test, doesn't pass yet 2023-12-19 15:42:21 -06:00
Hassan Hayat
7460fcde9d Simple design, no pool_out 2023-12-19 15:16:21 -06:00
Hassan Hayat
fc2a56039a Simple design, no pool_out 2023-12-19 15:16:21 -06:00
Hassan Hayat
b29f8e3a0f Alternative design, custom forward with generics at the forward function 2023-12-19 14:51:47 -06:00
Hassan Hayat
3031ead6dc Alternative design, custom forward with generics at the forward function 2023-12-19 14:51:47 -06:00
Hassan Hayat
20951c0721 Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 14:28:37 -06:00
Hassan Hayat
75b1064922 Merge remote-tracking branch 'upstream/main' into conv1d 2023-12-19 14:28:37 -06:00
Joe Fioti
1a4135515b Symbolic changes 2023-12-19 11:28:15 -05:00
Joe Fioti
acbb1b6e2c Symbolic changes 2023-12-19 11:28:15 -05:00
Hassan Hayat
c0632cb689 Update convolution.rs 2023-12-18 22:35:58 -06:00
Hassan Hayat
144e3b7a98 Remove unnecessary comments 2023-12-18 22:30:58 -06:00
Hassan Hayat
dfd21a343b Conv1D module first pass 2023-12-18 22:29:50 -06:00
Joe Fioti
0faadea621 Added cumsum 2023-12-18 23:28:37 -05:00
Joe Fioti
0dd8f4b7c7 2D convolutions 2023-12-18 18:17:42 -05:00
Joe Fioti
96e39c2535 1D last dim pooling on 2D tensors 2023-12-18 11:51:09 -06:00
Joe Fioti
909d5b7836 Pooling with dilation 2023-12-17 17:06:39 -06:00
Joe Fioti
1125351f4c 1D pooling 2023-12-17 12:45:04 -06:00
Joe Fioti
345622f452 Merge branch 'main' of https://github.com/jafioti/luminal 2023-12-15 21:55:36 -06:00
Joe Fioti
53b9bd6e61 Added MetalArange 2023-12-15 21:55:29 -06:00
Joe Fioti
e7d0a08150 Merge pull request #4 from TheSeamau5/arange
Simple Pooling implementation
2023-12-15 11:23:24 -06:00
Joe Fioti
0939f50ce2 removed fake sum reduction, generalized constants 2023-12-15 11:18:31 -06:00
Hassan Hayat
84d7a0cedc Remove unused imports 2023-12-15 10:51:13 -06:00
Hassan Hayat
c9c540057b Let's keep it simple for now, kernel size = stride 2023-12-15 10:46:19 -06:00
Joe Fioti
a2edbe14ec Commonize more metal compilers 2023-12-15 10:46:06 -06:00
Hassan Hayat
694fa93d30 Reverting to simpler impl 2023-12-15 01:58:05 -06:00
Hassan Hayat
4214a33525 Save code 2023-12-15 01:34:06 -06:00
Hassan Hayat
404322b4ab Save code 2023-12-15 01:23:26 -06:00
Joe Fioti
afd3eeee88 Removed Function output type 2023-12-14 21:13:56 -06:00
Joe Fioti
84adc99c33 llama cleanups 2023-12-14 21:00:12 -06:00
Joe Fioti
3f4b592c60 Fixed 2023-12-14 20:47:06 -06:00
Joe Fioti
d61c848f6a Added more symbolic minimization rules 2023-12-14 20:38:03 -06:00
Joe Fioti
94c7d00517 Small changes 2023-12-14 19:44:01 -06:00
Hassan Hayat
e799363d0d remove unused code 2023-12-14 19:41:24 -06:00
Hassan Hayat
d0d7f74e42 remove comments 2023-12-14 19:39:08 -06:00
Hassan Hayat
de5835822d Merge remote-tracking branch 'upstream/main' into arange 2023-12-14 19:29:45 -06:00
Hassan Hayat
77fb4305e8 add tests 2023-12-14 19:28:05 -06:00
Joe Fioti
4cdb364e4a Metal vecmat 2023-12-14 19:27:23 -06:00
Hassan Hayat
fbebf6d485 Added more tests 2023-12-14 15:41:16 -06:00
Hassan Hayat
4fde0f4524 Make pool n-dimensional 2023-12-14 14:22:36 -06:00
Joe Fioti
6f3cff1cd4 symbolic changes 2023-12-13 15:02:08 -06:00
Hassan Hayat
b90847c43f Start working on pooling 2023-12-13 13:55:32 -06:00
Joe Fioti
e5c7c8b2a2 Removed indexer 2023-12-12 23:47:58 -06:00
Joe Fioti
9e453719e3 Unified expressions 2023-12-12 22:36:38 -06:00
Joe Fioti
63b04f1e9a remvoed checking stuff in print 2023-12-12 16:50:28 -06:00
Joe Fioti
904baefa68 removed unsafe graph ref dereference 2023-12-12 16:48:17 -06:00
Joe Fioti
c82a00981a Tweaks 2023-12-12 15:07:13 -06:00
Joe Fioti
d1add4231f Fixed slow vecmat 2023-12-12 15:05:29 -06:00
Joe Fioti
678591a1a5 Low performance vecmat 2023-12-12 13:16:07 -06:00
Joe Fioti
f4a07f5259 Small refinements 2023-12-11 22:19:28 -06:00
Joe Fioti
e5e904498c Batch matmul fixes 2023-12-11 20:34:42 -06:00
Joe Fioti
89740bdd30 small 2023-12-09 09:36:20 -06:00
Joe Fioti
c6b72fa317 still broken bmm 2023-12-09 09:35:52 -06:00
Joe Fioti
80b917b02f Added RemapDownstream compiler 2023-12-08 22:31:19 -06:00
Joe Fioti
971361feac Removed copy remap 2023-12-08 21:53:05 -06:00
Joe Fioti
4a553724a2 Symblic changes 2023-12-08 14:17:24 -06:00
Joe Fioti
b87b30f045 simd batch matmul 2023-12-08 12:01:54 -06:00
Joe Fioti
a3a69f53da Working paddded batch matmul 2023-12-08 11:55:04 -06:00
Joe Fioti
cb659f3c25 improved symbolic algebra minimizer 2023-12-08 10:08:32 -06:00
Joe Fioti
8135540b22 Changed metal input rendering 2023-12-07 20:45:05 -06:00
Joe Fioti
3a10b6f4db Changed test 2023-12-07 16:41:02 -06:00
Joe Fioti
82d4a96ae1 Simd matmul 2D 2023-12-07 16:25:45 -06:00
Joe Fioti
802091e15e Removed expression interfaces 2023-12-07 00:00:43 -06:00
Joe Fioti
4292259db1 Replace dims with expressions 2023-12-06 23:36:35 -06:00
Joe Fioti
b6efabf216 Expr interface 2023-12-06 16:02:58 -06:00
Joe Fioti
7e5471bdfa Added big expression 2023-12-06 15:05:30 -06:00
Joe Fioti
5fa5aff813 Added small symbolic algebra lib and removed savage 2023-12-06 14:57:25 -06:00
Joe Fioti
5035ad1d99 Small improvements' 2023-12-06 11:56:29 -06:00
Joe Fioti
63797c90f9 small changes 2023-12-05 21:44:50 -06:00
Joe Fioti
8b475ea4f2 Fix CI 2023-12-05 10:28:48 -06:00
Joe Fioti
1cd4fb2e73 Merge 2023-12-05 10:21:10 -06:00
Joe Fioti
946ea8dfb8 Hybrid matmul 2023-12-05 10:15:15 -06:00
Joe Fioti
8c264fb2a5 broken matmul 2023-12-03 19:17:43 -05:00
Joe Fioti
b96b792612 Changed tensor api 2023-11-28 18:41:56 -05:00
Joe Fioti
909ea995b6 Fixed tests 2023-11-27 14:36:41 -05:00
Joe Fioti
4e197b512f Changed metal dispatching 2023-11-27 14:35:24 -05:00
Joe Fioti
aa3f8cce3d Matmul correction 2023-11-27 14:24:32 -05:00
Joe Fioti
00dcc29eb1 Faster batch matmul 2023-11-27 14:03:00 -05:00
Joe Fioti
032cec5c5a Compile time col-row-major ordering 2023-11-26 16:31:15 -05:00
Joe Fioti
f0d6fedc90 Small rmsnorm opt 2023-11-24 10:33:53 -05:00
Joe Fioti
bb9ff4f113 removed mutex from shared command buffer 2023-11-21 16:50:11 -06:00
Joe Fioti
1c3f6735f8 Remvoed metal attn matmul 2023-11-21 15:45:33 -06:00
Joe Fioti
2d210641d3 Added metal gpu gather 2023-11-21 13:59:17 -06:00
Joe Fioti
4078d895c7 Common metal prim ops 2023-11-20 12:17:37 -06:00
Joe Fioti
d67820b6ba Finished cuda prim unification 2023-11-20 10:42:42 -06:00
Joe Fioti
6869047b44 Unifyed cuda unary ops 2023-11-20 01:02:22 -06:00
Joe Fioti
ba7c3972b5 Common cuda copyto and copyfrom 2023-11-19 23:08:04 -06:00
Joe Fioti
f931504a09 Fixed cuda 2023-11-19 22:13:10 -06:00
Joe Fioti
52c18171a1 Small 2023-11-19 16:57:49 -06:00
Joe Fioti
b89dbefb3c Added set marking 2023-11-19 16:54:25 -06:00
Joe Fioti
07936bc8e4 remove mac test 2023-11-19 15:25:44 -06:00
Joe Fioti
647eda7895 Update and rename rust.yml to test.yml 2023-11-19 15:25:19 -06:00
Joe Fioti
5bb703084c tests passing 2023-11-19 15:13:17 -06:00
Joe Fioti
e7283e9105 Fixed native concats 2023-11-19 14:12:50 -06:00
Joe Fioti
7102d06e73 Partial fix 2023-11-18 10:34:45 -06:00
Joe Fioti
cf54dee88e broken version of native concat 2023-11-18 09:03:47 -06:00
Joe Fioti
854864ac5e Merge branch 'main' of https://github.com/jafioti/luminal 2023-11-18 08:38:33 -06:00
Joe Fioti
3eceaae45f Added dim slices and pading 2023-11-18 08:38:25 -06:00
Joe Fioti
e604a8cba0 Update README.md 2023-11-17 23:28:58 -06:00
Joe Fioti
c351acb075 more optimizations 2023-11-17 23:27:05 -06:00
Joe Fioti
d880efc1db graph selector optimizations 2023-11-16 18:07:16 -06:00
Joe Fioti
e118b293fd toposort at compile 2023-11-16 17:40:37 -06:00
Joe Fioti
254996063a Reworked selector API 2023-11-16 17:37:19 -06:00
Joe Fioti
a85b2ac301 Shared Metal Command Buffers 2023-11-16 15:21:01 -06:00
Joe Fioti
d2f8471943 action change 2023-11-12 14:43:06 -06:00
Joe Fioti
b6a7a3bc1e action change 2023-11-12 14:41:15 -06:00
Joe Fioti
3b3007cbdd Changed action 2023-11-12 14:38:27 -06:00
Joe Fioti
adc2092275 Macos gpu action 2023-11-12 14:35:21 -06:00
Joe Fioti
96831f2d4e Small 2023-11-12 12:04:50 -06:00
Joe Fioti
baf8664d10 Small changes 2023-11-12 11:55:15 -06:00
Joe Fioti
d071fd5397 Pasing tests 2023-11-12 11:38:00 -06:00
Joe Fioti
44f9415811 New testing, fixed cpu bug 2023-11-12 10:34:55 -06:00
Joe Fioti
b104364edb Removed simple tracker again 2023-11-10 10:19:11 -05:00
Joe Fioti
24bbf0ead9 Removed dyn_data 2023-11-10 10:18:15 -05:00
Joe Fioti
e8af292958 Changed tests 2023-11-09 22:17:53 -05:00
Joe Fioti
8d14f83bc3 Simplifications and API changes 2023-11-09 22:15:48 -05:00
Joe Fioti
2ff89167c2 Closer to working CommonBufferCompiler 2023-11-09 21:18:59 -05:00
Joe Fioti
87854bbdf0 Working llama at same speed as before 2023-11-08 16:38:51 -05:00
Joe Fioti
a3a4a972d7 optimized common buffer compilation 2023-11-06 21:23:25 -05:00
Joe Fioti
75fbb709d7 Share command queues on fp16 primops 2023-11-06 16:28:01 -05:00
Joe Fioti
6a311347bf New common buffer 2023-11-06 16:04:20 -05:00
Joe Fioti
2787fdd8b6 Switched to compilers 2023-11-05 22:31:37 -05:00
Joe Fioti
634f5c26ee Added schedule dependencies 2023-11-05 14:04:28 -05:00
Joe Fioti
e7683ac3ff Bugged common buffer 2023-11-01 22:20:42 -04:00
Joe Fioti
72fdc3bcfe Added preliminary internal graph to shared buffer op 2023-10-27 21:47:27 -04:00
Joe Fioti
271977d1dd Added unary shared metal command buffer 2023-10-27 21:17:34 -04:00
Joe Fioti
e6de090ed3 Started shared command buffers 2023-10-26 21:53:17 -04:00
Joe Fioti
65c0224ae5 Small 2023-10-15 22:03:04 -05:00
Joe Fioti
e957a4c99a Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-14 09:53:09 -05:00
Joe Fioti
4e6d5b733c Added constant to fakesumreduce opt 2023-10-14 09:53:03 -05:00
Joe Fioti
be591b2f4a Removed arch flags 2023-10-11 23:09:35 -05:00
Joe Fioti
bb90b73533 Fixed example 2023-10-11 23:03:25 -05:00
Joe Fioti
0abf5c2379 Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-11 17:31:36 -05:00
Joe Fioti
ccbf55923d Small changes 2023-10-11 11:56:05 -05:00
Joe Fioti
eb842428b7 Small changes 2023-10-10 23:04:24 -05:00
Joe Fioti
e61aa736db Added slowbatch matmul 2023-10-10 19:42:03 -05:00
Joe Fioti
8794afb246 re-added toposort caching 2023-10-10 16:00:05 -05:00
Joe Fioti
ddf32b6215 Cached weights 2023-10-10 15:57:17 -05:00
Joe Fioti
811fe65412 Working llama fp16 metal 2023-10-10 15:10:13 -05:00
Joe Fioti
0ad73d19ed Added metal fp16 copy opt 2023-10-03 13:02:55 -05:00
Joe Fioti
0b845dc7ee Merged 2023-10-03 12:48:01 -05:00
Joe Fioti
a0449b4d6b CopyOptimizer 2023-10-03 12:46:13 -05:00
Joe Fioti
7e58e1f299 Added MeanReduce and RMSNorm fused ops 2023-10-02 20:53:57 -05:00
Joe Fioti
f1da8c3cb7 Added mps matmul 2023-10-02 11:31:36 -05:00
Joe Fioti
b87f0124b7 Merge branch 'main' of https://github.com/jafioti/luminal 2023-10-01 23:28:55 -05:00
Joe Fioti
49db4cdea8 Added metal half precision (tests failing) 2023-10-01 23:28:46 -05:00
Joe Fioti
b72e0a2270 Update Introduction.md 2023-09-30 23:58:55 -05:00
Joe Fioti
67965bc275 Added metal half precision (tests failing) 2023-09-30 23:48:05 -05:00
Joe Fioti
a8abee1422 Llama running on metal 2023-09-30 22:58:31 -05:00
Joe Fioti
4da5e94adf Complete metal primops 2023-09-30 14:08:59 -05:00
Joe Fioti
0dc1e71148 Added metal contiguous 2023-09-30 11:34:44 -05:00
Joe Fioti
7d7972d54c Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-29 23:54:34 -05:00
Joe Fioti
41de512cdc Fixed llama fp16 2023-09-29 23:54:32 -05:00
Joe Fioti
07b2b1f28c Removed cuda kernel 2023-09-29 11:18:34 -05:00
Joe Fioti
2aec49d0e5 Started metal primops 2023-09-29 11:17:33 -05:00
Joe Fioti
ffa50d43c5 Fixed feature 2023-09-27 23:20:25 -05:00
Joe Fioti
ef06f5a746 Added half precision (llama not working 2023-09-27 23:14:27 -05:00
Joe Fioti
aebdbe5ca8 Simplifications 2023-09-26 23:45:49 -05:00
Joe Fioti
acdcfc14fb Added test improvements 2023-09-26 23:27:27 -05:00
Joe Fioti
a6b403e667 Fixed feature 2023-09-26 20:02:28 -05:00
Joe Fioti
e5cfe80029 Added transfer_weights and mark_weights 2023-09-26 19:57:09 -05:00
Joe Fioti
798ac9dd69 Fixed llama setup script 2023-09-26 18:40:36 -05:00
Joe Fioti
64a05e2f14 Added cuda batch matmul 2023-09-26 18:19:26 -05:00
Joe Fioti
99f5843c42 remove rerun 2023-09-26 14:03:23 -05:00
Joe Fioti
3f2250e51f Added cublas matmul 2023-09-26 14:03:04 -05:00
Joe Fioti
9a3de0103d Fixed cuda! Validated llama run 2023-09-25 13:07:22 -05:00
Joe Fioti
1848ef4905 Partial fix of cuda 2023-09-24 23:34:36 -05:00
Joe Fioti
b8725ec9aa Cuda still broken 2023-09-23 23:55:10 -05:00
Joe Fioti
fbba2eb1db Fully precompiled cuda kernels 2023-09-18 23:52:57 -05:00
Joe Fioti
8554a1fcfc Precompiled unary cuda ops 2023-09-18 20:50:43 -05:00
Joe Fioti
2f4e189f93 First precompiled kernel? 2023-09-18 18:10:24 -05:00
Joe Fioti
1bda13aec0 Re-added cuda 2023-09-18 16:36:45 -05:00
Joe Fioti
49cadac789 Comment 2023-09-18 11:23:18 -05:00
Joe Fioti
1fe9f3a068 Selectors for multi-output 2023-09-18 11:22:17 -05:00
Joe Fioti
edb102f7a2 Added multi-output ops 2023-09-18 11:15:39 -05:00
Joe Fioti
97376b36bc Small changes 2023-09-17 22:52:23 -05:00
Joe Fioti
41d88b0c4a Multi-graph llama 2023-09-17 11:30:16 -05:00
Joe Fioti
ee70a44f8b Added proint 2023-09-17 11:14:22 -05:00
Joe Fioti
1be715f322 Small changes 2023-09-17 11:12:49 -05:00
Joe Fioti
519319c9b2 Added debug prints 2023-09-17 10:51:21 -05:00
Joe Fioti
76abe671e4 Added batched matmul cpu op 2023-09-16 23:45:28 -05:00
Joe Fioti
14541394dc Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-16 17:19:17 -05:00
Joe Fioti
78a10f89ed Re-added optimizers 2023-09-16 17:19:12 -05:00
Joe Fioti
f20a9fd2ed Chaanged tokenizer 2023-09-16 14:19:03 -05:00
Joe Fioti
18eb48735d Private data in ndexer 2023-09-16 10:19:29 -05:00
Joe Fioti
6274ba8169 Added indexer for CPU 2023-09-16 10:18:54 -05:00
Joe Fioti
a63bae227e Fixed llama 2023-09-16 08:33:33 -05:00
Joe Fioti
6182590829 Removed noop 2023-09-11 16:25:27 -05:00
Joe Fioti
0922bcb903 Fixed serialization test 2023-09-11 15:32:16 -05:00
Joe Fioti
6eb62664a5 Indexing fix 2023-09-11 14:50:08 -05:00
Joe Fioti
dcb2072f36 Partially fixed shapes 2023-09-11 13:39:07 -05:00
Joe Fioti
c5bd1a9ce9 Update README.md 2023-09-11 00:32:30 -05:00
Joe Fioti
da1192bd01 Even more fixes 2023-09-11 00:30:57 -05:00
Joe Fioti
ef3b917f5e More fixes 2023-09-11 00:04:01 -05:00
Joe Fioti
2f32bcbb8f Shape fixes 2023-09-10 23:35:49 -05:00
Joe Fioti
71adf60a71 Fixes 2023-09-10 16:04:38 -05:00
Joe Fioti
8a1c51317c Removed shape functions 2023-09-04 10:41:35 -05:00
Joe Fioti
783d01dd6f Added dyn map to graph 2023-09-04 09:11:17 -05:00
Joe Fioti
8c0567146f Added global dyn dim resolution fn 2023-09-04 09:05:28 -05:00
Joe Fioti
8a4a98fa27 Removed realdim and put symbols in Dim 2023-09-04 09:00:27 -05:00
Joe Fioti
37b363a92f Added dyn shape 2023-09-04 08:43:35 -05:00
Joe Fioti
e9352d0506 Merge branch 'main' of https://github.com/jafioti/luminal 2023-09-04 08:19:44 -05:00
Joe Fioti
6efdcdb2b9 Removed shape resolution 2023-09-04 08:19:32 -05:00
Joe Fioti
eb1355c65a More fixes 2023-09-03 18:46:56 -05:00
Joe Fioti
a135938588 Added graph() function 2023-09-03 00:02:59 -05:00
Joe Fioti
aa3bf1ef51 Minor fixes 2023-09-02 23:52:20 -05:00
Joe Fioti
dcaa13b20e re-added serialization 2023-09-02 23:07:46 -05:00
Joe Fioti
c8ca146d0c Fixes 2023-09-02 20:13:20 -05:00
Joe Fioti
8c73edb584 Re-added shape resolution 2023-09-02 18:51:35 -05:00
Joe Fioti
f6a52704d9 Added InputTensor system 2023-09-02 18:27:02 -05:00
Joe Fioti
40814bc323 More removals 2023-09-02 16:42:27 -05:00
Joe Fioti
c652f8050a Removed old shapetracker 2023-09-02 16:30:26 -05:00
Joe Fioti
c4bf441fc1 Finished first draft of shape tracker 2023-09-02 15:44:16 -05:00
Joe Fioti
8ca9add11e Remvoed movement ops 2023-09-02 11:20:17 -05:00
Joe Fioti
1282de3d05 Added padding and slicing 2023-09-02 10:58:41 -05:00
Joe Fioti
f25c40bb08 Merge 2023-09-01 22:53:05 -05:00
Joe Fioti
3e12aa3492 tmp 2023-09-01 22:37:11 -05:00
Joe Fioti
78696adb53 Changed tracker 2023-09-01 22:15:39 -05:00
Joe Fioti
51f649da8a Partway transition to new shape tracker 2023-09-01 12:50:15 -05:00
Joe Fioti
3a35e59691 Fixed cuda 2023-08-30 00:13:07 -05:00
Joe Fioti
f5098784d7 Moved shape resolution to graph execution loop 2023-08-30 00:03:27 -05:00
Joe Fioti
b3a21eaa52 Fixed mean reduce 2023-08-29 22:05:01 -05:00
Joe Fioti
cc1d92e62f Remvoed function from mean reduce 2023-08-29 21:53:49 -05:00
Joe Fioti
3fdc34e286 Re-added arange with cumsum function 2023-08-29 21:29:52 -05:00
Joe Fioti
10f3eaad39 Changed 100 magic number to usize::MAX 2023-08-29 21:10:31 -05:00
Joe Fioti
ede46bd1e0 Added broken arange function 2023-08-14 21:09:59 -05:00
Joe Fioti
aa48af32ea Merge branch 'main' of https://github.com/jafioti/luminal 2023-08-14 14:29:53 -05:00
Joe Fioti
4acc6d1114 Started pool 2023-08-14 14:29:43 -05:00
Joe Fioti
b781be3cc2 Fixed example 2023-08-14 14:22:17 -05:00
Joe Fioti
d4a04a5055 Removed max op 2023-08-14 14:07:39 -05:00
Joe Fioti
3985301749 Added binary comparisons 2023-08-12 23:10:59 -05:00
Joe Fioti
51b6d2536d Added data function 2023-08-12 22:26:05 -05:00
Joe Fioti
2001353e9e Cuda kernels use valid 2023-08-12 20:02:24 -05:00
Joe Fioti
da8b5f62d2 Added contiguous op 2023-08-12 19:50:49 -05:00
Joe Fioti
1cf47a06c6 Merge branch 'main' of https://github.com/jafioti/luminal 2023-08-12 19:28:02 -05:00
Joe Fioti
2b68b022f9 Generalized unary sequential opt 2023-08-12 19:27:54 -05:00
Joe Fioti
7de2e883b9 Update README.md 2023-08-12 14:08:19 -05:00
Joe Fioti
ebbbfa1998 Added move_outoing_edges 2023-08-11 15:25:01 -05:00
Joe Fioti
7120aede15 Added better readme and llama setup 2023-08-11 15:07:06 -05:00
Joe Fioti
14c93f0e96 Fixed bugs 2023-08-11 14:41:06 -05:00
305 changed files with 58739 additions and 37563 deletions

View File

@@ -1,22 +0,0 @@
name: Rust
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose

34
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Test
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
env:
CARGO_TERM_COLOR: always
jobs:
cpu_test:
name: CPU Tests
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --no-default-features --verbose
- name: Run tests
run: cargo test --no-default-features --verbose
# macos_test:
# name: MacOS Tests
# runs-on: macos-13
# timeout-minutes: 20
# steps:
# - uses: actions/checkout@v3
# - name: Build
# run: cargo build --verbose
# - name: Run tests
# run: cargo test --verbose -- --test-threads 1

11
.gitignore vendored
View File

@@ -1,7 +1,14 @@
/target
/crates/**/target
.DS_Store
.vscode
*.vscode
Cargo.lock
*.st
*.npx
*.npz
*.npz
/**/llama-7b-hf
/**/mistral-7b-hf
/**/setup_weights/target
*.model
*.gguf

View File

@@ -1,31 +1,29 @@
[package]
name = "luminal"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
description = "Deep learning at the speed of light."
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
#default = ["cuda"]
cuda = ["dep:cudarc"]
[dependencies]
luminal_macro = { path = "./resources/luminal_macro" }
itertools = "0.11.0"
matrixmultiply = "0.3.7"
matrixmultiply = "0.3.8"
num-traits = "0.2.16"
petgraph = {path="./resources/petgraph"}
petgraph = "0.6.4"
rand = "0.8.5"
strum = { version = "0.25.0", features = ["derive"] }
urlencoding = "2.1.2"
webbrowser = "0.8.10"
dyn-clone = "1.0.12"
cudarc = {version="0.9.13", optional=true}
safetensors = "0.3.1"
memmap2 = "0.7.1"
half = "2.3.1"
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
tinyvec = "1.6.0"
term_size = "0.3.2"
colored = "2.0.4"
regex = "1.9.5"
rustc-hash = "1.1.0"
[dev-dependencies]
dfdx = "0.13"
tokenizers = "0.13.3"
dfdx = { version = "0.13", features = ["f16"] }

201
LICENSE-APACHE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

21
LICENSE-MIT Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Joe Fioti
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

115
README.md
View File

@@ -1,76 +1,109 @@
# luminal
![image](https://raw.githubusercontent.com/jafioti/luminal/main/resources/dag.jpeg)
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
[![](https://dcbadge.vercel.app/api/server/VQf3j8WWNd)](https://discord.gg/VQf3j8WWNd)
**Deep learning at the speed of light.**
Luminal is a deep learning library that prioritizes **static computation** and **operator fusion** to achieve high performance.
Luminal is a deep learning library that uses **composable compilers** to achieve high performance.
```rust
use luminal::prelude::*;
// Setup graph and tensors
let mut cx = Graph::new();
let a = cx.new_tensor::<R2<3, 1>>("A");
let b = cx.new_tensor::<R2<1, 4>>("B");
let a = cx.tensor::<R2<3, 1>>()
.set([[1.0], [2.0], [3.0]]);
let b = cx.tensor::<R2<1, 4>>()
.set([[1.0, 2.0, 3.0, 4.0]]);
// Do stuff...
let c = a.matmul(b);
// Do math...
let mut c = a.matmul(b).retrieve();
// Set inputs and mark outputs
a.set(vec![1.0, 2.0, 3.0]);
b.set(vec![1.0, 2.0, 3.0, 3.0]);
c.mark();
// Optimize and run graph
cx.optimize(GenericOptimizer::default());
// Compile and run graph
cx.compile(<(GenericCompiler, CPUCompiler)>::default(), &mut c);
cx.execute();
// Get result
println!("Result: {:?}", c.retrieve().unwrap().data);
println!("Result: {:?}", c);
```
## Why does this look so different from other DL libraries?
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. So when you see `x + y`, the addition actually happens right there. This is great for debugging, it works exactly as most developers expect.
## Getting Started
**Mistral 7B**
```bash
cd ./examples/mistral
# Download the model
bash ./setup/setup.sh
# Run the model
cargo run --release --features metal # MacOS (Recommended)
cargo run --release --features cuda # Nvidia
cargo run --release # CPU
```
However, this isn't great for performance because what makes sense for a developer doesn't make sense for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
## Features
### Speed
Luminal can run Q8 Mistral 7B on M-series Macbooks at 15-25 tokens per second. The goal is to become the fastest ML framework for any model on any device.
Luminal takes a different approach, more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Here everything's static. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, and executed later.
### Simplicity
The core of luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
## But Why?
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our optimizers have global knowledge and can do much more aggressive optimization **without any sync points**.
### RISC-style architecture
Everything in luminal boils down to 11 primitive ops:
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
- Binary - `Add, Mul, Mod, LessThan`
- Other - `SumReduce, MaxReduce, Contiguous`
Of course, we can still split the network into multiple seperate graphs if we want to insert dynamic control flow part-way through, which means this method doesn't preclude optimizations like KV caching, because the KV cached forward pass is just a seperate graph!
These ops are enough to support transformers, convnets, etc.
Some huge benefits are now unlocked:
### Native
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
### Validated against Pytorch
Correctness matters. So we write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
## Ideology
### Why does this look so different from other DL libraries?
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. In PyTorch, when you see `x + y`, the addition actually happens right there. This is great for debugging because it works exactly as most developers expect.
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
### Compile everything
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
**But why?**
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd will be handled by a compiler!
Now we can do:
- Aggressive kernel fusion
- Shape-specific kernels compiled at runtime
- Devices and Dtypes are handled through optimizers (just run the CUDA optimizer to convert the graph to use CUDA kernels, then the fp16 optimizer to convert to half-precision kernels)
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
## RISC-style architecture
Luminal can be ran on new accelerators by implementing 11 primitive ops. Take a look at `src/optimizers/cuda/prim.rs` to see 1-to-1 CUDA translations of the primops.
### Compile-time Shape Checks
All operations are shape checked at compile time, so no more shape mismatches! Credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
Accellerators are free to implement their own custom ops, and their own optimizers to convert luminal primitive ops to their bespoke ops.
## Compile-time Shape Checks
All operations are shape checked at compile time, so no more shape mismatches! All credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
## View the Graph
Once you've written all your computation code, run `cx.display_graph()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.optimize(GeneralOptimizer::default())` and display the graph again. Much better.
### View the Graph
Once you've written all your computation code, run `cx.display()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.compile(GenericCompiler::default())` and display the graph again. Much better.
## Where are we?
Currently luminal is extremely alpha. Please don't use this in prod.
- Llama 1 is implemented in `examples/llama`. You'll need to follow the instructions in [llama-dfdx](https://github.com/coreylowman/llama-dfdx) to download and convert the llama weights, and point this example loading path at them.
- The llama example shows how to implement a loader for a custom format. Safetensors loaders are already implemented, and are the recommended way to load a model.
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
- Performance on M-series macs with LLMs is within 20% of llama.cpp (a *heavily* optimized library)
- Mistral 7B and Llama 7B are implemented in `examples/`. See instructions above for running.
- We have a small library of NN modules in `nn`, including transformers.
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the tinygrad ops set.
- Currently there are very few optimizers, so primops are mostly used to run these models, which are very slow.
- Next release will bring a signifigant amount of optimizers which should fuse primops into much faster ops. The aim for 0.2 is to be usably fast, not SOTA yet.
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
- The aim for 0.3 is to achieve SOTA performance on an M1 pro (50 tok/s), and near SOTA on single nvidia gpus (>100 tok/s), as well as support many mainstream models (Whisper, Stable Diffusion, Yolo v9, etc.)
Some things on the roadmap:
- Write common sense cuda ops and optimizer (matmuls, mul-add, etc.)
- Optimize cuda and metal matmul kernels
- Fine-grained metal and cuda IR
- Build benchmarking suite to test against other libs
- Write specialized CUDA kernels for full transformer architecture (FlashAttention, etc.)
- Automatic differentiation of graphs
- Autograd engine
- Distributed data, pipeline and tensor parallel.
- Beat PT 2.0 perf on LLM training
- Write compiler for quantum photonic retro encabulator
- Build dyson swarm
## License
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.

View File

@@ -0,0 +1,22 @@
[package]
name = "luminal_cuda"
version = "0.2.0"
edition = "2021"
description = "Cuda compiler for luminal"
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
luminal = { path = "../.." }
luminal_cudarc = { version="0.10.0", features = [
"cublas",
"f16",
]}
itertools = "0.12.1"
rustc-hash = "1.1.0"
num-traits = "0.2.18"
[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
rand = "0.8.5"

View File

@@ -0,0 +1,612 @@
use std::{marker::PhantomData, sync::Arc};
use luminal_cudarc::{
driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig},
nvrtc::{compile_ptx_with_opts, CompileOptions},
};
use itertools::Itertools;
use luminal::{
op::*,
prelude::{petgraph::visit::EdgeRef, *},
};
use rustc_hash::FxHashMap;
use crate::{
get_idx_valid_exps, hash,
other::CudaARange,
prim::{CudaAdd, CudaCopyToDevice, CudaLessThan, CudaMul, CudaSumReduce},
render_dyn_dim_inputs, select_const, CudaData, CudaFloat,
};
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
pub struct CudaSub<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: CudaFloat> CudaSub<T> {
pub fn new(
a_shape: ShapeTracker,
b_shape: ShapeTracker,
dev: Arc<CudaDevice>,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (a_idx, a_valid) = get_idx_valid_exps(a_shape);
let (b_idx, b_valid) = get_idx_valid_exps(b_shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape]);
let type_name = T::type_name();
let mut code = format!(
"
#include \"cuda_fp16.h\"
extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp_a, const {type_name} *inp_b, int numel{rendered}) {{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {{
out[idx] =
(({a_valid}) == 0 ? {} : inp_a[{a_idx}])
- (({b_valid}) == 0 ? {} : inp_b[{b_idx}]);
}}
}}",
if T::is_f32() {
"0.0"
} else {
"__float2half(0.0)"
},
if T::is_f32() {
"0.0"
} else {
"__float2half(0.0)"
},
);
let name = format!("kernel_{}", hash(&code));
code = code.replace("kernel", &name);
if !dev.has_func(&name, &name) {
dev.load_ptx(
compile_ptx_with_opts(
code,
CompileOptions {
arch: Some("sm_75"),
include_paths: vec!["/usr/local/cuda/include".to_string()],
..Default::default()
},
)
.unwrap(),
&name,
&[name.clone().leak()],
)
.unwrap();
}
Self {
function: dev.get_func(&name, &name).unwrap(),
device: dev,
_phantom: Default::default(),
dyn_symbols,
dyn_map,
}
}
}
impl<T> Operator for CudaSub<T>
where
T: std::fmt::Debug
+ Copy
+ luminal_cudarc::driver::DeviceRepr
+ std::marker::Unpin
+ luminal_cudarc::driver::ValidAsZeroBits,
CudaData<T>: Data,
{
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let a = tensors[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let b = tensors[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.alloc_zeros::<T>(inp_size).unwrap();
let mut params = vec![
(&out).as_kernel_param(),
(&a.0).as_kernel_param(),
(&b.0).as_kernel_param(),
inp_size.as_kernel_param(),
];
let mut dims = [0; 10];
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
for (i, d) in self.dyn_symbols.iter().enumerate() {
dims[i] = dyn_map[d] as i32;
params.push(unsafe {
dims[0]
.as_kernel_param()
.add(i * std::mem::size_of::<i32>())
});
}
unsafe {
self.function
.clone()
.launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params)
.unwrap();
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
#[derive(LuminalPrint, Default)]
pub struct CudaSubtractionCompiler<T: CudaFloat>(PhantomData<T>);
impl<T: CudaFloat> Compiler for CudaSubtractionCompiler<T>
where
CudaData<T>: luminal::prelude::Data,
{
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let (mut neg_one, mut mul, mut add) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let mut searcher = select_const!(-1.0, T)
.ptr(&mut neg_one)
.edge(SelectOp::new().ty::<CudaMul<T>>().ptr(&mut mul))
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut add))
.search(graph);
while searcher.next_match() {
if check_no_delete(graph, &[neg_one, mul, add]) {
continue;
}
let (a, a_edge) = graph
.graph
.edges_directed(add, petgraph::Direction::Incoming)
.find(|e| e.source() != mul)
.map(|e| (e.source(), e.weight().as_data().unwrap()))
.unwrap();
let (b, b_edge) = graph
.graph
.edges_directed(mul, petgraph::Direction::Incoming)
.find(|e| e.source() != neg_one)
.map(|e| (e.source(), e.weight().as_data().unwrap()))
.unwrap();
let b_final_shape = graph
.graph
.edges_connecting(mul, add)
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
if !b_final_shape.is_contiguous()
|| b_final_shape.is_sliced()
|| b_final_shape.is_padded()
{
continue;
}
let sub = graph
.add_op(CudaSub::<T>::new(
a_edge.2,
b_edge.2,
dev.clone(),
&graph.dyn_map,
))
.input(a, a_edge.1, a_edge.2)
.input(b, b_edge.1, b_edge.2)
.finish();
move_outgoing_edge(add, sub, &mut graph.graph);
if graph.get_dests(neg_one).len() == 1 {
graph.graph.remove_node(neg_one);
}
graph.graph.remove_node(mul);
graph.graph.remove_node(add);
}
}
}
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
pub struct CudaEqual<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: CudaFloat> CudaEqual<T> {
pub fn new(
a_shape: ShapeTracker,
b_shape: ShapeTracker,
dev: Arc<CudaDevice>,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (a_idx, a_valid) = get_idx_valid_exps(a_shape);
let (b_idx, b_valid) = get_idx_valid_exps(b_shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape]);
let type_name = T::type_name();
let mut code = format!(
"
#include \"cuda_fp16.h\"
extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp_a, const {type_name} *inp_b, int numel{rendered}) {{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {{
{type_name} a_val = ({a_valid}) == 0 ? {} : inp_a[{a_idx}];
{type_name} b_val = ({b_valid}) == 0 ? {} : inp_b[{b_idx}];
out[idx] = ({type_name})(a_val == b_val);
}}
}}",
if T::is_f32() {
"0.0"
} else {
"__float2half(0.0)"
},
if T::is_f32() {
"0.0"
} else {
"__float2half(0.0)"
},
);
let name = format!("kernel_{}", hash(&code));
code = code.replace("kernel", &name);
if !dev.has_func(&name, &name) {
dev.load_ptx(
compile_ptx_with_opts(
code,
CompileOptions {
arch: Some("sm_75"),
include_paths: vec!["/usr/local/cuda/include".to_string()],
..Default::default()
},
)
.unwrap(),
&name,
&[name.clone().leak()],
)
.unwrap();
}
Self {
function: dev.get_func(&name, &name).unwrap(),
device: dev,
_phantom: Default::default(),
dyn_symbols,
dyn_map,
}
}
}
impl<T> Operator for CudaEqual<T>
where
T: std::fmt::Debug
+ Copy
+ luminal_cudarc::driver::DeviceRepr
+ std::marker::Unpin
+ luminal_cudarc::driver::ValidAsZeroBits,
CudaData<T>: Data,
{
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let a = tensors[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let b = tensors[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.alloc_zeros::<T>(inp_size).unwrap();
let mut params = vec![
(&out).as_kernel_param(),
(&a.0).as_kernel_param(),
(&b.0).as_kernel_param(),
inp_size.as_kernel_param(),
];
let mut dims = [0; 10];
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
for (i, d) in self.dyn_symbols.iter().enumerate() {
dims[i] = dyn_map[d] as i32;
params.push(unsafe {
dims[0]
.as_kernel_param()
.add(i * std::mem::size_of::<i32>())
});
}
unsafe {
self.function
.clone()
.launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params)
.unwrap();
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
#[derive(LuminalPrint, Default)]
pub struct CudaEqualCompiler<T: CudaFloat>(PhantomData<T>);
impl<T: CudaFloat> Compiler for CudaEqualCompiler<T>
where
CudaData<T>: luminal::prelude::Data,
{
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let (mut less_than1, mut less_than2, mut add, mut one, mut sub) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let s = select_const!(1.0, T).ptr(&mut one).edge(
SelectOp::new()
.ty::<CudaLessThan<T>>()
.ptr(&mut less_than1)
.edge(
SelectOp::new()
.ty::<CudaLessThan<T>>()
.ptr(&mut less_than2)
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut add)),
)
.edge(SelectOp::new().ty::<CudaSub<T>>().ptr(&mut sub)),
);
let mut searcher = s.search(graph);
while searcher.next_match() {
let lt1_inputs = graph
.graph
.neighbors_directed(less_than1, petgraph::Direction::Incoming)
.sorted()
.collect::<Vec<_>>();
let lt2_inputs = graph
.graph
.neighbors_directed(less_than2, petgraph::Direction::Incoming)
.sorted()
.collect::<Vec<_>>();
if lt1_inputs != lt2_inputs {
continue;
}
let inputs = graph
.graph
.edges_directed(less_than1, petgraph::Direction::Incoming)
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.source())
.collect::<Vec<_>>();
let (a, b) = (inputs[0], inputs[1]);
if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) {
continue;
}
let a_edge = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(a, less_than1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap();
let b_edge = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(b, less_than1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap();
let equals = graph
.add_op(CudaEqual::<T>::new(
a_edge.2,
b_edge.2,
dev.clone(),
&graph.dyn_map,
))
.input(a, a_edge.1, a_edge.2)
.input(b, b_edge.1, b_edge.2)
.finish();
move_outgoing_edge(sub, equals, &mut graph.graph);
graph.graph.remove_node(sub);
graph.safe_remove_node(add, 0);
graph.safe_remove_node(one, 0);
graph.safe_remove_node(less_than2, 0);
graph.safe_remove_node(less_than1, 0);
searcher.clear_cached_results();
}
}
}
#[derive(LuminalPrint, Clone, LuminalEqFalse)]
pub struct CudaGather<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
pub embed_dim: usize,
_phantom: PhantomData<T>,
}
impl<T: CudaFloat> CudaGather<T> {
pub fn new(dev: Arc<CudaDevice>, embed_dim: usize) -> Self {
let type_name = T::type_name();
let code = format!("
#include \"cuda_fp16.h\"
extern \"C\" __global__ void gather({type_name} *out, const {type_name} *weights, const float *inp, int n_embeddings, int embedding_dim) {{
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < n_embeddings && y < embedding_dim) {{
out[x * embedding_dim + y] = weights[(int)inp[x] * embedding_dim + y];
}}
}}");
dev.load_ptx(
compile_ptx_with_opts(
code,
CompileOptions {
arch: Some("sm_75"),
include_paths: vec!["/usr/local/cuda/include".to_string()],
..Default::default()
},
)
.unwrap(),
"gather",
&["gather"],
)
.unwrap();
Self {
function: dev.get_func("gather", "gather").unwrap(),
device: dev,
embed_dim,
_phantom: Default::default(),
}
}
}
impl<T> Operator for CudaGather<T>
where
T: std::fmt::Debug + Copy + luminal_cudarc::driver::DeviceRepr + std::marker::Unpin + CudaFloat,
CudaData<T>: Data,
{
fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
// Inp 1 should be Vec<f32> and inp 2 should be a CudaSlice<T>
let indexes = inputs[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<Vec<f32>>()
.unwrap();
let weights = inputs[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let mut indexes_buffer = unsafe { self.device.alloc::<f32>(indexes.len()).unwrap() };
self.device
.htod_copy_into(indexes.clone(), &mut indexes_buffer)
.unwrap();
let mut out = self
.device
.alloc_zeros::<T>(indexes.len() * self.embed_dim)
.unwrap();
unsafe {
self.function
.clone()
.launch(
LaunchConfig {
grid_dim: (
indexes.len().div_ceil(16) as u32,
self.embed_dim.div_ceil(16) as u32,
1,
),
block_dim: (16, 16, 1),
shared_mem_bytes: 0,
},
(
&mut out,
&weights.0,
&indexes_buffer,
indexes.len(),
self.embed_dim,
),
)
.unwrap();
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
#[derive(LuminalPrint, Default)]
pub struct MetalGatherCompiler<T: CudaFloat>(PhantomData<T>);
impl<T: CudaFloat> Compiler for MetalGatherCompiler<T>
where
CudaData<T>: luminal::prelude::Data,
{
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let (mut ind_copy, mut arange, mut equal, mut mul, mut sum_reduce) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let s = SelectOp::new()
.ty::<CudaARange<T>>()
.ptr(&mut arange)
.edge(
SelectOp::new()
.ty::<CudaCopyToDevice<T>>()
.ptr(&mut ind_copy)
.edge(SelectOp::new().ty::<CudaEqual<T>>().ptr(&mut equal)),
)
.edge(SelectOp::new().ty::<CudaMul<T>>().ptr(&mut mul))
.edge(
SelectOp::new()
.ty::<CudaSumReduce<T>>()
.ptr(&mut sum_reduce),
);
let mut searcher = s.search(graph);
while searcher.next_match() {
if check_no_delete(graph, &[arange, equal, mul, sum_reduce]) {
continue;
}
let embedding_dim = graph
.graph
.edges_directed(mul, petgraph::Direction::Incoming)
.find(|e| e.source() != equal && !e.weight().is_schedule())
.unwrap()
.weight()
.as_data()
.unwrap()
.2
.shape()[2]
.to_usize()
.unwrap();
let gather = graph
.add_op(CudaGather::<T>::new(dev.clone(), embedding_dim))
.finish();
move_incoming_edge(ind_copy, gather, &mut graph.graph);
graph.safe_remove_node(equal, 1);
move_incoming_edge(mul, gather, &mut graph.graph);
move_outgoing_edge(sum_reduce, gather, &mut graph.graph);
graph.graph.remove_node(sum_reduce);
graph.safe_remove_node(mul, 0);
graph.safe_remove_node(ind_copy, 0);
graph.safe_remove_node(arange, 0);
}
}
}

View File

@@ -0,0 +1,184 @@
mod binary;
mod matmul;
mod other;
mod prim;
#[cfg(test)]
mod tests;
use itertools::Itertools;
use luminal_cudarc::driver::{CudaSlice, DeviceRepr};
use std::{collections::hash_map::DefaultHasher, fmt::Write, hash::Hasher};
use luminal::prelude::*;
use self::symbolic::{BigExpression, Term};
pub type CudaCompiler<T> = (
prim::CudaPrimitiveCompiler<T>,
binary::CudaSubtractionCompiler<T>,
binary::CudaEqualCompiler<T>,
other::ARangeCompiler<T>,
binary::MetalGatherCompiler<T>,
matmul::CudaMatMulCompiler<T>,
prim::CopyCompiler<T>,
);
pub trait CudaFloat:
std::fmt::Debug
+ Copy
+ luminal_cudarc::driver::DeviceRepr
+ std::marker::Unpin
+ luminal_cudarc::driver::ValidAsZeroBits
{
fn to_f32(self) -> f32;
fn from_f32(a: f32) -> Self;
fn is_f32() -> bool;
fn type_name() -> &'static str;
}
impl CudaFloat for f32 {
fn from_f32(a: f32) -> Self {
a
}
fn to_f32(self) -> f32 {
self
}
fn is_f32() -> bool {
true
}
fn type_name() -> &'static str {
"float"
}
}
#[derive(Debug)]
pub struct CudaData<T>(CudaSlice<T>);
impl<T: DeviceRepr> Clone for CudaData<T> {
fn clone(&self) -> Self {
Self(self.0.try_clone().unwrap())
}
}
impl Data for CudaData<f32> {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl CudaFloat for f16 {
fn from_f32(a: f32) -> Self {
f16::from_f32(a)
}
fn to_f32(self) -> f32 {
self.to_f32()
}
fn is_f32() -> bool {
false
}
fn type_name() -> &'static str {
"__half"
}
}
impl Data for CudaData<f16> {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
fn expr_to_cuda_string(expr: BigExpression) -> String {
let mut symbols = vec![];
for term in expr.terms {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {
if c == 'z' {
"(int)idx".to_string()
} else {
c.to_string()
}
}
Term::Max => format!(
"max((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
Term::Min => format!(
"min((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
_ => format!(
"({}{term:?}{})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
};
symbols.push(new_symbol);
}
symbols.pop().unwrap()
}
fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) {
(
expr_to_cuda_string(shape.index_expression()),
expr_to_cuda_string(shape.valid_expression()),
)
}
fn render_dyn_dim_inputs(shapes: &[ShapeTracker]) -> (Vec<char>, String) {
let symbols: Vec<char> = shapes
.iter()
.flat_map(|st| {
st.shape()
.into_iter()
.chain(
st.padding
.into_iter()
.flat_map(|i| [i.0.into(), i.1.into()]),
)
.chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
})
.flat_map(|d| d.to_symbols())
.unique()
.collect();
(
symbols.clone(),
symbols.into_iter().fold(String::default(), |mut acc, c| {
write!(&mut acc, ", const int {c}").unwrap();
acc
}),
)
}
#[macro_export]
macro_rules! select_const {
($i: expr, $t: tt) => {
luminal::compiler_utils::SelectOp::new().check(|o, _| {
if let Some(c) = o.as_any().downcast_ref::<$crate::prim::CudaConstant<$t>>() {
if let luminal::op::ConstantValue::Float(f) = c.0 {
(f - $i).abs() < 0.0001
} else {
false
}
} else {
false
}
})
};
}
fn hash<T: std::hash::Hash>(obj: T) -> u64 {
let mut hasher = DefaultHasher::new();
obj.hash(&mut hasher);
hasher.finish()
}

View File

@@ -0,0 +1,353 @@
use std::{marker::PhantomData, sync::Arc};
use luminal_cudarc::{
cublas::{sys::cublasOperation_t::*, CudaBlas},
driver::{CudaDevice, DevicePtr, DevicePtrMut},
};
use crate::{
prim::{CudaMul, CudaSumReduce},
CudaData, CudaFloat,
};
use luminal::{
graph::NodeIndex,
op::{InputTensor, Operator},
prelude::*,
};
/// Multiplies a MxK matrix with a KxN matrix, resulting in a MxN matrix
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
pub struct CudaMatmul2D<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);
impl<T: CudaFloat + 'static> Operator for CudaMatmul2D<T>
where
CudaData<T>: Data,
{
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (m, k, n) = (
a_shape[0].to_usize().unwrap() as i32,
a_shape[1].to_usize().unwrap() as i32,
b_shape[1].to_usize().unwrap() as i32,
);
let a = inp[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let b = inp[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let mut out = self.1.alloc_zeros::<T>((m * n) as usize).unwrap();
let (a_row_major, b_row_major) = (
inp[0].1.indexes[1] > inp[0].1.indexes[0],
inp[1].1.indexes[1] > inp[1].1.indexes[0],
);
let (transa, transb) = match (a_row_major, b_row_major) {
(true, true) => (CUBLAS_OP_N, CUBLAS_OP_N),
(false, false) => (CUBLAS_OP_T, CUBLAS_OP_T),
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
};
if T::is_f32() {
unsafe {
luminal_cudarc::cublas::result::sgemm(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&1.0_f32 as *const f32,
*b.0.device_ptr() as *const f32,
if b_row_major { n } else { k },
*a.0.device_ptr() as *const f32,
if a_row_major { k } else { m },
&0.0_f32 as *const f32,
*out.device_ptr_mut() as *mut f32,
n,
)
.unwrap();
}
} else {
unsafe {
luminal_cudarc::cublas::result::hgemm(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&f16::from_f32(1.0) as *const f16,
*b.0.device_ptr() as *const f16,
if b_row_major { n } else { k },
*a.0.device_ptr() as *const f16,
if a_row_major { k } else { m },
&f16::from_f32(0.0) as *const f16,
*out.device_ptr_mut() as *mut f16,
n,
)
.unwrap();
}
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
/// Multiplies a BxMxK matrix with a BxKxN matrix, resulting in a BxMxN matrix
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
pub struct CudaBatchMatmul2D<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);
impl<T: CudaFloat + 'static> Operator for CudaBatchMatmul2D<T>
where
CudaData<T>: Data,
{
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let a_strides = inp[0].1.strides();
let (batch_size, m, k, n) = (
a_shape[0].to_usize().unwrap() as i32,
a_shape[1].to_usize().unwrap() as i32,
a_shape[2].to_usize().unwrap() as i32,
b_shape[1].to_usize().unwrap() as i32,
);
let a = inp[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let b = inp[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap();
let mut out = self
.1
.alloc_zeros::<T>((m * n * batch_size) as usize)
.unwrap();
let (a_row_major, b_row_major) = (
inp[0].1.indexes[2] > inp[0].1.indexes[1],
inp[1].1.indexes[1] > inp[1].1.indexes[0],
);
let (transa, transb) = match (a_row_major, b_row_major) {
(true, true) => (CUBLAS_OP_N, CUBLAS_OP_N),
(false, false) => (CUBLAS_OP_T, CUBLAS_OP_T),
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
};
if T::is_f32() {
unsafe {
luminal_cudarc::cublas::result::sgemm_strided_batched(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&1.0_f32 as *const f32,
*b.0.device_ptr() as *const f32,
if b_row_major { n } else { k },
0,
*a.0.device_ptr() as *const f32,
if a_row_major { k } else { m },
a_strides[0].to_usize().unwrap() as i64,
&0.0_f32 as *const f32,
*out.device_ptr_mut() as *mut f32,
n,
(m * n) as i64,
batch_size,
)
.unwrap();
}
} else {
unsafe {
luminal_cudarc::cublas::result::hgemm_strided_batched(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&f16::from_f32(1.0) as *const f16,
*b.0.device_ptr() as *const f16,
if b_row_major { n } else { k },
0,
*a.0.device_ptr() as *const f16,
if a_row_major { k } else { m },
a_strides[0].to_usize().unwrap() as i64,
&f16::from_f32(0.0) as *const f16,
*out.device_ptr_mut() as *mut f16,
n,
(m * n) as i64,
batch_size,
)
.unwrap();
}
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
#[derive(Default)]
pub struct CudaMatMulCompiler<T>(PhantomData<T>);
impl<T: CudaFloat + 'static> Compiler for CudaMatMulCompiler<T>
where
CudaData<T>: Data,
{
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let dev = CudaDevice::new(0).unwrap();
// Look for the matmul pattern
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
// Actually starts at [A,B] | [B, C]
let s = SelectEdge::new(
SelectOp::new()
.ty::<CudaMul<T>>()
.shapes([['A', 'C', 'B'], ['A', 'C', 'B']])
.fakes([
[Some(false), Some(true), Some(false)],
[Some(true), Some(false), Some(false)],
])
.ptr(&mut mul),
SelectOp::new()
.ty::<CudaSumReduce<T>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<CudaSumReduce<T>>() {
o.2 == 2
} else {
false
}
})
.ptr(&mut sum_reduce),
);
let mut searcher = s.search(graph);
while searcher.next_match() {
if graph.no_delete.contains(&mul) {
// The intermediate mul can't be deleted
continue;
}
// Insert MatMul2D op
let mut srcs = graph.get_sources(mul);
// Undo expansions and permute
srcs[0].2.remove_dim(1);
srcs[1].2.remove_dim(0);
srcs[1].2.permute(&[1, 0]);
let new_op = graph
.add_op(CudaMatmul2D::<T>(
Arc::new(CudaBlas::new(dev.clone()).unwrap()),
dev.clone(),
Default::default(),
))
.input(srcs[0].0, 0, srcs[0].2)
.input(srcs[1].0, 0, srcs[1].2)
.finish();
// Create edges to dests
move_outgoing_edge(sum_reduce, new_op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
sum_reduce,
new_op,
);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
mul,
new_op,
);
// Remove the old ops
graph.graph.remove_node(mul);
graph.graph.remove_node(sum_reduce);
}
// Look for the batch matmul pattern
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
// Actually starts at [A,B] | [B, C]
let mut searcher = SelectEdge::new(
SelectOp::new()
.ty::<CudaMul<T>>()
.shapes([['D', 'A', 'C', 'B'], ['D', 'A', 'C', 'B']])
.fakes([
[Some(false), Some(false), Some(true), Some(false)],
[Some(true), Some(true), Some(false), Some(false)],
])
.ptr(&mut mul),
SelectOp::new()
.ty::<CudaSumReduce<T>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<CudaSumReduce<T>>() {
o.2 == 3
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
while searcher.next_match() {
if graph.no_delete.contains(&mul) {
// The intermediate mul can't be deleted
continue;
}
// Insert BatchMatMul2D op
let mut srcs = graph.get_sources(mul);
// Undo expansions and permute
srcs[0].2.remove_dim(2);
srcs[1].2.remove_dim(1);
srcs[1].2.remove_dim(0);
srcs[1].2.permute(&[1, 0]);
let new_op = graph
.add_op(CudaBatchMatmul2D::<T>(
Arc::new(CudaBlas::new(dev.clone()).unwrap()),
dev.clone(),
Default::default(),
))
.input(srcs[0].0, 0, srcs[0].2)
.input(srcs[1].0, 0, srcs[1].2)
.finish();
// Create edges to dests
move_outgoing_edge(sum_reduce, new_op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
sum_reduce,
new_op,
);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
mul,
new_op,
);
// Remove the old ops
graph.graph.remove_node(mul);
graph.graph.remove_node(sum_reduce);
}
}
}

View File

@@ -0,0 +1,195 @@
use std::{marker::PhantomData, sync::Arc};
use luminal_cudarc::{
driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig},
nvrtc::{compile_ptx_with_opts, CompileOptions},
};
use luminal::{
op::*,
prelude::{petgraph::visit::EdgeRef, *},
shape::symbolic::BigExpression,
};
use rustc_hash::FxHashMap;
use crate::{
binary::CudaSub,
prim::{CudaAdd, CudaContiguous, CudaSumReduce},
select_const, CudaData, CudaFloat,
};
#[derive(LuminalPrint, Clone, LuminalEqFalse)]
pub struct CudaARange<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
pub size: BigExpression,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: CudaFloat> CudaARange<T> {
pub fn new(
dev: Arc<CudaDevice>,
size: BigExpression,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let type_name = T::type_name();
let code = format!(
"
#include \"cuda_fp16.h\"
extern \"C\" __global__ void arange({type_name} *out, int n_elements) {{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_elements) {{
out[idx] = ({type_name})idx;
}}
}}"
);
dev.load_ptx(
compile_ptx_with_opts(
code,
CompileOptions {
arch: Some("sm_75"),
include_paths: vec!["/usr/local/cuda/include".to_string()],
..Default::default()
},
)
.unwrap(),
"arange",
&["arange"],
)
.unwrap();
Self {
function: dev.get_func("arange", "arange").unwrap(),
device: dev,
size,
_phantom: Default::default(),
dyn_map,
}
}
}
impl<T> Operator for CudaARange<T>
where
T: std::fmt::Debug + Copy + luminal_cudarc::driver::DeviceRepr + std::marker::Unpin + CudaFloat,
CudaData<T>: Data,
{
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let n_elements = self
.size
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
.unwrap();
let mut out = self.device.alloc_zeros::<T>(n_elements).unwrap();
unsafe {
self.function
.clone()
.launch(
LaunchConfig::for_num_elems(n_elements as u32),
(&mut out, n_elements as i32),
)
.unwrap();
}
vec![Tensor {
data: Box::new(CudaData(out)),
}]
}
}
#[derive(LuminalPrint, Default)]
pub struct ARangeCompiler<T: CudaFloat>(PhantomData<T>);
impl<T: CudaFloat> Compiler for ARangeCompiler<T>
where
CudaData<T>: Data,
{
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let (
mut one_const,
mut contig1,
mut contig2,
mut contig3,
mut contig4,
mut sum_reduce,
mut subtraction_constant,
mut subtraction,
) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
// TODO: Make sure this actually checks the shape transformations to ensure pooling happens
let contig = SelectOp::new().ty::<CudaContiguous<T>>();
let pre_sub_pattern = select_const!(1.0, T)
.ptr(&mut one_const)
.edge(contig.clone().ptr(&mut contig1))
.edge(contig.clone().ptr(&mut contig2))
.edge(contig.clone().ptr(&mut contig3))
.edge(contig.clone().ptr(&mut contig4))
.edge(
SelectOp::new()
.ty::<CudaSumReduce<T>>()
.ptr(&mut sum_reduce),
);
let mut s1 = pre_sub_pattern
.clone()
.edge(
select_const!(1.0, T)
.ptr(&mut subtraction_constant)
.edge(SelectOp::new().ty::<CudaSub<T>>().ptr(&mut subtraction)),
)
.search(graph);
let mut s2 = pre_sub_pattern
.edge(
select_const!(-1.0, T)
.ptr(&mut subtraction_constant)
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut subtraction)),
)
.search(graph);
while s1.next_match() || s2.next_match() {
let arange_amount = {
let sh = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(one_const, contig1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap()
.2;
sh.dims[sh.indexes[sh.len() - 1]]
};
let arange_op = graph
.add_op(CudaARange::<T>::new(
dev.clone(),
arange_amount.into(),
&graph.dyn_map,
))
.finish();
move_outgoing_edge(subtraction, arange_op, &mut graph.graph);
graph.graph.remove_node(subtraction);
graph.safe_remove_node(subtraction_constant, 0);
graph.safe_remove_node(sum_reduce, 0);
graph.safe_remove_node(contig4, 0);
graph.safe_remove_node(contig3, 0);
graph.safe_remove_node(contig2, 0);
graph.safe_remove_node(contig1, 0);
graph.safe_remove_node(one_const, 0);
s1.clear_cached_results();
s2.clear_cached_results();
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,973 @@
use dfdx::prelude::{Module as DfdxModule, *};
use itertools::Itertools;
use rand::{rngs::StdRng, SeedableRng};
use luminal::{
nn::{activation::ReLU, linear::Linear, norm::RMSNorm},
prelude::{symbolic::Expression, Module, *},
};
#[allow(unused_imports)]
use dfdx::prelude::{
Axes as DAxes, Axes2 as DAxes2, Axes3 as DAxes3, Axes4 as DAxes4, Axes5 as DAxes5,
Axis as DAxis, Const as DConst, *,
};
#[allow(unused_imports)]
use luminal::{
prelude::{
Axes as LAxes, Axes2 as LAxes2, Axes3 as LAxes3, Axes4 as LAxes4, Axes5 as LAxes5,
Axis as LAxis, Const as LConst, *,
},
tests::{
assert_close, assert_close_precision, assert_exact, random_vec, random_vec_rng, test_graphs,
},
};
use crate::CudaCompiler;
#[test]
fn test_contiguous() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<3>, DConst::<4>));
let d_b = d_a.permute::<Rank2<4, 3>, _>().reshape::<Rank2<12, 1>>();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_softmax() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<1, 12>>().set(data.clone());
let mut b = a.softmax::<1>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<12>));
let d_b = d_a.softmax::<DAxis<1>>();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_rotate() {
let mut cx = Graph::new();
const D: usize = 2;
const S: usize = 2;
const H: usize = 2;
let data = random_vec(D * S * H);
let a = cx
.tensor::<R4<1, D, S, H>>()
.set(data)
.keep()
.permute::<_, LAxes4<0, 2, 1, 3>>();
let x1 = a.slice((.., .., .., ..Expression::from(H / 2)));
let x2 = a.slice((.., .., .., Expression::from(H / 2)..));
let mut rotated_a = (-x2)
.concat_along::<R4<1, S, D, H>, LAxis<3>, _>(x1)
.retrieve();
cx.execute();
let unopt = rotated_a.data();
cx.compile(CudaCompiler::<f32>::default(), &mut rotated_a);
cx.execute();
assert_close(&unopt, &rotated_a.data());
}
#[test]
fn test_constant() {
let mut cx = Graph::new();
let a = cx.constant_expr('a');
let mut a = (a * a).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut a);
cx.set_dyn_dim('a', 10);
cx.execute();
assert_exact(&a.data(), &[100.0]);
a.drop();
cx.set_dyn_dim('a', 25);
cx.execute();
assert_exact(&a.data(), &[625.0]);
}
#[test]
fn test_log2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.log2().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
assert_close(
&b.data(),
&data.into_iter().map(|i| i.log2()).collect::<Vec<_>>(),
);
}
#[test]
fn test_exp2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.exp2().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
assert_close(
&b.data(),
&data.into_iter().map(|i: f32| i.exp2()).collect::<Vec<_>>(),
);
}
#[test]
fn test_recip() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 4096.]);
let mut b = a.recip().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 4096.]);
let d_b = d_a.recip();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_sin() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut b = a.sin().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_a.sin();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_sqrt() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut b = a / a.sqrt();
b.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_a.clone() / d_a.sqrt();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_add() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = (a + b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a + d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_sub() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a - b;
c.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a - d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_square() {
let mut cx = Graph::new();
let mut rng = rand::thread_rng();
let data = random_vec_rng(40960, &mut rng);
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'s'>, luminal::prelude::Const<4096>)>()
.set_dyn(data.clone(), &[1, 10, 4096]);
let mut b = a * a;
b.retrieve();
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec::<Rank3<1, 10, 4096>>(
data,
(
dfdx::prelude::Const::<1>,
dfdx::prelude::Const::<10>,
dfdx::prelude::Const::<4096>,
),
);
let d_b = d_a.clone() * d_a;
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_mul() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a * b;
c.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a * d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_mul2() {
let mut cx = Graph::new();
let a = cx
.tensor::<(LConst<1>, LConst<1>, Dyn<'a'>, Dyn<'a'>)>()
.set_dyn(vec![82.4, 783.0, 99.6, 974.5], &[1, 1, 2, 2]);
let b = cx.tensor::<R0>().set(vec![0.57735026]);
let mut c = (a * b.expand()).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([[[[82.4, 783.0], [99.6, 974.5]]]]);
let d_b = d_dev.tensor(0.57735026);
let d_c = d_a * d_b.broadcast::<_, dfdx::shapes::Axes4<0, 1, 2, 3>>();
assert_exact(&c.data(), &d_c.as_vec());
}
#[test]
fn test_div() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a / b;
c.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a / d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_max() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a.max(b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a.maximum(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_mod() {
let mut cx = Graph::new();
let a_data = random_vec(3);
let b_data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(a_data.clone());
let b = cx.tensor::<R1<3>>().set(b_data.clone());
let mut c = a % b;
c.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
// No dfdx equivalent
assert_close(
&c.data(),
&a_data
.into_iter()
.zip(b_data)
.map(|(a, b)| a % b)
.collect_vec(),
);
}
// Reduction op tests
#[test]
fn test_sum_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
let d_b = d_a.clone().sum::<_, DAxis<2>>();
let d_c = d_a.clone().sum::<_, DAxis<1>>();
let d_d = d_a.sum::<_, DAxis<0>>();
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_sum_reduce2() {
let mut cx = Graph::new();
let data = random_vec(32 * 10 * 10 * 128);
let a = cx.tensor::<R5<1, 32, 10, 10, 128>>().set(data.clone());
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut d);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
data,
(
DConst::<1>,
DConst::<32>,
DConst::<10>,
DConst::<10>,
DConst::<128>,
),
);
let d_d = d_a.sum::<_, DAxis<2>>();
assert_exact(&d.data(), &d_d.as_vec());
}
#[test]
fn test_max_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.max_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.max_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.max_reduce::<_, LAxis<0>>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
let d_b = d_a.clone().max::<_, DAxis<2>>();
let d_c = d_a.clone().max::<_, DAxis<1>>();
let d_d = d_a.max::<_, DAxis<0>>();
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_mean_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
let d_b = d_a.clone().mean::<_, DAxis<2>>();
let d_c = d_a.clone().mean::<_, DAxis<1>>();
let d_d = d_a.mean::<_, DAxis<0>>();
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_matmul_simple() {
let mut cx = Graph::new();
let a_data = random_vec(256 * 256);
let b_data = random_vec(256 * 256);
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<256>, DConst::<256>));
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<256>, DConst::<256>));
let d_c = d_a.matmul(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_matmul() {
let d_dev = Cpu::default();
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
let mut c = a.matmul(b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
for m in (1..23).step_by(4) {
for k in (1..35).step_by(3) {
for n in (1..70).step_by(7) {
let a_data = random_vec_rng(m * k, &mut rng);
let b_data = random_vec_rng(k * n, &mut rng);
a.set_dyn(a_data.clone(), &[m, k]);
b.set_dyn(b_data.clone(), &[k, n]);
cx.execute();
let d_a = d_dev.tensor_from_vec(a_data, (m, k));
let d_b = d_dev.tensor_from_vec(b_data, (k, n));
let d_c = d_a.matmul(d_b);
assert_close(&c.data(), &d_c.as_vec());
c.drop();
}
}
}
}
#[test]
fn test_attn_matmul() {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
let a = cx
.named_tensor::<R4<1, 32, 11, 128>>("Input")
.set(a_data.clone())
.keep();
let b = cx
.named_tensor::<R4<1, 32, 128, 11>>("Input")
.set(b_data.clone())
.keep();
let mut c = a.matmul(b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
a_data,
(DConst::<1>, DConst::<32>, DConst::<11>, DConst::<128>),
);
let d_b = d_dev.tensor_from_vec(
b_data,
(DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>),
);
let d_c = d_a.matmul(d_b);
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
}
#[test]
fn test_batch_matmul() {
let m = 12;
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>();
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
let mut c = a.matmul(b).retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
for batch in (1..23).step_by(4) {
for k in (1..35).step_by(3) {
for n in (1..48).step_by(7) {
let a_data = random_vec_rng(batch * m * k, &mut rng);
let b_data = random_vec_rng(k * n, &mut rng);
a.set_dyn(a_data.clone(), &[batch, m, k]);
b.set_dyn(b_data.clone(), &[k, n]);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (batch, m, k));
let d_b = d_dev.tensor_from_vec(b_data, (k, n));
let d_c = d_a.matmul(d_b);
assert_close_precision(&c.data(), &d_c.to_dtype::<f32>().as_vec(), 2);
c.drop();
}
}
}
}
#[test]
fn test_batch_matmul_transpose() {
const B: usize = 1;
const M: usize = 48; // Any
const K: usize = 4096; // >= 16, multiple of 16
const N: usize = 4096; // >= 256, multiple of 256
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(B * M * K, &mut rng);
let a = cx.named_tensor::<R3<B, M, K>>("A").set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.named_tensor::<R2<N, K>>("B").set(b_data.clone());
let a_t_data = random_vec_rng(B * K * M, &mut rng);
let a_t = cx.named_tensor::<R3<B, K, M>>("A_T").set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.named_tensor::<R2<K, N>>("B_T").set(b_t_data.clone());
let mut a_b = a.matmul(b.permute::<_, LAxes2<1, 0>>()).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, LAxes3<0, 2, 1>>()
.matmul(b.permute::<_, LAxes2<1, 0>>())
.retrieve();
let mut a_t_b_t = a_t.permute::<_, LAxes3<0, 2, 1>>().matmul(b_t).retrieve();
cx.compile(
<(GenericCompiler, CudaCompiler<f32>)>::default(),
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<B>, DConst::<M>, DConst::<K>));
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<N>, DConst::<K>));
let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::<B>, DConst::<K>, DConst::<M>));
let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::<K>, DConst::<N>));
let d_a_b = d_a.clone().matmul(d_b.clone().permute::<_, DAxes2<1, 0>>());
let d_a_b_t = d_a.matmul(d_b_t.clone());
let d_a_t_b = d_a_t
.clone()
.permute::<_, DAxes3<0, 2, 1>>()
.matmul(d_b.permute::<_, DAxes2<1, 0>>());
let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t);
assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 1);
assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 1);
assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 1);
assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 1);
}
#[test]
fn test_matmul_transpose() {
const M: usize = 1024; // Any
const K: usize = 16; // >= 16
const N: usize = 767; // >= 256, multiple of 256
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(M * K, &mut rng);
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
let a_t_data = random_vec_rng(K * M, &mut rng);
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
let mut a_b = a.matmul(b.permute()).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, LAxes2<1, 0>>()
.matmul(b.permute())
.retrieve();
let mut a_t_b_t = a_t.permute::<_, LAxes2<1, 0>>().matmul(b_t).retrieve();
cx.compile(
<(GenericCompiler, CudaCompiler<f32>)>::default(),
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<M>, DConst::<K>));
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<N>, DConst::<K>));
let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::<K>, DConst::<M>));
let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::<K>, DConst::<N>));
let d_a_b = d_a.clone().matmul(d_b.clone().permute());
let d_a_b_t = d_a.matmul(d_b_t.clone());
let d_a_t_b = d_a_t
.clone()
.permute::<_, DAxes2<1, 0>>()
.matmul(d_b.permute());
let d_a_t_b_t = d_a_t.permute::<_, DAxes2<1, 0>>().matmul(d_b_t);
assert_close(&a_b.data(), &d_a_b.as_vec());
assert_close(&a_b_t.data(), &d_a_b_t.as_vec());
assert_close(&a_t_b.data(), &d_a_t_b.as_vec());
assert_close(&a_t_b_t.data(), &d_a_t_b_t.as_vec());
}
#[test]
fn test_relu_and_linear() {
// Test single and batch, unoptimized and optimized
let mut cx = Graph::new();
let input_data = random_vec(32);
let w1 = random_vec(32 * 64);
let w2 = random_vec(32 * 64);
let batch = cx
.named_tensor::<R2<2, 32>>("Batch")
.set(random_vec(32 * 2));
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
model.0.weight.set(w1.clone());
model.2.weight.set(w2.clone());
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();
cx.execute();
let unoptimized_b = b.data();
let unoptimized_batch_out = batch_out.data();
b.drop();
batch_out.drop();
cx.compile(
<(GenericCompiler, CudaCompiler<f32>)>::default(),
(&mut b, &mut batch_out),
);
cx.execute();
assert_close_precision(&unoptimized_b, &b.data(), 2);
assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2);
// Test against dfdx
let dev = Cpu::default();
let mut model = <(
dfdx::nn::modules::builders::UnbiasedLinear<32, 64>,
dfdx::nn::modules::builders::ReLU,
dfdx::nn::modules::builders::UnbiasedLinear<64, 32>,
)>::build_on_device(&dev);
// Set weights
model.0.weight = dev
.tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>))
.permute();
model.2.weight = dev
.tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>))
.permute();
let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,));
let out = model.forward(a);
assert_close_precision(&unoptimized_b, &out.as_vec(), 2);
}
#[test]
fn test_rms_norm() {
// Test single and batch, unoptimized and optimized
let inp_data = random_vec(15 * 32);
let weight_data = random_vec(32);
let mut cx = Graph::new();
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());
let model = RMSNorm::<32>::initialize(&mut cx);
model.weight.set(weight_data.clone());
let mut b = model.forward(a).retrieve();
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
cx.execute();
// Test against dfdx
let dev = Cpu::default();
let weight = dev.tensor_from_vec(weight_data, (DConst::<32>,));
let a = dev.tensor_from_vec(inp_data, (DConst::<15>, DConst::<32>));
let var_f32 = a.clone().square().mean::<_, DAxis<1>>();
let std_f32 = (var_f32 + 1e-6).sqrt();
let x_f32 = a / std_f32.broadcast();
let out = weight.broadcast() * x_f32;
assert_close(&b.data(), &out.as_vec());
}
#[test]
fn test_layer_norm() {
let mut cx = Graph::new();
let a_data = random_vec(15 * 16 * 32);
let a = cx.tensor::<R3<15, 16, 32>>().set(a_data.clone());
let mut b = a.layer_norm::<0, _>(1e-5).retrieve();
let mut c = a.layer_norm::<2, _>(1e-5).retrieve();
cx.compile(
<(GenericCompiler, CudaCompiler<f32>)>::default(),
(&mut b, &mut c),
);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<15>, DConst::<16>, DConst::<32>));
let d_b = d_a.clone().normalize::<DAxis<0>>(1e-5);
let d_c = d_a.normalize::<DAxis<2>>(1e-5);
assert_close_precision(&b.data(), &d_b.as_vec(), 2);
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
}
#[test]
fn test_transformer_encoder_block() {
let mut cx = Graph::new();
let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<32, 64, 1> =
InitModule::initialize(&mut cx);
let w_k_weight = random_vec(32 * 32);
model.attention.w_k.weight.set(w_k_weight.clone());
let w_q_weight = random_vec(32 * 32);
model.attention.w_q.weight.set(w_q_weight.clone());
let w_v_weight = random_vec(32 * 32);
model.attention.w_v.weight.set(w_v_weight.clone());
let w_o_weight = random_vec(32 * 32);
model.attention.w_o.weight.set(w_o_weight.clone());
let ff_0_weight = random_vec(32 * 64);
model.ff.0.weight.set(ff_0_weight.clone());
let ff_1_weight = random_vec(64 * 32);
model.ff.2.weight.set(ff_1_weight.clone());
let a_data = random_vec(2 * 32);
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>()
.set_dyn(a_data.clone(), &[1, 2, 3])
.keep();
cx.keep_tensors(state_dict(&model));
let mut b = model.forward(a).retrieve();
cx.execute();
let unopt_b = b.data();
b.drop();
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
cx.execute();
assert_close_precision(&unopt_b, &b.data(), 2);
let d_dev = Cpu::default();
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> =
d_dev
.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<32, 1, 64>, f32>();
d_model.self_attn.w_k.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_v.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_q.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_o.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_o.weight = d_dev
.tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>))
.permute();
d_model.self_attn.w_k.weight = d_dev
.tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>))
.permute();
d_model.self_attn.w_q.weight = d_dev
.tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>))
.permute();
d_model.self_attn.w_v.weight = d_dev
.tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>))
.permute();
d_model.ff.0 .0.weight = d_dev
.tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>))
.permute();
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,));
d_model.ff.0 .2.weight = d_dev
.tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>))
.permute();
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
d_model.norm1.epsilon = 1e-5;
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm2.epsilon = 1e-5;
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>));
let d_b = d_model.forward(d_a);
assert_close_precision(&b.data(), &d_b.as_vec(), 2);
}
#[test]
fn test_common_buffer() {
let data = random_vec(32);
let mut cx = Graph::new();
let a = cx.tensor::<R1<32>>();
a.set(data.clone());
let a1 = cx.tensor::<R1<32>>();
a1.set(data.clone());
let exped = a * a1;
let mut b = exped.log2().retrieve();
let mut c = exped.sin().retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c));
cx.execute();
}
#[test]
fn test_embedding() {
let mut cx = Graph::new();
let batch = cx
.named_tensor::<R2<2, 3>>("Batch")
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0])
.keep();
let a = cx
.named_tensor::<R1<3>>("Single")
.set(vec![1.0, 0.0, 1.0])
.keep();
let model: luminal::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx);
model
.weight
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut batch_out));
cx.execute();
let d_dev = Cpu::default();
let mut d_model: modules::Embedding<3, 4, f32, Cpu> =
<dfdx::nn::modules::builders::Embedding<3, 4>>::build_on_device(&d_dev);
d_model.weight = d_dev.tensor_from_vec(
vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.],
(DConst::<3>, DConst::<4>),
);
let d_a = d_dev.tensor_from_vec(vec![1, 0, 1], (DConst::<3>,));
let d_batch = d_dev.tensor_from_vec(vec![1, 0, 2, 1, 0, 1], (DConst::<2>, DConst::<3>));
let d_b = d_model.forward(d_a);
let d_batch_out = d_model.forward(d_batch);
assert_close(&b.data(), &d_b.as_vec());
assert_close(&batch_out.data(), &d_batch_out.as_vec());
}
#[test]
fn test_slice() {
let data = random_vec(256);
let mut cx = Graph::new();
let a = cx.tensor::<R1<256>>().set(data.clone());
let mut c: GraphTensor<R1<20>> = a
.slice((..Expression::from(20),))
.realize()
.contiguous()
.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<256>,));
let d_c = d_a.slice((..20,));
assert_exact(&c.data(), &d_c.as_vec());
}
#[test]
fn test_pad() {
// Pad a 8x2 mat to 10x4
let data = random_vec(8 * 2);
let mut cx = Graph::new();
let a = cx.tensor::<R2<8, 2>>().set(data.clone());
let mut c = a
.pad::<R2<10, 4>, _, _>(&[(0, 2), (0, 2)])
.contiguous()
.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (8, 2));
// There is no pad function in dfdx, so we concat with zero tensors
let d_b = (d_a, d_dev.zeros_like(&(2, 2))).concat_along(DAxis::<0>);
let d_c = (d_b, d_dev.zeros_like(&(10, 2))).concat_along(DAxis::<1>);
assert_exact(&c.data(), &d_c.as_vec());
}
#[test]
fn test_pad_contig() {
let m = 13;
let k = 24;
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(m * k, &mut rng);
let mut a = cx
.tensor::<(Dyn<'M'>, Dyn<'K'>)>()
.set_dyn(a_data, &[m, k])
.retrieve();
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
.pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')])
.contiguous()
.retrieve();
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
(a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve();
cx.compile(CudaCompiler::<f32>::default(), (&mut a, &mut b, &mut c));
cx.execute();
// Close because b and c are going through 16 bits, while a is not
assert_close(&a.data(), &b.data());
assert_close(&a.data(), &c.data());
}
#[test]
fn test_movement() {
let data = random_vec(32);
let mut cx = Graph::new();
let a = cx.tensor::<R1<32>>().set(data.clone());
let b: GraphTensor<R1<42>> = a.pad(&[(0, 10)]).contiguous().retrieve();
let mut c: GraphTensor<R1<25>> = b
.slice((..Expression::from(25),))
.realize()
.contiguous()
.retrieve();
cx.compile(CudaCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (DConst::<32>,));
let d_c = d_a.slice((..25,));
assert_exact(&c.data(), &d_c.as_vec());
}

View File

@@ -0,0 +1,2 @@
mod fp16;
mod fp32;

View File

@@ -0,0 +1,19 @@
[package]
name = "luminal_metal"
version = "0.2.0"
edition = "2021"
description = "Metal compiler for luminal"
license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
itertools = "0.12.1"
luminal = { path = "../.." }
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"] }
num-traits = "0.2.18"
rand = "0.8.5"
rustc-hash = "1.1.0"
[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }

View File

@@ -0,0 +1,652 @@
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
use itertools::Itertools;
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
ComputePipelineState, Device, MTLResourceOptions, MTLSize,
};
use rustc_hash::FxHashMap;
use crate::{
compile_function, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims,
render_dyn_dim_inputs, select_const, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel,
MetalKernelWrapper, SetInt,
};
use super::prim::*;
use luminal::{
op::{InputTensor, Operator},
prelude::{
petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction},
*,
},
shape::symbolic::BigExpression,
};
use super::other::MetalARange;
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
pub struct MetalSub<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> MetalSub<T> {
pub fn new(
a_shape: ShapeTracker,
b_shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (a_idx_exp, a_valid_exp) = get_idx_valid_exps(a_shape);
let (b_idx_exp, b_valid_exp) = get_idx_valid_exps(b_shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape], 4);
let type_name = T::type_name();
let code = format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
out[idx] =
(({a_valid_exp}) == 0 ? 0.0 : inp_a[{a_idx_exp}])
- (({b_valid_exp}) == 0 ? 0.0 : inp_b[{b_idx_exp}]);
}}
}}
");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_symbols,
dyn_map,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalSub<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_u32(3, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
4,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalSub<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[
(get_buffer_from_tensor(&tensors[0].0), tensors[0].1),
(get_buffer_from_tensor(&tensors[1].0), tensors[1].1),
],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 - input1".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
input_shapes[0],
input_shapes[1],
self.device.clone(),
self.queue.clone(),
self.dyn_map,
)
}
}
None
}
}
#[derive(LuminalPrint, Default)]
pub struct MetalSubtractionCompiler<T: MetalFloat>(PhantomData<T>);
impl<T: MetalFloat> Compiler for MetalSubtractionCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (mut neg_one, mut mul, mut add) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let mut searcher = select_const!(-1.0, T)
.ptr(&mut neg_one)
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul))
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add))
.search(graph);
while searcher.next_match() {
if check_no_delete(graph, &[neg_one, mul, add]) {
continue;
}
let (a, a_edge) = graph
.graph
.edges_directed(add, petgraph::Direction::Incoming)
.find(|e| e.source() != mul)
.map(|e| (e.source(), e.weight().as_data().unwrap()))
.unwrap();
let (b, b_edge) = graph
.graph
.edges_directed(mul, petgraph::Direction::Incoming)
.find(|e| e.source() != neg_one)
.map(|e| (e.source(), e.weight().as_data().unwrap()))
.unwrap();
let b_final_shape = graph
.graph
.edges_connecting(mul, add)
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
if !b_final_shape.is_contiguous()
|| b_final_shape.is_sliced()
|| b_final_shape.is_padded()
{
continue;
}
let sub = graph
.add_op(MetalSub::<T>::new(
a_edge.2,
b_edge.2,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.input(a, a_edge.1, a_edge.2)
.input(b, b_edge.1, b_edge.2)
.finish();
move_outgoing_edge(add, sub, &mut graph.graph);
if graph.get_dests(neg_one).len() == 1 {
graph.graph.remove_node(neg_one);
}
graph.graph.remove_node(mul);
graph.graph.remove_node(add);
}
}
}
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
pub struct MetalEqual<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> MetalEqual<T> {
pub fn new(
a_shape: ShapeTracker,
b_shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (a_idx_exp, a_valid_exp) = get_idx_valid_exps(a_shape);
let (b_idx_exp, b_valid_exp) = get_idx_valid_exps(b_shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape], 4);
let type_name = T::type_name();
let code = format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
{type_name} a_val = (({a_valid_exp}) == 0 ? 0.0 : inp_a[{a_idx_exp}]);
{type_name} b_val = (({b_valid_exp}) == 0 ? 0.0 : inp_b[{b_idx_exp}]);
out[idx] = ({type_name})(a_val == b_val);
}}
}}
");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_symbols,
dyn_map,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalEqual<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_u32(3, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
4,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalEqual<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[
(get_buffer_from_tensor(&tensors[0].0), tensors[0].1),
(get_buffer_from_tensor(&tensors[1].0), tensors[1].1),
],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 == input1 ? 1.0 : 0.0".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
input_shapes[0],
input_shapes[1],
self.device.clone(),
self.queue.clone(),
self.dyn_map,
)
}
}
None
}
}
#[derive(LuminalPrint, Default)]
pub struct MetalEqualCompiler<T: MetalFloat>(PhantomData<T>);
impl<T: MetalFloat> Compiler for MetalEqualCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (mut less_than1, mut less_than2, mut add, mut one, mut sub) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let s = select_const!(1.0, T).ptr(&mut one).edge(
SelectOp::new()
.ty::<MetalLessThan<T>>()
.ptr(&mut less_than1)
.edge(
SelectOp::new()
.ty::<MetalLessThan<T>>()
.ptr(&mut less_than2)
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add)),
)
.edge(SelectOp::new().ty::<MetalSub<T>>().ptr(&mut sub)),
);
let mut searcher = s.search(graph);
while searcher.next_match() {
let lt1_inputs = graph
.graph
.neighbors_directed(less_than1, Direction::Incoming)
.sorted()
.collect::<Vec<_>>();
let lt2_inputs = graph
.graph
.neighbors_directed(less_than2, Direction::Incoming)
.sorted()
.collect::<Vec<_>>();
if lt1_inputs != lt2_inputs {
continue;
}
let inputs = graph
.graph
.edges_directed(less_than1, Direction::Incoming)
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.source())
.collect::<Vec<_>>();
let (a, b) = (inputs[0], inputs[1]);
if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) {
continue;
}
let a_edge = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(a, less_than1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap();
let b_edge = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(b, less_than1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap();
let equals = graph
.add_op(MetalEqual::<T>::new(
a_edge.2,
b_edge.2,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.input(a, a_edge.1, a_edge.2)
.input(b, b_edge.1, b_edge.2)
.finish();
move_outgoing_edge(sub, equals, &mut graph.graph);
graph.graph.remove_node(sub);
graph.safe_remove_node(add, 0);
graph.safe_remove_node(one, 0);
graph.safe_remove_node(less_than2, 0);
graph.safe_remove_node(less_than1, 0);
searcher.clear_cached_results();
}
}
}
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
pub struct MetalGather<T> {
pipeline: ComputePipelineState,
device: Device,
queue: CommandQueue,
pub embed_dim: usize,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> MetalGather<T> {
fn new(device: Device, queue: CommandQueue, embed_dim: usize) -> Self {
let type_name = T::type_name();
Self {pipeline: compile_function("metal_gather", &format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void metal_gather(device float *inp [[buffer(0)]], device {type_name} *weights [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_embeddings [[buffer(3)]], device int& embedding_dim [[buffer(4)]], uint2 i_ [[thread_position_in_grid]]) {{
if (i_.x < n_embeddings && i_.y < embedding_dim) {{
out[i_.x * embedding_dim + i_.y] = weights[(int)inp[i_.x] * embedding_dim + i_.y];
}}
}}"), &device), device, embed_dim, queue, _phantom: Default::default()}
}
}
impl<T: MetalFloat> Operator for MetalGather<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup buffers
let indexes = tensors[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<Vec<f32>>()
.unwrap();
let index_buffer = self.device.new_buffer_with_data(
unsafe { std::mem::transmute(indexes.as_ptr()) },
(indexes.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let b_inp = tensors[1]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<MetalBuffer>()
.unwrap();
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let out = self.device.new_buffer(
(indexes.len() * self.embed_dim * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let encoder = command_buffer
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(b_inp), 0);
encoder.set_buffer(2, Some(&out), 0);
encoder.set_u32(3, indexes.len() as u32);
encoder.set_u32(4, self.embed_dim as u32);
// Execute
encoder.dispatch_threads(
MTLSize {
width: indexes.len() as u64,
height: self.embed_dim as u64,
depth: 1,
},
MTLSize {
width: 16,
height: 16,
depth: 1,
},
);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
}
#[derive(LuminalPrint, Default)]
pub struct MetalGatherCompiler<T: MetalFloat>(PhantomData<T>);
impl<T: MetalFloat> Compiler for MetalGatherCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (mut ind_copy, mut arange, mut equal, mut mul, mut sum_reduce) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
let s = SelectOp::new()
.ty::<MetalARange<T>>()
.ptr(&mut arange)
.edge(
SelectOp::new()
.ty::<MetalCopyToDevice<T>>()
.ptr(&mut ind_copy)
.edge(SelectOp::new().ty::<MetalEqual<T>>().ptr(&mut equal)),
)
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul))
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.ptr(&mut sum_reduce),
);
let mut searcher = s.search(graph);
while searcher.next_match() {
if check_no_delete(graph, &[arange, equal, mul, sum_reduce]) {
continue;
}
let embedding_dim = graph
.graph
.edges_directed(mul, Direction::Incoming)
.find(|e| e.source() != equal && !e.weight().is_schedule())
.unwrap()
.weight()
.as_data()
.unwrap()
.2
.shape()[2]
.to_usize()
.unwrap();
let gather = graph
.add_op(MetalGather::<T>::new(
dev.clone(),
queue.clone(),
embedding_dim,
))
.finish();
move_incoming_edge(ind_copy, gather, &mut graph.graph);
graph.safe_remove_node(equal, 1);
move_incoming_edge(mul, gather, &mut graph.graph);
move_outgoing_edge(sum_reduce, gather, &mut graph.graph);
graph.graph.remove_node(sum_reduce);
graph.safe_remove_node(mul, 0);
graph.safe_remove_node(ind_copy, 0);
graph.safe_remove_node(arange, 0);
}
}
}
#[cfg(test)]
mod tests {
use luminal::{prelude::*, tests::assert_close};
use crate::MetalCompiler;
#[test]
fn test_subtraction() {
let mut cx = Graph::new();
let a = cx
.tensor::<R1<10>>()
.set(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
let b = cx.tensor::<R0>().set(vec![1.]);
let mut c = (a - b.expand()).retrieve();
let mut d = (-a + b.expand()).retrieve();
cx.execute();
let unopt_c = c.data();
c.drop();
let unopt_d = d.data();
d.drop();
cx.compile(MetalCompiler::<f16>::default(), (&mut c, &mut d));
cx.execute();
assert_close(&unopt_c, &c.data());
assert_close(&unopt_d, &d.data());
}
}

View File

@@ -0,0 +1,320 @@
use std::{any::Any, cell::UnsafeCell, ops::Deref, sync::Arc};
use itertools::Itertools;
use metal_rs::{Buffer, CommandBuffer, CommandQueue, Device};
use petgraph::{
stable_graph::NodeIndex,
visit::EdgeRef,
Direction::{self},
};
use rustc_hash::{FxHashMap, FxHashSet};
use luminal::{
op::{InputTensor, Operator},
prelude::*,
};
use crate::{MetalBuffer, MetalKernel, MetalKernelWrapper};
use super::get_buffer_from_tensor;
#[derive(Default, LuminalPrint)]
pub struct CommandBufferCompiler;
impl Compiler for CommandBufferCompiler {
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, _: T) {
let is_metal: FxHashSet<NodeIndex> = graph
.graph
.node_indices()
.collect::<Vec<_>>()
.into_iter()
.filter(|i| {
graph
.graph
.node_weight_mut(*i)
.unwrap()
.custom("metal", Box::new(()))
.is_some()
})
.collect();
// Do forward pass
let mut forward_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
for node in graph
.graph
.node_indices()
.filter(|n| graph.graph.edges_directed(*n, Direction::Incoming).count() == 0)
.sorted()
{
let mut stack = vec![node];
while let Some(node) = stack.pop() {
// Get rank as max of predecessors
let rank = graph
.graph
.neighbors_directed(node, Direction::Incoming)
.filter_map(|i| forward_map.get(&i).map(|r| (i, *r)))
.map(|(node_index, rank)| {
if is_metal.contains(&node) != is_metal.contains(&node_index) {
rank + 1
} else {
rank
}
})
.max()
.unwrap_or_default();
// Max it with the current entry in the map or insert
if let Some(entry) = forward_map.get_mut(&node) {
if rank > *entry {
*entry = rank;
stack.extend(graph.graph.neighbors_directed(node, Direction::Outgoing));
}
} else {
forward_map.insert(node, rank);
stack.extend(graph.graph.neighbors_directed(node, Direction::Outgoing));
}
}
}
// Do backward pass
let mut backward_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
for node in graph
.graph
.node_indices()
.filter(|n| graph.graph.edges_directed(*n, Direction::Outgoing).count() == 0)
.sorted()
{
let mut stack = vec![node];
while let Some(node) = stack.pop() {
// Get rank as max of successors
let rank = graph
.graph
.neighbors_directed(node, Direction::Outgoing)
.filter_map(|i| backward_map.get(&i).map(|r| (i, *r)))
.map(|(node_index, rank)| {
if is_metal.contains(&node) != is_metal.contains(&node_index) {
rank + 1
} else {
rank
}
})
.max()
.unwrap_or_default();
// Max it with the current entry in the map or insert
if let Some(entry) = backward_map.get_mut(&node) {
if rank > *entry {
*entry = rank;
stack.extend(graph.graph.neighbors_directed(node, Direction::Incoming));
}
} else {
backward_map.insert(node, rank);
stack.extend(graph.graph.neighbors_directed(node, Direction::Incoming));
}
}
}
// Get sets (Rank -> # of nodes with that rank)
let forward_sets = forward_map
.iter()
.sorted_by_key(|(_, v)| **v)
.group_by(|(_, v)| **v)
.into_iter()
.map(|(k, g)| (k, g.count()))
.collect::<FxHashMap<_, _>>();
let backward_sets = backward_map
.iter()
.sorted_by_key(|(_, v)| **v)
.group_by(|(_, v)| **v)
.into_iter()
.map(|(k, g)| (k, g.count()))
.collect::<FxHashMap<_, _>>();
// Assign nodes to sets
let mut node_sets: FxHashMap<(bool, usize), FxHashSet<NodeIndex>> = FxHashMap::default();
for node in graph.graph.node_indices().filter(|i| is_metal.contains(i)) {
let forward_bigger =
forward_sets[&forward_map[&node]] >= backward_sets[&backward_map[&node]];
node_sets
.entry((
forward_bigger,
if forward_bigger {
forward_map[&node]
} else {
backward_map[&node]
},
))
.and_modify(|set| {
set.insert(node);
})
.or_insert({
let mut set = FxHashSet::default();
set.insert(node);
set
});
}
// Add sets to graph
let dev = Device::system_default().unwrap();
let mut queue = dev.new_command_queue();
let mut num_buffers_on_queue = 0;
for set in node_sets.values() {
if num_buffers_on_queue >= 63 {
num_buffers_on_queue = 0;
queue = dev.new_command_queue();
} else {
num_buffers_on_queue += 1;
}
#[allow(clippy::arc_with_non_send_sync)]
let buffer = Arc::new(UnsafeCell::new(queue.new_command_buffer().to_owned()));
let exec = graph
.add_op(ExecuteMetalKernels {
queue: queue.clone(),
buffer: buffer.clone(),
})
.finish();
for node in set {
// Create schedule dependency
graph.add_schedule_dependency(*node, exec);
// Wrap node in MetalKernelOperation
let wrapper = graph
.graph
.node_weight_mut(*node)
.unwrap()
.custom("metal", Box::new(()))
.unwrap()
.downcast::<MetalKernelWrapper>()
.unwrap();
*graph.graph.node_weight_mut(*node).unwrap() = Box::new(CommandBufferWrapper {
wrapper,
buffer: buffer.clone(),
dyn_map: &graph.dyn_map,
});
// Create schedule dependencies from exec to consumers
for outside_node in graph
.graph
.edges_directed(*node, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.map(|e| e.target())
.filter(|n| !set.contains(n))
.collect::<Vec<_>>()
{
graph.add_schedule_dependency(exec, outside_node);
}
}
}
}
}
#[derive(LuminalEqFalse, LuminalPrint)]
struct ExecuteMetalKernels {
queue: CommandQueue,
buffer: Arc<UnsafeCell<CommandBuffer>>,
}
impl Operator for ExecuteMetalKernels {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let buffer = unsafe { &mut *self.buffer.get() };
buffer.commit();
buffer.wait_until_completed();
*buffer = self.queue.new_command_buffer().to_owned();
vec![]
}
}
#[derive(Clone, LuminalEqFalse)]
struct CommandBufferWrapper {
wrapper: Box<MetalKernelWrapper>,
buffer: Arc<UnsafeCell<CommandBuffer>>,
dyn_map: *const FxHashMap<char, usize>,
}
impl std::fmt::Debug for CommandBufferWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalKernel({:?})", self.wrapper.0)
}
}
impl MetalKernel for CommandBufferWrapper {
fn intermediate_buffer_sizes(
&self,
input_shapes: &[ShapeTracker],
) -> Vec<symbolic::BigExpression> {
self.wrapper.0.intermediate_buffer_sizes(input_shapes)
}
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<symbolic::BigExpression> {
self.wrapper.0.output_buffer_sizes(input_shapes)
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
_: &metal_rs::CommandBufferRef,
intermediate_buffers: &[&Buffer],
output_buffers: &[&Buffer],
) {
self.wrapper.0.metal_forward(
inputs,
unsafe { &*self.buffer.get() },
intermediate_buffers,
output_buffers,
);
}
fn without_command_buffer(
&self,
inputs: &[(&Buffer, ShapeTracker)],
intermediate_buffers: &[&Buffer],
output_buffers: &[&Buffer],
) {
self.metal_forward(
inputs,
unsafe { &*self.buffer.get() },
intermediate_buffers,
output_buffers,
)
}
}
impl Operator for CommandBufferWrapper {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
self.without_storage_buffers(
&inp.iter()
.map(|(t, sh)| (get_buffer_from_tensor(t).deref(), *sh))
.collect::<Vec<_>>(),
unsafe { &*self.buffer.get() },
unsafe { self.dyn_map.as_ref().unwrap() },
)
.into_iter()
.map(|b| Tensor::new(MetalBuffer(b)))
.collect()
}
#[allow(clippy::arc_with_non_send_sync)]
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}
#[cfg(test)]
#[test]
fn test_common_buffer() {
use luminal::{
prelude::*,
tests::{assert_close, random_vec},
};
use crate::MetalCompiler;
let mut cx = Graph::new();
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let b = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let c = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let mut d = ((a + b) * c).retrieve();
cx.execute();
let d_unopt = d.data();
d.drop();
cx.compile(MetalCompiler::<f16>::default(), &mut d);
cx.execute();
assert_close(&d.data(), &d_unopt);
}

View File

@@ -0,0 +1,459 @@
use rustc_hash::{FxHashMap, FxHashSet};
use std::{any::Any, marker::PhantomData, ops::Deref, sync::Arc};
use itertools::Itertools;
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
ComputePipelineState, Device, MTLResourceOptions,
};
use luminal::{
op::{InputTensor, Operator},
prelude::{
petgraph::{visit::EdgeRef, Direction},
*,
},
};
use crate::{get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper};
use self::symbolic::BigExpression;
use super::{
compile_function, get_idx_valid_exps, input_dyn_dims, prim::MetalConstant,
render_dyn_dim_inputs, DispatchNElements, SetInt,
};
#[derive(Default, Debug)]
pub struct ElementwiseFusionCompiler<T>(PhantomData<T>);
impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let device = Device::system_default().unwrap();
let queue = device.new_command_queue();
// Find two elementwise ops that have a contiguous edge
let (mut a, mut b) = (NodeIndex::default(), NodeIndex::default());
let mut selector = SelectOp::new()
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
.ptr(&mut a)
.edge(
SelectOp::new()
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
.ptr(&mut b),
)
.search(graph);
let mut fused_ops = FxHashSet::default();
while selector.next_match() {
// More than one connecting edge
if graph.no_delete.contains(&a)
|| (graph
.graph
.edges_directed(a, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.count()
> 1
&& !graph
.graph
.node_weight(a)
.unwrap()
.as_any()
.is::<MetalConstant<T>>())
{
continue;
}
// Connecting shape isn't contiguous
let (edge_id, (to_input, _, connecting_shape)) = graph
.graph
.edges_connecting(a, b)
.find_map(|e| e.weight().as_data().map(|i| (e.id(), i)))
.unwrap();
if !connecting_shape.is_contiguous()
|| connecting_shape.is_sliced()
|| connecting_shape.is_padded()
{
continue;
}
// Fuse into a FusedElementwiseOp
let new_op;
let mut a_equation = graph
.node_custom::<String, _>(a, "elementwise", ())
.unwrap();
let mut curr_input = to_input;
// Keep track of original edges to a and b
let a_orig_edges = graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
.sorted_by_key(|i| i.1)
.collect::<Vec<_>>();
let b_orig_edges = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
.sorted_by_key(|i| i.1)
.collect::<Vec<_>>();
// Remove edge a -> b, and decrement indexes of all edges higher than it
graph.graph.remove_edge(edge_id);
for edge in graph
.graph
.edges_directed(b, Direction::Incoming)
.map(|e| e.id())
.collect_vec()
{
if let Some(Dependency::Data { input_order, .. }) =
graph.graph.edge_weight_mut(edge)
{
if *input_order > curr_input {
*input_order -= 1;
}
}
}
// Add edges if they don't exist
for input_edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.sorted_by_key(|i| i.1)
.collect_vec()
{
// Find edge or add it
if !graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.any(|(src, _, out_ind, _)| src == input_edge.0 && out_ind == input_edge.2)
{
// Move all edges >= curr_input up by one
for edge in graph
.graph
.edges_directed(b, Direction::Incoming)
.map(|e| e.id())
.collect_vec()
{
if let Some(Dependency::Data { input_order, .. }) =
graph.graph.edge_weight_mut(edge)
{
if *input_order >= curr_input {
*input_order += 1;
}
}
}
// Add edge
graph.graph.add_edge(
input_edge.0,
b,
Dependency::Data {
input_order: curr_input,
output_order: input_edge.2,
shape: input_edge.3,
},
);
curr_input += 1;
}
}
// Alter a_equation to reflect the correct input indexes
let mut replacements = vec![];
for (src, inp_ind, out_ind) in a_orig_edges {
let n = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
.unwrap();
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
}
a_equation = multi_replace(&a_equation, &replacements);
// Alter b_equation to reflect the correct input indexes
replacements.clear();
for (src, inp_ind, out_ind) in b_orig_edges {
if inp_ind > to_input {
let n = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
.unwrap();
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
}
}
if let Some(fused_op) = graph
.graph
.node_weight_mut(b)
.unwrap()
.as_any_mut()
.downcast_mut::<FusedElementwiseOp<T>>()
{
// B is already fused, just combine with b
new_op = b;
// Render a into b as input to_input
fused_op.equation = multi_replace(&fused_op.equation, &replacements)
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
} else {
let mut b_equation = graph
.node_custom::<String, _>(b, "elementwise", ())
.unwrap();
b_equation = multi_replace(&b_equation, &replacements)
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// B is not a fused op, let's create a new one
new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel: None,
dyn_map: &graph.dyn_map,
dyn_chars: vec![],
equation: b_equation,
queue: queue.clone(),
device: device.clone(),
_phantom: Default::default(),
})
.finish();
move_incoming_edge(b, new_op, &mut graph.graph);
move_outgoing_edge(b, new_op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
b,
new_op,
);
graph.graph.remove_node(b);
fused_ops.remove(&b);
}
// Remove a
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
a,
new_op,
);
if graph
.graph
.edges_directed(a, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.count()
== 0
{
graph.graph.remove_node(a);
}
fused_ops.remove(&a);
fused_ops.insert(new_op);
selector.reset();
}
// Compile all the kernels we placed
let type_name = T::type_name();
for fused_op in fused_ops {
let edges = graph
.graph
.edges_directed(fused_op, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.collect_vec();
if let Some(op) = graph
.graph
.node_weight_mut(fused_op)
.unwrap()
.as_any_mut()
.downcast_mut::<FusedElementwiseOp<T>>()
{
let (dyn_chars, rendered) = render_dyn_dim_inputs(
&edges.iter().map(|i| i.2).collect_vec(),
edges.len() + 2,
);
for (inp_ind, _, sh) in &edges {
let (ind, val) = get_idx_valid_exps(*sh);
if (sh.is_contiguous() && !sh.is_sliced() && !sh.is_padded())
|| (!sh.is_sliced() && !sh.is_padded())
{
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("(float)input{inp_ind}[{ind}]"),
);
} else {
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("(({val} != 0) ? (float)input{inp_ind}[{ind}] : 0.0)"),
);
}
}
let kernel = format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel({} device {type_name} *out [[buffer({})]], device uint& n_elements [[buffer({})]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
out[idx] = ({type_name})({});
}}
}}",
edges
.iter()
.map(|(inp_ind, _, _)| format!(
"device {type_name}* input{inp_ind} [[buffer({inp_ind})]],"
))
.collect_vec()
.join(" "),
edges.len(),
edges.len() + 1,
op.equation
);
op.kernel = Some(compile_function("mkernel", &kernel, &device));
op.dyn_chars = dyn_chars;
}
}
}
}
fn multi_replace(input: &str, replacements: &[(String, String)]) -> String {
// Use Unicode Private Use Areas as unlikely placeholders
// Starting at U+E000
let mut placeholder_start = 0xE000;
let mut output = input.to_string();
// Generate placeholder characters for each replacement pair
let mut placeholders: Vec<(String, char)> = Vec::new();
for (from, _) in replacements {
let placeholder = std::char::from_u32(placeholder_start).unwrap();
placeholder_start += 1;
placeholders.push((from.clone(), placeholder));
}
// First pass: Replace all target strings with placeholders
for (from, placeholder) in &placeholders {
output = output.replace(from, &placeholder.to_string());
}
// Second pass: Replace placeholders with final strings
for ((_, placeholder), (_, to)) in placeholders.iter().zip(replacements) {
output = output.replace(&placeholder.to_string(), to);
}
output
}
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
pub struct FusedElementwiseOp<T> {
kernel: Option<ComputePipelineState>,
dyn_map: *const FxHashMap<char, usize>,
dyn_chars: Vec<char>,
equation: String,
queue: CommandQueue,
device: Device,
_phantom: PhantomData<T>,
}
impl<T> MetalKernel for FusedElementwiseOp<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
if input_shapes.len() == 1 {
// Assume since it's a unary op, we're outputting 1-1 elements from input
vec![input_shapes[0].n_physical_elements() * std::mem::size_of::<T>()]
} else {
// If it isn't a unary op, output the contiguous buffer length
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
}
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(self.kernel.as_ref().unwrap());
let out_size = inputs
.iter()
.map(|i| i.1.n_elements().to_usize().unwrap())
.max()
.unwrap();
// Set function inputs
for (i, (buf, _)) in inputs.iter().enumerate() {
encoder.set_buffer(i as u64, Some(*buf), 0);
}
encoder.set_buffer(inputs.len() as u64, Some(output_buffers[0]), 0);
encoder.set_u32(inputs.len() + 1, out_size as u32);
input_dyn_dims(
&self.dyn_chars,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
inputs.len() + 2,
);
// Execute
encoder.dispatch_1d(out_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for FusedElementwiseOp<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let out = self.device.new_buffer(
self.output_buffer_sizes(&tensors.iter().map(|(_, s)| *s).collect_vec())[0]
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
.unwrap() as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&tensors
.iter()
.map(|(t, s)| (get_buffer_from_tensor(t).deref(), *s))
.collect_vec(),
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new(self.equation.clone()));
}
None
}
}
#[cfg(test)]
mod tests {
use luminal::{
prelude::*,
tests::{assert_close, random_vec},
};
use crate::MetalCompiler;
#[test]
fn test_fusion() {
let mut cx = Graph::new();
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
let mut c = (a.exp2() - b.sin()).relu().retrieve();
cx.execute();
let unopt_c = c.data();
c.drop();
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut c);
cx.execute();
assert_close(&c.data(), &unopt_c);
}
}

View File

@@ -0,0 +1,307 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
_MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \
float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
addr_space _MLX_BFloat16 &lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \
_MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
/////////////////////////////////////////////////////////////////////////////
// Bfloat numeric limits
/////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
namespace metal {
template <>
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
static constexpr constant int digits = 8;
static constexpr constant int digits10 = 2;
static constexpr constant int max_digits10 = 4;
static constexpr constant int radix = 2;
static constexpr constant int min_exponent = -125;
static constexpr constant int min_exponent10 = -37;
static constexpr constant int max_exponent = 128;
static constexpr constant int max_exponent10 = 38;
static constexpr bfloat16_t min() {
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t lowest() {
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t max() {
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t epsilon() {
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t round_error() {
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t infinity() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t quiet_NaN() {
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t signaling_NaN() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t denorm_min() {
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
}
};
METAL_FUNC bool isnan(_MLX_BFloat16 x) { return x != x; }
} // namespace metal
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#include "bf16_math.h"

View File

@@ -0,0 +1,365 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
/*
Following the Metal Shading Language Specification (Metal 3.1)
"bfloat is an extended itypeing point type that only allows implicit conversion
to a type of greater itypeing point rank. While bfloat can be implicitly
converted to itype, it cannot be implicitly converted to half, and neither
itype nor half can be implicitly converted to bfloat."
Further, as far as I can tell, the stdlib math/simd functions are not defined
for bfloat and calling with an argument of type bfloat will result in that
argument getting implicitly converted to itype which then returns an output
that is (likely) a itype which cannot be implicitly converted into a bfloat
This leads to situations where
bfloat a = 5.0bf;
bfloat b = metal::abs(a); // this will throw an error since abs return itype
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
For the moment, I will be adding overloaded instantiations of the math
functions to accordingly automatically handle the casting
*/
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
\
METAL_FUNC otype abs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acos(itype x) { \
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acosh(itype x) { \
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asin(itype x) { \
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asinh(itype x) { \
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atan(itype y_over_x) { \
return static_cast<otype>( \
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
} \
METAL_FUNC otype atan2(itype y, itype x) { \
return static_cast<otype>( \
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atanh(itype x) { \
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype ceil(itype x) { \
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cos(itype x) { \
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cosh(itype x) { \
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cospi(itype x) { \
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype divide(itype x, itype y) { \
return static_cast<otype>( \
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype exp(itype x) { \
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp10(itype x) { \
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp2(itype x) { \
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fabs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fdim(itype x, itype y) { \
ctype t = static_cast<ctype>(x - y); \
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
} \
METAL_FUNC otype floor(itype x) { \
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fma(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fma( \
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
} \
METAL_FUNC otype fmax(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype fmin(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype fmod(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fract(itype x) { \
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype frexp(itype x, thread int &exp) { \
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
} \
METAL_FUNC otype ldexp(itype x, int k) { \
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
} \
METAL_FUNC otype log(itype x) { \
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log10(itype x) { \
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log2(itype x) { \
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype max(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype max3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype median3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype min(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype min3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3(static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), mfast)); \
} \
METAL_FUNC otype nextafter(itype x, itype y) { \
return static_cast<otype>( \
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
} \
METAL_FUNC otype pow(itype x, itype y) { \
return static_cast<otype>( \
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype powr(itype x, itype y) { \
return static_cast<otype>( \
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype rint(itype x) { \
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype round(itype x) { \
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype rsqrt(itype x) { \
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sin(itype x) { \
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinh(itype x) { \
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinpi(itype x) { \
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sqrt(itype x) { \
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tan(itype x) { \
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanh(itype x) { \
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanpi(itype x) { \
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype trunc(itype x) { \
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
}
namespace metal {
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
__METAL_MAYBE_FAST_MATH__);
namespace fast {
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
__METAL_FAST_MATH__);
} // namespace fast
namespace precise {
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
__METAL_PRECISE_MATH__);
} // namespace precise
} // namespace metal
///////////////////////////////////////////////////////////////////////////////
// Metal simd for bfloat16
///////////////////////////////////////////////////////////////////////////////
#define instantiate_metal_simd_comm_funcs(itype, otype, ctype, itype_to_ctype, \
ctype_to_otype) \
\
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
return ctype_to_otype( \
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
return ctype_to_otype( \
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down(itype data, itype filling_data, \
ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down(itype data, itype filling_data, \
ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up(itype data, itype filling_data, \
ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up(itype data, itype filling_data, \
ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
return ctype_to_otype( \
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
}
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
\
METAL_FUNC otype simd_max(itype data) { \
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_min(itype data) { \
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_product(itype data) { \
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_sum(itype data) { \
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_xor(itype data) { \
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(bfloat16_t, bfloat16_t, uint16_t,
bfloat16_to_uint16, uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal

View File

@@ -0,0 +1,115 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
struct complex64_t;
template <typename T>
static constexpr constant bool can_convert_to_complex64 =
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_complex64 =
!is_same_v<T, complex64_t> &&
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
struct complex64_t {
float real;
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
// Conversions to complex64_t
template <typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) thread : real(x), imag(0) {}
template <typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
template <typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) device : real(x), imag(0) {}
template <typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Conversions from complex64_t
template <typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const thread {
return static_cast<T>(real);
}
template <typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const threadgroup {
return static_cast<T>(real);
}
template <typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const device {
return static_cast<T>(real);
}
template <typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const constant {
return static_cast<T>(real);
}
};
constexpr complex64_t operator-(complex64_t x) { return {-x.real, -x.imag}; }
constexpr bool operator>=(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
}
constexpr bool operator>(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
}
constexpr bool operator<=(complex64_t a, complex64_t b) {
return operator>=(b, a);
}
constexpr bool operator<(complex64_t a, complex64_t b) {
return operator>(b, a);
}
constexpr bool operator==(complex64_t a, complex64_t b) {
return a.real == b.real && a.imag == b.imag;
}
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a.real * b.real + a.imag * b.imag;
auto y = a.imag * b.real - a.real * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
return {real, imag};
}

View File

@@ -0,0 +1,16 @@
// Copyright © 2023 Apple Inc.
#pragma once
#ifdef __METAL__
#define MTL_CONST constant
#else
#define MTL_CONST
#endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;

View File

@@ -0,0 +1,539 @@
// Copyright © 2023 Apple Inc.
// #pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BROWS,
int BCOLS,
int BK,
int vec_size,
int tgp_size,
bool transpose,
bool ldK,
int tgp_padding = 0>
struct BlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
// Leading dimension for src
const int src_ld;
// Stride along reduction axis between blocks
const int tstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tstride(
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
// Iterate over rows of block
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
// Row is in bounds, we check against column
if ((bi + i) < src_tile_dim.y) {
// Use fast thread memory for bound checks
short tmp_idx[vec_size];
T tmp_val[vec_size];
// Make sure tmp_idx only contains valid indices
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
}
// Read all valid indices into tmp_val
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
}
// Zero out unneeded values
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
// Row is out of bounds, we just fill tgp memory with zeros
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tstride;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct BlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC BlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct GEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t = BlockLoader<
T,
BM,
BK,
BK,
vec_size,
tgp_size,
transpose_a,
true,
tgp_padding_a>;
using loader_b_t = BlockLoader<
T,
BK,
BN,
BK,
vec_size,
tgp_size,
transpose_b,
false,
tgp_padding_b>;
using mma_t = BlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant int& M [[buffer(3)]],
const constant int& N [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& batch_stride_a [[buffer(6)]],
const constant int& batch_stride_b [[buffer(7)]],
const constant int& batch_size_b [[buffer(8)]],
const constant int& batch_stride_c [[buffer(9)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
// Adjust for batch
A += batch_stride_a * tid.z;
B += batch_stride_b * (tid.z / batch_size_b);
C += batch_stride_c * tid.z;
// Adjust for transpose
const int lda_dev = transpose_a ? M : K;
const int ldb_dev = transpose_b ? K : N;
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
A += transpose_a ? c_row : c_row * K;
B += transpose_b ? c_col * K : c_col;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
mma_t mma_op(simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned && K_aligned) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN aligned, K unaligned loop
else if (MN_aligned && !K_aligned) {
// Main loop
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Loop tail
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MNK unaligned loop
else { // Loop over K - unaligned case
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, N);
return;
} else {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(short2(BK, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_none);
mma_op.store_result_safe(C, N, src_tile_dims);
return;
}
}
}
};

View File

@@ -0,0 +1,95 @@
// Copyright © 2023 Apple Inc.
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/gemm.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant int &M [[buffer(3)]],
const constant int &N [[buffer(4)]],
const constant int &K [[buffer(5)]],
const constant int &batch_stride_a [[buffer(6)]],
const constant int& batch_stride_b [[buffer(7)]],
const constant int& batch_size_b [[buffer(8)]],
const constant int& batch_stride_c [[buffer(9)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
A, B, C,
M, N, K,
batch_stride_a, batch_stride_b, batch_size_b, batch_stride_c,
tgp_memory,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant int &M [[buffer(3)]], \
const constant int &N [[buffer(4)]], \
const constant int &K [[buffer(5)]], \
const constant int &batch_stride_a [[buffer(6)]], \
const constant int& batch_stride_b [[buffer(7)]], \
const constant int& batch_size_b [[buffer(8)]], \
const constant int& batch_stride_c [[buffer(9)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
// TODO: Accumulation in different type

View File

@@ -0,0 +1,575 @@
// Copyright © 2023 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/defines.h"
#include "KERNEL_PATH/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Threadgroup in_vec cache
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
// Block position
int out_row = (tid.x * BM + simd_gid) * TM;
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Advance matrix
mat += out_row * in_vec_size;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
}
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)simd_gid;
(void)simd_lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
// Threadgroup accumulation results
threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN;
int out_col = (tid.x * BN + lid.x) * TN;
int in_row = lid.y * TM;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
#pragma clang loop unroll(full)
for(int i = 0; i < TN; i++) {
tgp_results[lid.y * TN + i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j];
}
}
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);

View File

@@ -0,0 +1,228 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/defines.h"
#include "KERNEL_PATH/utils.h"
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,312 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "loader.h"
#include "mma.h"
#include "transforms.h"
#include "../utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
if (!M_aligned) {
short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_a.set_mask(tile_dims_A, mask_A);
}
if (!N_aligned) {
short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
loader_b.set_mask(tile_dims_B, mask_B);
}
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(mask_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(mask_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.set_mask(tile_dims_A_last, mask_A);
loader_b.set_mask(tile_dims_B_last, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(C, params->ldc);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(C, params->ldc);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,89 @@
// Copyright © 2024 Apple Inc.
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
gemm_kernel::run(
A, B, C,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant GEMMParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -0,0 +1,260 @@
// Copyright © 2024 Apple Inc.
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -0,0 +1,280 @@
// Copyright © 2024 Apple Inc.
#include "KERNEL_PATH/bf16.h"
#include "KERNEL_PATH/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);

View File

@@ -0,0 +1,160 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "../utils.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void set_mask(
thread const short2& src_tile_dims,
thread bool mask[n_rows][vec_size]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
mask[i][j] =
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
simdgroup_barrier(mem_flags::mem_none);
// Use fast thread memory for bound checks
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
}
simdgroup_barrier(mem_flags::mem_none);
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
}
simdgroup_barrier(mem_flags::mem_none);
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,264 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "transforms.h"
#include "../utils.h"
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,79 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// GEMM param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct GEMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int swizzle_log;
const int gemm_k_iterations_aligned;
};
struct GEMMSpiltKParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int split_k_partitions;
const int split_k_partition_stride;
const int split_k_partition_size;
const int gemm_k_iterations_aligned;
};
struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha;
const float beta;
const int fdc;
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,63 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "../utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,5 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/params.h"

View File

@@ -0,0 +1,9 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@@ -0,0 +1,212 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "bf16.h"
#include "complex.h"
#include <metal_math>
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U> struct Limits {
static const constant U max;
static const constant U min;
static const constant U finite_max;
static const constant U finite_min;
};
#define instantiate_default_limit(type) \
template <> struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <> struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(uint elem, device const int *shape,
device const size_t *strides, int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
inline size_t elem_to_loc(uint elem, constant const int *shape,
constant const size_t *strides, int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(uint3 elem, constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {static_cast<uint>(elem.x * a_strides[NDIM - 1] +
elem.y * a_strides[NDIM - 2]),
static_cast<uint>(elem.x * b_strides[NDIM - 1] +
elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(uint3 elem, constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t &stride) {
return elem * stride;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(uint3 elem, constant const int *shape,
constant const size_t *strides, int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape,
constant const size_t *a_strides,
constant const size_t *b_strides, int ndim) {
uint2 loc = {static_cast<uint>(elem.x * a_strides[ndim - 1] +
elem.y * a_strides[ndim - 2]),
static_cast<uint>(elem.x * b_strides[ndim - 1] +
elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint elem_to_loc_nd(uint elem, device const int *shape,
device const size_t *strides);
template <>
inline uint elem_to_loc_nd<1>(uint elem, device const int *shape,
device const size_t *strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(uint elem, device const int *shape,
device const size_t *strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(uint elem, device const int *shape,
device const size_t *strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(uint elem, device const int *shape,
device const size_t *strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) { return (N + M - 1) / M; }
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
bfloat16_t ret =
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
return ret;
}

View File

@@ -0,0 +1,458 @@
use std::{
any::{Any, TypeId},
fmt::{Debug, Write},
ops::Deref,
sync::Arc,
};
#[cfg(test)]
mod tests;
mod binary;
mod command_buffer;
mod elementwise_fusion;
mod matmul;
mod other;
mod prim;
mod quantized;
mod storage_buffer;
mod unary;
use itertools::Itertools;
use metal_rs::*;
pub use quantized::*;
use rustc_hash::FxHashMap;
use luminal::{
op::InputTensor,
prelude::{
symbolic::{BigExpression, Term},
*,
},
};
/// Compile graphs to run on Metal-supported macOS devices in supported data formats
pub type MetalCompiler<T> = (
prim::PrimitiveCompiler<T>,
SpecialOpsCompiler<T>,
other::CopyCompiler<T>,
other::ContiguousElimination<T>,
elementwise_fusion::ElementwiseFusionCompiler<T>,
// BufferCompilers,
);
/// Compilers to share command and storage buffers
type BufferCompilers = (
command_buffer::CommandBufferCompiler,
storage_buffer::StorageBufferCompiler,
);
/// Compiler to replace metal ops with specialized variants
type SpecialOpsCompiler<T> = (
binary::MetalSubtractionCompiler<T>,
binary::MetalEqualCompiler<T>,
other::ARangeCompiler<T>,
binary::MetalGatherCompiler<T>,
unary::MetalExpCompiler<T>,
unary::MetalCosCompiler<T>,
unary::MeanReduceCompiler<T>,
unary::StdNormCompiler<T>,
unary::SoftmaxCompiler<T>,
unary::RopeCompiler<T>,
matmul::MetalMatMulCompiler<T>,
);
#[derive(Debug, Clone)]
pub struct MetalBuffer(pub Buffer);
impl Deref for MetalBuffer {
type Target = Buffer;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Data for MetalBuffer {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub trait MetalFloat: Copy + 'static {
fn to_f32(self) -> f32;
fn from_f32(a: f32) -> Self;
fn is_f32() -> bool;
fn type_name() -> &'static str;
}
// Quantization types
pub trait MetalQuantizationType {
type MatmulCompiler;
}
/// 8-bit quantization. Equivalent to the ggml Q8_0 datatype
pub struct Q8_0;
impl MetalQuantizationType for Q8_0 {
type MatmulCompiler = matmul::MetalMatMulCompiler<f16>;
}
impl MetalQuantizationType for f32 {
type MatmulCompiler = matmul::MetalMatMulCompiler<Self>;
}
impl MetalQuantizationType for f16 {
type MatmulCompiler = matmul::MetalMatMulCompiler<Self>;
}
// Main metal dtypes
impl MetalFloat for f32 {
fn from_f32(a: f32) -> Self {
a
}
fn to_f32(self) -> f32 {
self
}
fn is_f32() -> bool {
true
}
fn type_name() -> &'static str {
"float"
}
}
impl MetalFloat for f16 {
fn from_f32(a: f32) -> Self {
f16::from_f32(a)
}
fn to_f32(self) -> f32 {
self.to_f32()
}
fn is_f32() -> bool {
false
}
fn type_name() -> &'static str {
"half"
}
}
pub trait MetalKernel: Debug {
/// Annotate the buffer sizes of the intermediate buffers
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
vec![]
}
/// Annotate the buffer sizes of the output buffers
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression>;
/// Set up the kernel on the buffer
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
intermediate_buffers: &[&Buffer],
output_buffers: &[&Buffer],
);
fn without_command_buffer(
&self,
inputs: &[(&Buffer, ShapeTracker)],
intermediate_buffers: &[&Buffer],
output_buffers: &[&Buffer],
) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let command_buffer = queue.new_command_buffer();
self.metal_forward(inputs, command_buffer, intermediate_buffers, output_buffers);
}
fn without_storage_buffers(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
dyn_map: &FxHashMap<char, usize>,
) -> Vec<Buffer> {
let dev = Device::system_default().unwrap();
// Allocate storage buffers
let inp_shapes = inputs.iter().map(|(_, s)| *s).collect::<Vec<_>>();
let intermediate_buffers = self
.intermediate_buffer_sizes(&inp_shapes)
.into_iter()
.map(|n| {
dev.new_buffer(
n.exec(dyn_map).unwrap() as u64,
MTLResourceOptions::StorageModeShared,
)
})
.collect::<Vec<_>>();
let intermediate_buffers_ref = intermediate_buffers.iter().collect::<Vec<_>>();
let output_buffers = self
.output_buffer_sizes(&inp_shapes)
.into_iter()
.map(|n| {
dev.new_buffer(
n.exec(dyn_map).unwrap() as u64,
MTLResourceOptions::StorageModeShared,
)
})
.collect::<Vec<_>>();
let output_buffers_ref = output_buffers.iter().collect::<Vec<_>>();
self.metal_forward(
inputs,
command_buffer,
&intermediate_buffers_ref,
&output_buffers_ref,
);
output_buffers
}
}
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
pub struct MetalKernelWrapper(pub Arc<Box<dyn MetalKernel>>);
impl Default for MetalKernelWrapper {
fn default() -> Self {
Self(Arc::new(Box::new(())))
}
}
impl MetalKernel for () {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
vec![]
}
fn metal_forward(
&self,
_: &[(&Buffer, ShapeTracker)],
_: &CommandBufferRef,
_: &[&Buffer],
_: &[&Buffer],
) {
}
}
fn compile_lib(device: &Device, source: &str) -> Library {
let options = CompileOptions::new();
options.set_fast_math_enabled(true);
// options.set_install_name(
// &rand::thread_rng()
// .sample_iter(&rand::distributions::Alphanumeric)
// .take(7)
// .map(char::from)
// .collect::<String>(),
// );
device
.new_library_with_source(
&source.replace(
"KERNEL_PATH",
&format!("{}/src/kernels", env!("CARGO_MANIFEST_DIR")),
),
&options,
)
.unwrap()
}
fn select_function_from_lib(
lib: &Library,
function: &str,
device: &Device,
) -> ComputePipelineState {
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor
.set_compute_function(Some(&lib.get_function(function, None).unwrap()));
device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap()
}
fn compile_function(name: &str, code: &str, device: &Device) -> ComputePipelineState {
let library = compile_lib(device, code);
select_function_from_lib(&library, name, device)
}
fn is<T: Any>(type_id: TypeId) -> bool {
type_id == TypeId::of::<T>()
}
trait DispatchNElements {
fn dispatch_1d(&self, n: usize);
}
impl DispatchNElements for ComputeCommandEncoderRef {
fn dispatch_1d(&self, n: usize) {
self.dispatch_thread_groups(
MTLSize {
width: n.div_ceil(1024) as u64,
height: 1,
depth: 1,
},
MTLSize {
width: 1024,
height: 1,
depth: 1,
},
);
}
}
trait SetInt {
fn set_i32(&self, index: usize, value: i32);
fn set_u32(&self, index: usize, value: u32);
fn set_f32(&self, index: usize, value: f32);
fn set_i64(&self, index: usize, value: i64);
fn set_u64(&self, index: usize, value: u64);
fn set_f64(&self, index: usize, value: f64);
}
impl SetInt for ComputeCommandEncoderRef {
fn set_i32(&self, index: usize, value: i32) {
self.set_bytes(
index as u64,
std::mem::size_of::<i32>() as u64,
&value as *const i32 as *const _,
);
}
fn set_u32(&self, index: usize, value: u32) {
self.set_bytes(
index as u64,
std::mem::size_of::<u32>() as u64,
&value as *const u32 as *const _,
);
}
fn set_f32(&self, index: usize, value: f32) {
self.set_bytes(
index as u64,
std::mem::size_of::<f32>() as u64,
&value as *const f32 as *const _,
);
}
fn set_i64(&self, index: usize, value: i64) {
self.set_bytes(
index as u64,
std::mem::size_of::<i64>() as u64,
&value as *const i64 as *const _,
);
}
fn set_u64(&self, index: usize, value: u64) {
self.set_bytes(
index as u64,
std::mem::size_of::<u64>() as u64,
&value as *const u64 as *const _,
);
}
fn set_f64(&self, index: usize, value: f64) {
self.set_bytes(
index as u64,
std::mem::size_of::<f64>() as u64,
&value as *const f64 as *const _,
);
}
}
fn input_dyn_dims(
dyn_symbols: &[char],
dyn_map: &FxHashMap<char, usize>,
encoder: &ComputeCommandEncoderRef,
index: usize,
) {
for (i, s) in dyn_symbols.iter().enumerate() {
encoder.set_u32(i + index, dyn_map[s] as u32);
}
}
fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>, String) {
let symbols: Vec<char> = shapes
.iter()
.flat_map(|st| {
st.shape()
.into_iter()
.chain(
st.padding
.into_iter()
.flat_map(|i| [i.0.into(), i.1.into()]),
)
.chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
})
.flat_map(|d| d.to_symbols())
.unique()
.collect();
(
symbols.clone(),
symbols
.into_iter()
.enumerate()
.fold(String::default(), |mut acc, (i, c)| {
write!(&mut acc, ", device int& {c} [[buffer({})]]", i + offset).unwrap();
acc
}),
)
}
fn expr_to_metal_string(expr: BigExpression) -> String {
let mut symbols = vec![];
for term in expr.terms {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {
if c == 'z' {
"(int)idx".to_string()
} else {
c.to_string()
}
}
Term::Max => format!(
"max((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
Term::Min => format!(
"min((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
_ => format!(
"({}{term:?}{})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
};
symbols.push(new_symbol);
}
symbols.pop().unwrap()
}
fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) {
(
expr_to_metal_string(shape.index_expression()),
expr_to_metal_string(shape.valid_expression()),
)
}
fn get_buffer_from_tensor<'a>(tensor: &'a InputTensor) -> &'a MetalBuffer {
tensor
.borrowed()
.data
.as_any()
.downcast_ref::<MetalBuffer>()
.expect("Tensor does not contain a metal buffer")
}
#[macro_export]
macro_rules! select_const {
($i: expr, $t: tt) => {
luminal::compiler_utils::SelectOp::new().check(|o, _| {
if let Some(c) = o.as_any().downcast_ref::<$crate::prim::MetalConstant<$t>>() {
if let luminal::op::ConstantValue::Float(f) = c.0 {
(f - $i).abs() < 0.0001
} else {
false
}
} else {
false
}
})
};
}

View File

@@ -0,0 +1,480 @@
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
use luminal::{
op::{InputTensor, Operator},
prelude::*,
shape::symbolic::BigExpression,
};
use metal_rs::{objc::rc::autoreleasepool, *};
use crate::{
compile_lib, get_buffer_from_tensor,
prim::{MetalContiguous, MetalMul, MetalSumReduce},
select_function_from_lib, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper, SetInt,
};
/// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
pub struct Matmul<T> {
matmul_pipeline: ComputePipelineState,
matvec_pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
_phantom: PhantomData<T>,
}
const BM: u64 = 8;
const BN: u64 = 32;
impl<T> MetalKernel for Matmul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
let batch_size = input_shapes[0]
.shape()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
vec![batch_size * m * n * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
);
let a_dims = a_shape.len();
let m = a_shape[a_dims - 2];
let batch_size = a_shape.iter().take(a_dims - 2).product::<usize>().max(1);
let b_batch_size = b_shape
.iter()
.enumerate()
.take(b_shape.len() - 2)
.filter(|(i, _)| !inputs[1].1.fake[inputs[1].1.indexes[*i]])
.map(|(_, i)| *i)
.product::<usize>()
.max(1);
let b_dims = b_shape.len();
let k = b_shape[b_dims - 2];
let n = b_shape[b_dims - 1];
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
if m == 1 && batch_size == 1 {
// Matvec
encoder.set_compute_pipeline_state(&self.matvec_pipeline);
encoder.set_buffer(0, Some(inputs[1].0), 0);
encoder.set_buffer(1, Some(inputs[0].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_i32(3, if m == 1 { k } else { m } as i32);
encoder.set_i32(4, if m == 1 { n } else { m } as i32);
encoder.set_i32(5, 0);
encoder.set_i32(6, 0);
encoder.set_threadgroup_memory_length(
0,
if inputs[1].1.indexes[inputs[1].1.len() - 1]
> inputs[1].1.indexes[inputs[1].1.len() - 2]
{
BN * BM * 4
} else {
BN * 8
},
);
let b = if inputs[1].1.is_contiguous() { BN } else { BM };
encoder.dispatch_thread_groups(
MTLSize::new((n as u64 + b * 4 - 1).div_ceil(b * 4), 1, 1),
MTLSize::new(BN, BM, 1),
);
} else {
// Matmul
encoder.set_compute_pipeline_state(&self.matmul_pipeline);
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_i32(3, m as i32);
encoder.set_i32(4, n as i32);
encoder.set_i32(5, k as i32);
encoder.set_i32(6, (m * k) as i32); // A batch stride
if inputs[1].1.len() > 2 // 3D or larger
&& inputs[1].1.fake[inputs[1].1.indexes[inputs[1].1.len() - 3]] // 3rd to last dimension is fake
&& inputs[1]
.1
.indexes
.iter()
.take(inputs[1].1.len().saturating_sub(4))
.any(|i| !inputs[1].1.fake[*i])
// At least one non-fake dimension before 3rd to last
{
encoder.set_i32(7, (k * n) as i32); // B batch stride
// B batch size 2
encoder.set_i32(8, b_shape[inputs[1].1.len() - 3] as i32);
} else {
encoder.set_i32(7, if b_batch_size == 1 { 0 } else { n * k } as i32); // B batch stride
encoder.set_i32(8, 1); // B batch size
}
encoder.set_i32(9, (m * n) as i32); // C batch stride
// Execute
encoder.dispatch_thread_groups(
MTLSize::new(
(n + 31).div_ceil(32) as u64,
(m + 31).div_ceil(32) as u64,
batch_size as u64,
),
MTLSize::new(32, 2, 2),
);
}
encoder.end_encoding();
}
}
impl<T: 'static + Clone> Operator for Matmul<T> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let n = b_shape.last().unwrap().to_usize().unwrap();
let batch_size = a_shape
.iter()
.map(|i| i.to_usize().unwrap())
.take(a_shape.len() - 2)
.product::<usize>();
let m = a_shape[a_shape.len() - 2].to_usize().unwrap();
let out = self.device.new_buffer(
(batch_size * m * n * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[
(get_buffer_from_tensor(&inp[0].0), inp[0].1),
(get_buffer_from_tensor(&inp[1].0), inp[1].1),
],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}
#[derive(Default, Debug)]
pub struct MetalMatMulCompiler<T>(PhantomData<T>);
impl<T: MetalFloat> Compiler for MetalMatMulCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
// Look for the matmul pattern
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
// Actually starts at [A,B] | [B, C]
let mut searcher_2d = SelectOp::new()
.ty::<MetalMul<T>>()
.shapes([['M', 'N', 'K'], ['M', 'N', 'K']])
.fakes([
[None, Some(true), Some(false)],
[Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
SelectOp::new()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
o.dim == 2
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
let mut searcher_3d = SelectOp::new()
.ty::<MetalMul<T>>()
.shapes([['D', 'A', 'C', 'B'], ['D', 'A', 'C', 'B']])
.fakes([
[Some(false), Some(false), Some(true), Some(false)],
[None, Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
o.dim == 3
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
let mut searcher_4d = SelectOp::new()
.ty::<MetalMul<T>>()
.shapes([['E', 'D', 'A', 'C', 'B'], ['E', 'D', 'A', 'C', 'B']])
.fakes([
[
Some(false),
Some(false),
Some(false),
Some(true),
Some(false),
],
[None, None, Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
o.dim == 4
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
let mut searcher_5d = SelectOp::new()
.ty::<MetalMul<T>>()
.shapes([
['F', 'E', 'D', 'A', 'C', 'B'],
['F', 'E', 'D', 'A', 'C', 'B'],
])
.fakes([
[
Some(false),
Some(false),
Some(false),
Some(false),
Some(true),
Some(false),
],
[None, None, None, Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
o.dim == 5
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
while searcher_2d.next_match()
|| searcher_3d.next_match()
|| searcher_4d.next_match()
|| searcher_5d.next_match()
{
if graph.no_delete.contains(&mul) {
// The intermediate mul can't be deleted
continue;
}
// Insert Matmul op
let srcs = graph.get_sources(mul);
let (mut src1, mut src1_shape) = (srcs[0].0, srcs[0].2);
let (mut src2, mut src2_shape) = (srcs[1].0, srcs[1].2);
// Undo expansions and permute
src1_shape.remove_dim(src1_shape.len() - 2);
src2_shape.remove_dim(src2_shape.len() - 3);
let mut dims = (0..src2_shape.len()).collect::<Vec<_>>();
dims.swap(src2_shape.len() - 2, src2_shape.len() - 1);
src2_shape.permute(&dims);
// If src1 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if src1_shape
.indexes
.iter()
.take(src1_shape.len() - 2)
.enumerate()
.any(|(a, b)| a != *b)
|| src1_shape.is_sliced()
|| src1_shape.is_padded()
{
src1 = graph
.add_op(MetalContiguous::<T>::new(
src1_shape,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.input(src1, 0, src1_shape)
.finish();
src1_shape = src1_shape.contiguous();
}
// If src2 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if src2_shape
.indexes
.iter()
.take(src2_shape.len() - 2)
.filter(|i| !src2_shape.fake[**i])
.enumerate()
.any(|(a, b)| a != *b)
|| src2_shape.is_sliced()
|| src2_shape.is_padded()
{
src2 = graph
.add_op(MetalContiguous::<T>::new(
src2_shape,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.input(src2, 0, src2_shape)
.finish();
src2_shape = src2_shape.contiguous();
}
let type_name = if T::is_f32() { "float32" } else { "float16" };
let matmul_op = graph
.add_op(Matmul::<T> {
matmul_pipeline: select_function_from_lib(
&matmul_library,
&format!( "gemm_{}{}_{type_name}_{type_name}_bm32_bn32_bk16_wm2_wn2_MN_naligned_K_taligned", if src1_shape.is_contiguous() {"n"} else {"t"}, if src2_shape.indexes[src2_shape.len() - 1] > src2_shape.indexes[src2_shape.len() - 2] {"n"} else {"t"}),
&dev
),
matvec_pipeline: select_function_from_lib(
&matvec_library,
&format!(
"gemv_{}{type_name}_bm{BM}_bn{BN}_tm4_tn4",
if src2_shape.indexes[src2_shape.len() - 1] > src2_shape.indexes[src2_shape.len() - 2] { "t_" } else { "" }
),
&dev
),
queue: queue.clone(),
device: dev.clone(),
_phantom: Default::default()
})
.input(src1, 0, src1_shape)
.input(src2, 0, src2_shape)
.finish();
// Create edges to dests
move_outgoing_edge(sum_reduce, matmul_op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
sum_reduce,
matmul_op,
);
// Remove the old ops
graph.graph.remove_node(mul);
graph.graph.remove_node(sum_reduce);
}
}
}
#[cfg(test)]
mod tests {
use dfdx::{
tensor::TensorFromVec,
tensor_ops::{PermuteTo, TryMatMul},
};
use luminal::{
prelude::*,
tests::{assert_close_precision, random_vec},
};
use crate::MetalCompiler;
#[test]
fn test_matrix_vector() {
const M: usize = 53;
const N: usize = 256;
let mut cx = Graph::new();
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
let mut a = cx.named_tensor::<R2<1, M>>("Vec").set(a_vec.clone());
let mut b = cx.named_tensor::<R2<N, M>>("Mat").set(b_mat.clone());
let mut c = a.matmul(b.permute()).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
(&mut a, &mut b, &mut c),
);
cx.execute();
let d_dev = dfdx::tensor::Cpu::default();
let d_a = d_dev.tensor_from_vec(a_vec, (dfdx::shapes::Const::<M>,));
let d_b =
d_dev.tensor_from_vec(b_mat, (dfdx::shapes::Const::<N>, dfdx::shapes::Const::<M>));
let d_c = d_a.matmul(d_b.permute());
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
}
#[test]
fn test_batch_matrix_vector() {
const M: usize = 256;
const N: usize = 256;
let mut cx = Graph::new();
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
let mut a = cx.named_tensor::<R3<1, 1, M>>("Vec").set(a_vec.clone());
let mut b = cx.named_tensor::<R2<M, N>>("Mat").set(b_mat.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
(&mut a, &mut b, &mut c),
);
cx.execute();
let d_dev = dfdx::tensor::Cpu::default();
let d_a = d_dev.tensor_from_vec(
a_vec,
(
dfdx::shapes::Const::<1>,
dfdx::shapes::Const::<1>,
dfdx::shapes::Const::<M>,
),
);
let d_b =
d_dev.tensor_from_vec(b_mat, (dfdx::shapes::Const::<M>, dfdx::shapes::Const::<N>));
let d_c = d_a.matmul(d_b);
assert_close_precision(&c.data(), &d_c.to_dtype::<f32>().as_vec(), 2);
}
}

View File

@@ -0,0 +1,370 @@
use std::{any::Any, marker::PhantomData, sync::Arc};
use luminal::{
op::{InputTensor, Operator},
prelude::{
petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction},
*,
},
shape::symbolic::BigExpression,
};
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
ComputePipelineState, Device, MTLResourceOptions,
};
use rustc_hash::FxHashMap;
use crate::{
compile_function,
prim::{MetalAdd, MetalContiguous, MetalCopyFromDevice, MetalCopyToDevice, MetalSumReduce},
select_const, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
SetInt,
};
use super::binary::MetalSub;
/// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up
#[derive(LuminalPrint, Default)]
pub struct CopyCompiler<T>(PhantomData<T>);
impl<T: MetalFloat> Compiler for CopyCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let (mut first, mut second) = (NodeIndex::default(), NodeIndex::default());
let mut selector = SelectOp::new()
.ty::<MetalCopyToDevice<T>>()
.ptr(&mut first)
.edge(
SelectOp::new()
.ty::<MetalCopyToDevice<T>>()
.ptr(&mut second),
)
.search(graph);
while selector.next_match() {
// Ensure there are no dests from first that are not copies
if graph
.graph
.edges_directed(first, petgraph::Direction::Outgoing)
.filter(|e| {
let target = graph.graph.node_weight(e.target()).unwrap().as_any();
!target.is::<MetalCopyFromDevice<T>>() && !target.is::<MetalCopyToDevice<T>>()
})
.count()
> 0
|| graph.no_delete.contains(&first)
{
continue;
}
let Some((source, _, _)) = graph.get_sources(first).pop() else {
continue;
};
move_outgoing_edge(second, source, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
second,
source,
);
graph.graph.remove_node(second);
for dest in graph
.get_dests(first)
.iter()
.map(|(i, _)| *i)
.collect::<Vec<_>>()
{
move_outgoing_edge(dest, source, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
dest,
source,
);
graph.graph.remove_node(dest);
}
graph.graph.remove_node(first);
selector.clear_cached_results();
}
}
}
/// Special kernel for producing aranges
#[derive(Clone, LuminalEqFalse)]
pub struct MetalARange<T: MetalFloat> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
pub size: BigExpression,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> std::fmt::Debug for MetalARange<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalARange({:?})", self.size)
}
}
impl<T: MetalFloat> MetalARange<T> {
fn new(
device: Device,
queue: CommandQueue,
size: BigExpression,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let type_name = T::type_name();
Self {
pipeline: compile_function("metal_arange", &format!("
#include <metal_stdlib>
using namespace metal;
kernel void metal_arange(device {type_name} *out [[buffer(0)]], device int& n_elements [[buffer(1)]], uint idx [[thread_position_in_grid]]) {{
if (idx < n_elements) {{
out[idx] = ({type_name})idx;
}}
}}"), &device),
queue,
device,
size,
dyn_map,
_phantom: Default::default(),
}
}
}
impl<T: MetalFloat> MetalKernel for MetalARange<T> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
vec![self.size.clone() * std::mem::size_of::<f16>()]
}
fn metal_forward(
&self,
_: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
// Calculate size
let size = self
.size
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
.unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(output_buffers[0]), 0);
encoder.set_u32(1, size as u32);
// Execute
encoder.dispatch_1d(size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalARange<T> {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Set up command buffer and output buffer
let command_buffer = self.queue.new_command_buffer();
let size = self
.size
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
.unwrap();
let out = self.device.new_buffer(
(size * std::mem::size_of::<f16>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(&[], command_buffer, &[], &[&out]);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}
/// Replace the arange pattern with a special kernel. This must be ran **after** the subtraction compiler
#[derive(Default, LuminalPrint)]
pub struct ARangeCompiler<T: MetalFloat>(PhantomData<T>);
impl<T: MetalFloat> Compiler for ARangeCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (
mut one_const,
mut contig1,
mut contig2,
mut contig3,
mut contig4,
mut sum_reduce,
mut subtraction_constant,
mut subtraction,
) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);
// TODO: Make sure this actually checks the shape transformations to ensure pooling happens
let contig = SelectOp::new().ty::<MetalContiguous<T>>();
let pre_sub_pattern = select_const!(1.0, T)
.ptr(&mut one_const)
.edge(contig.clone().ptr(&mut contig1))
.edge(contig.clone().ptr(&mut contig2))
.edge(contig.clone().ptr(&mut contig3))
.edge(contig.clone().ptr(&mut contig4))
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.ptr(&mut sum_reduce),
);
let mut s1 = pre_sub_pattern
.clone()
.edge(
select_const!(1.0, T)
.ptr(&mut subtraction_constant)
.edge(SelectOp::new().ty::<MetalSub<T>>().ptr(&mut subtraction)),
)
.search(graph);
let mut s2 = pre_sub_pattern
.edge(
select_const!(-1.0, T)
.ptr(&mut subtraction_constant)
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut subtraction)),
)
.search(graph);
while s1.next_match() || s2.next_match() {
let arange_amount = {
let sh = graph
.graph
.edge_weight(
graph
.graph
.edges_connecting(one_const, contig1)
.next()
.unwrap()
.id(),
)
.unwrap()
.as_data()
.unwrap()
.2;
sh.dims[sh.indexes[sh.len() - 1]]
};
let arange_op = graph
.add_op(MetalARange::<T>::new(
dev.clone(),
queue.clone(),
arange_amount.into(),
&graph.dyn_map,
))
.finish();
move_outgoing_edge(subtraction, arange_op, &mut graph.graph);
graph.graph.remove_node(subtraction);
graph.safe_remove_node(subtraction_constant, 0);
graph.safe_remove_node(sum_reduce, 0);
graph.safe_remove_node(contig4, 0);
graph.safe_remove_node(contig3, 0);
graph.safe_remove_node(contig2, 0);
graph.safe_remove_node(contig1, 0);
graph.safe_remove_node(one_const, 0);
s1.clear_cached_results();
s2.clear_cached_results();
}
}
}
#[derive(Debug, Default)]
pub struct ContiguousElimination<T>(PhantomData<T>);
impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
// Look for contiguous calls going to ops that can accept non-contiguous inputs (marked non_contiguous)
let (mut contig, mut op) = (NodeIndex::default(), NodeIndex::default());
let pattern = SelectOp::new()
.ty::<MetalContiguous<T>>()
.ptr(&mut contig)
.edge(
SelectOp::new()
.check(|op, _| op.custom("non_contiguous", Box::new(())).is_some())
.ptr(&mut op),
);
let mut selector = pattern.search(graph);
while selector.next_match() {
if graph.no_delete.contains(&contig)
|| graph
.graph
.edges_directed(contig, Direction::Outgoing)
.count()
> 1
{
continue;
}
// Shape going from contig to op
// let first_shape = graph
// .graph
// .edges_directed(contig, Direction::Incoming)
// .find_map(|e| e.weight().as_data())
// .unwrap()
// .2;
let second_shape = graph
.graph
.edges_connecting(contig, op)
.find_map(|e| e.weight().as_data())
.unwrap()
.2;
// Here we should check if second shape and first shape are mergeable instead of just checking if second_shape is contiguous
if second_shape.is_contiguous()
&& !second_shape.is_sliced()
&& !second_shape.is_padded()
{
let source = graph
.graph
.neighbors_directed(contig, petgraph::Direction::Incoming)
.next()
.unwrap();
move_incoming_edge(contig, op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
contig,
source,
);
graph.graph.remove_node(contig);
let new_shapes = graph
.get_sources(op)
.into_iter()
.map(|(_, _, s)| s)
.collect::<Vec<_>>();
graph
.graph
.node_weight_mut(op)
.unwrap()
.custom("recompile_shapes", Box::new(new_shapes));
selector.clear_cached_results();
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,508 @@
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
ComputePipelineState, Device, MTLResourceOptions, MTLSize,
};
use petgraph::visit::EdgeRef;
use luminal::{
op::{InputTensor, Operator},
prelude::*,
shape::symbolic::BigExpression,
};
use crate::{
binary::MetalGather, get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel,
MetalKernelWrapper,
};
use super::{compile_function, SetInt};
/// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix. This expects the first input to be a quantized 2D matrix
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
pub struct QuantizedMatmul<T> {
matvec_pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> QuantizedMatmul<T> {
fn new(device: Device, queue: CommandQueue) -> Self {
let type_name = T::type_name();
Self {
matvec_pipeline: compile_function("mkernel", &format!("
using namespace metal;
#define QK8_0 32
#define NB_Q8_0 8
typedef struct {{
half d; // delta
int8_t qs[QK8_0]; // quants
}} block_q8_0;
kernel void mkernel(
device block_q8_0* x [[buffer(0)]], // Quantized 2D matrix
device {type_name}* y [[buffer(1)]], // Float src vector
device {type_name}* dst [[buffer(2)]], // Float dest vector
constant int64_t & src_vec_size [[buffer(3)]], // Matrix n cols (src vector size) (Must be >= 32)
constant int64_t & dest_vec_size [[buffer(4)]], // Matrix n rows (dest vector size) (Must be >= 4)
constant int64_t & mat_batch_stride [[buffer(5)]], // Matrix batch stride
constant int64_t & vec_batch_stride [[buffer(6)]], // Vector batch stride
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index_in_simdgroup[[thread_index_in_simdgroup]],
uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]] // 2 simdgroups in a threadgroup
) {{
const int num_rows = 4;
const int num_simdgroups_per_threadgroup = 2;
const int quant_width = 32;
const int num_quants_per_row = src_vec_size / 32; // Number of quants per row
// This is the first row the simdgroup will work on (each simdgroup handles a block of 4 rows)
const int first_row = (threadgroup_position_in_grid.x * num_simdgroups_per_threadgroup + simdgroup_index_in_threadgroup) * num_rows;
// Offsets
x += first_row * num_quants_per_row + threadgroup_position_in_grid.z * (mat_batch_stride / 32);
y += threadgroup_position_in_grid.z * vec_batch_stride;
dst += (threadgroup_position_in_grid.z * dest_vec_size);
// thread-local cache of vector values to work on. This thread must only work on 8 at a time
{type_name} yl[8];
// thread-local cache of 4 row sums
float sumf[num_rows] = {{0.f}};
const int ix = thread_index_in_simdgroup / 4;
const int il = thread_index_in_simdgroup % 4;
y += thread_index_in_simdgroup * 8;
// each thread in a SIMD group deals with 8 quants at a time
// we start at 0-7 (ix) depending on the simdgroup index, and jump 8 indexes each time
for (int ib = ix; ib < num_quants_per_row; ib += 8) {{ // ib: current column position
// Load vector values into the cache
for (int i = 0; i < 8; ++i) {{
yl[i] = y[i];
}}
// Loop through 4 matrix rows
for (int row = 0; row < 4; ++row) {{
// Get pointer to matrix data
device const int8_t* qs = x[ib + row * num_quants_per_row].qs + il * 8;
float sumq = 0.f; // Partial sum
// Loop through 8 columns
for (int iq = 0; iq < 8; ++iq) {{
sumq += qs[iq] * yl[iq]; // Multiply int with vector value (auto converts to float?)
}}
sumf[row] += sumq * x[ib + row * num_quants_per_row].d; // multiply by delta (scaling factor)
}}
y += 256; // Jump by 256
}}
// each simdgroup is responsible for saving 4 final vector values (n rows)
for (int row = 0; row < num_rows; ++row) {{
const float tot = simd_sum(sumf[row]);
if (thread_index_in_simdgroup == 0 && first_row + row < dest_vec_size) {{
dst[first_row + row] = ({type_name})tot;
}}
}}
}}
"), &device),
queue,
device,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for QuantizedMatmul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
let batch_size = input_shapes[0]
.shape()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
vec![batch_size * m * n * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
assert!(
!inputs[1].1.is_contiguous(),
"Weight matrix must be column-major"
);
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
);
let a_dims = a_shape.len();
let m = a_shape[a_dims - 2];
let batch_size = a_shape.iter().take(a_dims - 2).product::<usize>().max(1);
let b_dims = b_shape.len();
let k = b_shape[b_dims - 2];
let n = b_shape[b_dims - 1];
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
if batch_size == 1 {
// Matvec
encoder.set_compute_pipeline_state(&self.matvec_pipeline);
encoder.set_buffer(0, Some(inputs[1].0), 0); // Matrix
encoder.set_buffer(1, Some(inputs[0].0), 0); // Vector
encoder.set_buffer(2, Some(output_buffers[0]), 0); // Dest vector
encoder.set_i64(3, k as i64); // Src vec size
encoder.set_i64(4, n as i64); // Dest vec size
encoder.set_i64(5, 0); // Matrix batch stride
encoder.set_i64(6, k as i64); // Vector batch stride
encoder.dispatch_thread_groups(
MTLSize::new(n.div_ceil(8) as u64, 1, m as u64),
MTLSize::new(8, 8, 1),
);
} else {
todo!()
}
encoder.end_encoding();
}
}
impl<T: 'static + Clone> Operator for QuantizedMatmul<T> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let n = b_shape[1].to_usize().unwrap();
let (batch_size, m) = if a_shape.len() == 3 {
(
a_shape[0].to_usize().unwrap(),
a_shape[1].to_usize().unwrap(),
)
} else {
(0, a_shape[0].to_usize().unwrap())
};
let out = self.device.new_buffer(
(batch_size * m * n * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[
(get_buffer_from_tensor(&inp[0].0), inp[0].1),
(get_buffer_from_tensor(&inp[1].0), inp[1].1),
],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
pub struct QuantizedGather<T> {
pipeline: ComputePipelineState,
device: Device,
queue: CommandQueue,
embed_dim: usize,
_phantom: PhantomData<T>,
}
impl<T: MetalFloat> QuantizedGather<T> {
fn new(device: Device, queue: CommandQueue, embed_dim: usize) -> Self {
let type_name = T::type_name();
Self {pipeline: compile_function("metal_gather", &format!(
"
#include <metal_stdlib>
using namespace metal;
#define QK8_0 32
typedef struct {{
half d; // delta
int8_t qs[QK8_0]; // quants
}} block_q8_0;
kernel void metal_gather(device float *inp [[buffer(0)]], device block_q8_0 *weights [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_embeddings [[buffer(3)]], device int& embedding_dim [[buffer(4)]], uint2 idx [[thread_position_in_grid]]) {{
if (idx.x < n_embeddings && idx.y < embedding_dim) {{
int block_idx = ((int)inp[idx.x] * embedding_dim + idx.y) / QK8_0;
out[idx.x * embedding_dim + idx.y] = weights[block_idx].qs[idx.y % QK8_0] * weights[block_idx].d;
}}
}}"), &device), device, embed_dim, queue, _phantom: Default::default()}
}
}
impl<T: MetalFloat> Operator for QuantizedGather<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup buffers
let indexes = tensors[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<Vec<f32>>()
.unwrap();
let index_buffer = self.device.new_buffer_with_data(
unsafe { std::mem::transmute(indexes.as_ptr()) },
(indexes.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let out = self.device.new_buffer(
(indexes.len() * self.embed_dim * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let encoder = command_buffer
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(get_buffer_from_tensor(&tensors[1].0)), 0);
encoder.set_buffer(2, Some(&out), 0);
encoder.set_u32(3, indexes.len() as u32);
encoder.set_u32(4, self.embed_dim as u32);
// Execute
encoder.dispatch_threads(
MTLSize {
width: indexes.len() as u64,
height: self.embed_dim as u64,
depth: 1,
},
MTLSize {
width: 16,
height: 16,
depth: 1,
},
);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
}
#[derive(Default)]
pub struct MetalQuantizedCompiler<T>(Vec<NodeIndex>, PhantomData<T>);
impl<T> MetalQuantizedCompiler<T> {
pub fn new<To: ToIds>(weights: To) -> Self {
Self(weights.to_ids(), Default::default())
}
}
impl<T: MetalFloat + Default> Compiler for MetalQuantizedCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let device = Device::system_default().unwrap();
let queue = device.new_command_queue();
let mut weight_ids = self.0.clone();
let mut local_remap = remap.to_ids_mut();
for w in &mut weight_ids {
local_remap.push(w);
}
// Normal metal compilation
graph.compile(
<(
super::prim::PrimitiveCompiler<T>,
super::SpecialOpsCompiler<T>,
super::other::CopyCompiler<T>,
super::other::ContiguousElimination<T>,
super::elementwise_fusion::ElementwiseFusionCompiler<T>,
)>::default(),
&mut local_remap,
);
// Modify ops directly downstream of weights
for weight in downstream(&weight_ids, graph) {
for (target, (inp_ind, _, _)) in graph
.graph
.edges_directed(weight, petgraph::Direction::Outgoing)
.filter_map(|e| e.weight().as_data().map(|i| (e.target(), i)))
.collect::<Vec<_>>()
{
assert_eq!(
inp_ind, 1,
"Quantized weight {target:?} is the wrong input!",
);
let op_node = graph.graph.node_weight_mut(target).unwrap();
if let Some(gather) = op_node.as_any().downcast_ref::<MetalGather<T>>() {
*op_node = Box::new(QuantizedGather::<T>::new(
device.clone(),
queue.clone(),
gather.embed_dim,
));
} else if op_node.as_any().is::<super::matmul::Matmul<T>>() {
*op_node = Box::new(QuantizedMatmul::<T>::new(device.clone(), queue.clone()));
} else {
panic!("Quantized weight {target:?} is an input to a node that isn't a matmul or gather!");
}
}
}
// Finish normal metal compilation
graph.compile(super::BufferCompilers::default(), &mut remap);
}
}
#[cfg(test)]
mod tests {
use dfdx::{
tensor::TensorFromVec,
tensor_ops::{PermuteTo, TryMatMul},
};
use luminal::{
prelude::*,
tests::{assert_close, random_vec_rng},
};
use metal_rs::{Device, MTLResourceOptions};
use rand::{thread_rng, Rng};
use crate::{MetalBuffer, MetalQuantizedCompiler};
#[repr(C, packed)]
struct BlockQ8_0 {
_d: f16,
_qs: [i8; 32],
}
fn quantized_buffer(weights: &[BlockQ8_0], dev: &Device) -> Tensor {
let buffer = dev.new_buffer_with_bytes_no_copy(
weights.as_ptr() as *mut _,
std::mem::size_of_val(weights) as u64,
MTLResourceOptions::StorageModeShared,
None,
);
Tensor {
data: Box::new(MetalBuffer(buffer)),
}
}
#[test]
fn test_quantized_matvec() {
let mut rng = thread_rng();
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let vec_data = random_vec_rng(1024, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>();
let vec = cx.tensor::<R1<1024>>().set(vec_data.clone());
let mut out = vec.matmul(weights.permute()).retrieve();
// "Load" weights in 8bit
let blocks = mat_data
.chunks_exact(32)
.map(|chunk| {
let mut array = [0; 32];
for (i, n) in chunk.iter().enumerate() {
array[i] = *n;
}
BlockQ8_0 {
_d: f16::from_f32(1.0),
_qs: array,
}
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
cx.compile(
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
&mut out,
);
cx.execute();
let mut cx1 = Graph::new();
let weights = cx1
.tensor::<R2<512, 1024>>()
.set(mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>());
let vec = cx1.tensor::<R1<1024>>().set(vec_data);
let out_32 = vec.matmul(weights.permute()).retrieve();
cx1.execute();
assert_close(&out.data(), &out_32.data());
}
#[test]
fn test_quantized_matmul() {
let mut rng = thread_rng();
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>();
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute()).retrieve();
// "Load" weights in 8bit
let blocks = mat_data
.chunks_exact(32)
.map(|chunk| {
let mut array = [0; 32];
for (i, n) in chunk.iter().enumerate() {
array[i] = *n;
}
BlockQ8_0 {
_d: f16::from_f32(1.0),
_qs: array,
}
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
cx.compile(
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
&mut out,
);
cx.execute();
let cpu = dfdx::tensor::Cpu::default();
let d_a = cpu.tensor_from_vec(
mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>(),
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>),
);
let d_b = cpu.tensor_from_vec(
inp_mat_data,
(dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>),
);
let d_c = d_b.matmul(d_a.permute());
assert_close(&out.data(), &d_c.as_vec());
}
}

View File

@@ -0,0 +1,61 @@
use crate::{
prelude::{metal::prim::MetalAdd, *},
select_const, select_ty,
};
use petgraph::stable_graph::NodeIndex;
use super::{
binary::MetalSub,
prim::{MetalConstant, MetalLessThan, MetalMul},
};
pub fn less_than<T: MetalFloat>(
s1: SelectEdge,
s2: SelectEdge,
ptrs: &mut Vec<NodeIndex>,
) -> SelectEdge {
s2.edge(s1.edge(select_ty!(MetalLessThan<T>).ptr(ptrs)))
}
pub fn mul<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
s2.edge(s1.edge(select_ty!(MetalMul<T>).ptr(ptrs)))
}
pub fn add<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
s2.edge(s1.edge(select_ty!(MetalAdd<T>).ptr(ptrs)))
}
pub fn sub<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
s2.edge(s1.edge(select_ty!(MetalSub<T>).ptr(ptrs)))
}
pub fn less_than_equal<T: MetalFloat>(
s1: SelectEdge,
s2: SelectEdge,
mut ptrs: &mut Vec<NodeIndex>,
) -> SelectEdge {
sub::<T>(
select_const!(1.0, T).ptr(&mut ptrs).into(),
less_than::<T>(s2, s1, &mut ptrs),
ptrs,
)
}
pub fn max<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
let a = mul::<T>(
less_than::<T>(s1.clone(), s2.clone(), ptrs),
s2.clone(),
ptrs,
);
let b = mul::<T>(less_than_equal::<T>(s2, s1.clone(), ptrs), s1, ptrs);
add::<T>(a, b, ptrs)
}
pub fn relu<T: MetalFloat>(s1: SelectEdge, mut ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
max::<T>(s1, select_const!(0.0, T).ptr(&mut ptrs).into(), &mut ptrs)
}
pub fn abs<T: MetalFloat>(s1: SelectEdge, mut ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
add::<T>(
relu::<T>(s1.clone(), &mut ptrs),
relu::<T>(
mul::<T>(s1, select_const!(-1.0, T).ptr(&mut ptrs).into(), &mut ptrs),
&mut ptrs,
),
&mut ptrs,
)
}

View File

@@ -0,0 +1,404 @@
use std::{
cell::UnsafeCell,
collections::{BTreeMap, BTreeSet},
ops::Deref,
sync::Arc,
};
use itertools::Itertools;
use metal_rs::{Buffer, Device, MTLResourceOptions};
use rustc_hash::{FxHashMap, FxHashSet};
use luminal::{
op::{InputTensor, Operator},
prelude::{
petgraph::{algo::toposort, stable_graph::NodeIndex, visit::EdgeRef, Direction},
symbolic::BigExpression,
*,
},
};
use crate::{MetalBuffer, MetalKernelWrapper};
use super::get_buffer_from_tensor;
#[derive(Default, LuminalPrint)]
pub struct StorageBufferCompiler;
impl Compiler for StorageBufferCompiler {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
// First pass - get clear sets for each node
#[allow(clippy::type_complexity)]
let mut first_pass: FxHashMap<
NodeIndex,
(
BTreeMap<NodeIndex, BTreeSet<NodeIndex>>,
BTreeSet<NodeIndex>,
),
> = FxHashMap::default();
let toposort = toposort(&graph.graph, None).unwrap();
// Loop through nodes in graph
for node in &toposort {
// Run through parents to build new tenative set and clear set
let (mut tenative_sets, mut clear_set) = (BTreeMap::default(), BTreeSet::default());
for parent in graph
.graph
.edges_directed(*node, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.map(|e| e.source())
{
let parent_children = graph
.graph
.edges_directed(parent, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.map(|e| e.target())
.collect::<BTreeSet<_>>();
tenative_sets.insert(parent, parent_children);
if let Some((parent_tenative_set, parent_clear_set)) = first_pass.get(&parent) {
for (node_index, new_tenative_set) in
parent_tenative_set.iter().map(|(n, c)| {
let mut c = c.clone();
c.retain(|n| *n != parent);
(*n, c)
})
{
if let Some(set) = tenative_sets.get(&node_index) {
*tenative_sets.get_mut(&node_index).unwrap() =
btreeset_intersection(new_tenative_set, set);
} else {
tenative_sets.insert(node_index, new_tenative_set);
}
}
clear_set.extend(
tenative_sets
.iter()
.filter(|(_, v)| v.is_empty())
.map(|(n, _)| *n),
);
tenative_sets.retain(|_, v| !v.is_empty());
clear_set.extend(parent_clear_set);
}
}
first_pass.insert(*node, (tenative_sets, clear_set));
}
// Second pass - assign buffers
let available_buffers = graph
.graph
.node_indices()
.filter(|n| !graph.no_delete.contains(n))
.collect::<Vec<_>>()
.into_iter()
.filter_map(|n| {
if let Some(Ok(wrapper)) = graph
.graph
.node_weight_mut(n)
.unwrap()
.custom("metal", Box::new(()))
.map(|n| n.downcast::<MetalKernelWrapper>())
{
Some((n, wrapper))
} else {
None
}
})
.collect::<Vec<_>>()
.into_iter()
.map(|(n, wrapper)| {
let input_shapes = graph
.get_sources(n)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
let output_buffers = wrapper.0.output_buffer_sizes(&input_shapes);
let intermediate_buffers = wrapper.0.intermediate_buffer_sizes(&input_shapes);
(n, (output_buffers, intermediate_buffers))
})
.collect::<FxHashMap<_, _>>();
// Loop through nodes in graph
let mut buffers = vec![];
let mut buffer_map = FxHashMap::default();
let mut used = FxHashSet::<NodeIndex>::default();
for node in &toposort {
if graph.no_delete.contains(node) {
continue;
}
let Some(Ok(wrapper)) = graph
.graph
.node_weight_mut(*node)
.unwrap()
.custom("metal", Box::new(()))
.map(|e| e.downcast::<MetalKernelWrapper>())
else {
continue;
};
buffer_map.insert(*node, (vec![], vec![]));
let input_shapes = graph
.get_sources(*node)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
// Assign output buffers
for required_buffer in wrapper.0.output_buffer_sizes(&input_shapes) {
// Find an applicable buffer
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
available_buffers[i]
.0
.iter()
.cloned()
.enumerate()
.map(|(o, b)| (o, *i, b))
})
.find(|(_, _, size)| *size == required_buffer)
{
let buffer = buffer_map.get(&source_node).unwrap().0[buffer_index];
buffer_map.get_mut(node).unwrap().0.push(buffer);
// Remove this buffer from first_pass so it can't be used again
used.insert(source_node);
} else {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().0.push(buffers.len());
buffers.push(required_buffer);
}
}
// Assign intermediate buffers
for required_buffer in wrapper.0.intermediate_buffer_sizes(&input_shapes) {
// Find an applicable buffer
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !graph.no_delete.contains(i))
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
available_buffers[i]
.1
.iter()
.cloned()
.enumerate()
.map(|(o, b)| (o, *i, b))
})
.find(|(_, _, size)| *size == required_buffer)
{
let buffer = buffer_map.get(&source_node).unwrap().1[buffer_index];
buffer_map.get_mut(node).unwrap().1.push(buffer);
used.insert(source_node);
} else {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().1.push(buffers.len());
buffers.push(required_buffer);
}
}
}
// Loop through no_delete nodes and add buffers just for them
for node in &toposort {
if !graph.no_delete.contains(node) {
continue;
}
let Some(Ok(wrapper)) = graph
.graph
.node_weight_mut(*node)
.unwrap()
.custom("metal", Box::new(()))
.map(|e| e.downcast::<MetalKernelWrapper>())
else {
continue;
};
buffer_map.insert(*node, (vec![], vec![]));
let input_shapes = graph
.get_sources(*node)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
// Assign output buffers
for required_buffer in wrapper.0.output_buffer_sizes(&input_shapes) {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().0.push(buffers.len());
buffers.push(required_buffer);
}
// Assign intermediate buffers
for required_buffer in wrapper.0.intermediate_buffer_sizes(&input_shapes) {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().1.push(buffers.len());
buffers.push(required_buffer);
}
}
// We now have the buffers to allocate, and the buffers needed for each op.
// Let's create the allocator op and wrap all the metal ops
let shared_buffers = Arc::new(UnsafeCell::new(vec![]));
let allocator = graph
.add_op(AllocateMetalBuffers {
dev: Device::system_default().unwrap(),
dyn_map: &graph.dyn_map,
buffer_sizes: buffers,
buffers: shared_buffers.clone(),
})
.finish();
// Ensure allocator is ran before any nodes that use the buffers
for node in graph
.graph
.node_indices()
// Starting node must have no incoming edges
.filter(|e| {
graph
.graph
.edges_directed(*e, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.count()
== 0
})
// Starting node must have at least one outgoing edge
.filter(|e| {
graph
.graph
.edges_directed(*e, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.count()
> 0
})
.collect_vec()
{
graph.add_schedule_dependency(allocator, node);
}
// Wrap nodes in StorageBufferWrapper
for (node, (output_buffers, intermediate_buffers)) in buffer_map
.into_iter()
.filter(|(_, b)| !b.0.is_empty() || !b.1.is_empty())
{
let wrapper = graph
.graph
.node_weight_mut(node)
.unwrap()
.custom("metal", Box::new(()))
.unwrap()
.downcast::<MetalKernelWrapper>()
.unwrap();
*graph.graph.node_weight_mut(node).unwrap() = Box::new(StorageBufferWrapper {
wrapper,
buffers: shared_buffers.clone(),
output_buffers,
intermediate_buffers,
});
}
}
}
fn btreeset_intersection<T: Ord>(mut a: BTreeSet<T>, b: &BTreeSet<T>) -> BTreeSet<T> {
a.retain(|i| b.contains(i));
a
}
#[derive(LuminalEqFalse, LuminalPrint)]
struct AllocateMetalBuffers {
dev: Device,
dyn_map: *const FxHashMap<char, usize>,
buffer_sizes: Vec<BigExpression>,
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
}
impl Operator for AllocateMetalBuffers {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let buffers = unsafe { &mut *self.buffers.get() };
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
// Allocate all buffers
if buffers.is_empty() {
*buffers = self
.buffer_sizes
.iter()
.map(|e| {
self.dev.new_buffer(
e.exec(dyn_map).unwrap() as u64,
MTLResourceOptions::StorageModeShared,
)
})
.collect();
} else {
for (size, buffer) in self.buffer_sizes.iter().zip(buffers) {
let size = size.exec(dyn_map).unwrap() as u64;
if buffer.length() != size {
// TODO: For some reason this causes bad outputs. Maybe we are relying on buffer length somewhere? We shouldn't be.
// Also, it seems we are getting the benifits of this without actually doing it. Maybe metal is doing it in the background?
// Similar allocation strategy to Rust's Vec
// let mut length = buffer.length();
// while length < size {
// length *= 2;
// }
let length = size;
*buffer = self
.dev
.new_buffer(length, MTLResourceOptions::StorageModeShared);
}
}
}
vec![]
}
}
#[derive(LuminalEqFalse)]
struct StorageBufferWrapper {
wrapper: Box<MetalKernelWrapper>,
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
intermediate_buffers: Vec<usize>,
output_buffers: Vec<usize>,
}
impl std::fmt::Debug for StorageBufferWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.wrapper.0.fmt(f)
}
}
impl Operator for StorageBufferWrapper {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let buffers = unsafe { self.buffers.get().as_ref().unwrap() };
let intermediate_buffers = self
.intermediate_buffers
.iter()
.map(|i| &buffers[*i])
.collect::<Vec<_>>();
let output_buffers = self
.output_buffers
.iter()
.map(|i| &buffers[*i])
.collect::<Vec<_>>();
self.wrapper.0.without_command_buffer(
&inp.iter()
.map(|(t, sh)| (get_buffer_from_tensor(t).deref(), *sh))
.collect::<Vec<_>>(),
&intermediate_buffers,
&output_buffers,
);
output_buffers
.iter()
.map(|buf| Tensor::new(MetalBuffer((*buf).clone())))
.collect()
}
}
#[test]
fn test_shared_buffers() {
use luminal::prelude::*;
use luminal::tests::{assert_close_precision, random_vec};
let mut cx = Graph::new();
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let b = a.exp2();
let c = a.log2() * b;
let d = b.recip();
let mut e = (c + d).retrieve();
cx.execute();
let e_unopt = e.data();
e.drop();
cx.compile(crate::MetalCompiler::<f16>::default(), &mut e);
cx.execute();
assert_close_precision(&e.data(), &e_unopt, 2);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
use dfdx::prelude::{Module as DfdxModule, *};
use itertools::Itertools;
use rand::{rngs::StdRng, Rng, SeedableRng};
use luminal::{
nn::{activation::ReLU, linear::Linear},
prelude::{Module, *},
tests::{assert_close, assert_close_precision, random_vec, random_vec_rng},
};
use crate::MetalCompiler;
#[test]
fn test_contiguous() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(data, (dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>));
let d_b = d_a.permute::<Rank2<4, 3>, _>().reshape::<Rank2<12, 1>>();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_log2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.log2().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
assert_close(
&b.data(),
&data.into_iter().map(|i: f32| i.log2()).collect::<Vec<_>>(),
);
}
#[test]
fn test_exp2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.exp2().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
assert_close(
&b.data(),
&data.into_iter().map(|i: f32| i.exp2()).collect::<Vec<_>>(),
);
}
#[test]
fn test_recip() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 4096.]);
let mut b = a.recip().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 4096.]);
let d_b = d_a.recip();
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
}
#[test]
fn test_sin() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut b = a.sin().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_a.sin();
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
}
#[test]
fn test_sqrt() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut b = a.sqrt().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_a.sqrt();
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
}
#[test]
fn test_add() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a + b;
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a + d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_sub() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a - b;
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a - d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_square() {
let mut cx = Graph::new();
let mut rng = rand::thread_rng();
let data = (0..40960)
.map(|_| rng.gen_range(-0.01..0.01))
.collect::<Vec<f32>>();
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'s'>, luminal::prelude::Const<4096>)>()
.set_dyn(data.clone(), &[1, 10, 4096]);
let mut b = a * a;
b.retrieve();
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec::<Rank3<1, 10, 4096>>(
data,
(
dfdx::prelude::Const::<1>,
dfdx::prelude::Const::<10>,
dfdx::prelude::Const::<4096>,
),
);
let d_b = d_a.clone() * d_a;
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_mul() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a * b;
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a * d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_mul2() {
let mut cx = Graph::new();
let a = cx
.tensor::<(
luminal::prelude::Const<1>,
luminal::prelude::Const<1>,
Dyn<'a'>,
Dyn<'a'>,
)>()
.set_dyn(vec![82.4, 783.0, 99.6, 974.5], &[1, 1, 2, 2]);
let b = cx.tensor::<R0>().set(vec![0.57735026]);
let mut c = a * b.expand();
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([[[[82.4, 783.0], [99.6, 974.5]]]]);
let d_b = d_dev.tensor(0.57735026);
let d_c = d_a * d_b.broadcast::<_, dfdx::shapes::Axes4<0, 1, 2, 3>>();
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_div() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a / b;
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a / d_b;
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_max() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
let mut c = a.max(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([1., 2., 3.]);
let d_b = d_dev.tensor([1., 2., 3.]);
let d_c = d_a.maximum(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_mod() {
let mut cx = Graph::new();
let a_data = random_vec(3);
let b_data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(a_data.clone());
let b = cx.tensor::<R1<3>>().set(b_data.clone());
let mut c = a % b;
c.retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
// No dfdx equivalent
assert_close(
&c.data(),
&a_data
.into_iter()
.zip(b_data)
.map(|(a, b)| a % b)
.collect_vec(),
);
}
// Reduction op tests
#[test]
fn test_sum_reduce() {
let mut cx = Graph::new();
let data = random_vec(4096);
let a = cx.tensor::<R3<1, 1, 4096>>();
a.set(data.clone());
let mut b = a.sum_reduce::<_, luminal::prelude::Axis<1>>().retrieve();
let mut c = a.sum_reduce::<_, luminal::prelude::Axis<0>>().retrieve();
let mut d = a.sum_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
data,
(
dfdx::shapes::Const::<1>,
dfdx::shapes::Const::<1>,
dfdx::shapes::Const::<4096>,
),
);
let d_b = d_a.clone().sum::<_, dfdx::shapes::Axis<1>>();
let d_c = d_a.clone().sum::<_, dfdx::shapes::Axis<0>>();
let d_d = d_a.sum::<_, dfdx::shapes::Axis<2>>();
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_max_reduce() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R3<2, 2, 3>>();
a.set(data.clone());
let mut b = a.max_reduce::<_, luminal::prelude::Axis<1>>().retrieve();
let mut c = a.max_reduce::<_, luminal::prelude::Axis<0>>().retrieve();
let mut d = a.max_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
data,
(
dfdx::shapes::Const::<2>,
dfdx::shapes::Const::<2>,
dfdx::shapes::Const::<3>,
),
);
let d_b = d_a.clone().max::<_, dfdx::shapes::Axis<1>>();
let d_c = d_a.clone().max::<_, dfdx::shapes::Axis<0>>();
let d_d = d_a.max::<_, dfdx::shapes::Axis<2>>();
assert_close(&b.data(), &d_b.as_vec());
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_mean_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.mean_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
data,
(
dfdx::shapes::Const::<1>,
dfdx::shapes::Const::<10>,
dfdx::shapes::Const::<4096>,
),
);
let d_b = d_a.mean::<_, dfdx::shapes::Axis<2>>();
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_matmul_simple() {
let mut cx = Graph::new();
let a_data = random_vec(256 * 256);
let b_data = random_vec(256 * 256);
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
a_data,
(dfdx::shapes::Const::<256>, dfdx::shapes::Const::<256>),
);
let d_b = d_dev.tensor_from_vec(
b_data,
(dfdx::shapes::Const::<256>, dfdx::shapes::Const::<256>),
);
let d_c = d_a.matmul(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_matmul() {
let mut cx = Graph::new();
let a_data = random_vec(512 * 512);
let b_data = random_vec(512 * 512);
let a = cx.tensor::<R2<512, 512>>().set(a_data.clone());
let b = cx.tensor::<R2<512, 512>>().set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(
a_data,
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<512>),
);
let d_b = d_dev.tensor_from_vec(
b_data,
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<512>),
);
let d_c = d_a.matmul(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_batch_matmul() {
let mut cx = Graph::new();
let a = cx
.tensor::<R3<2, 2, 3>>()
.set(vec![1., 2., 3., 1., 2., 1., 1., 2., 3., 1., 2., 1.]);
let b = cx
.tensor::<R2<3, 4>>()
.set(vec![1., 2., 3., 1., 1., 2., 1., 2., -1., -2., 1., 2.]);
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([[[1., 2., 3.], [1., 2., 1.]], [[1., 2., 3.], [1., 2., 1.]]]);
let d_b = d_dev.tensor([[1., 2., 3., 1.], [1., 2., 1., 2.], [-1., -2., 1., 2.]]);
let d_c = d_a.matmul(d_b);
assert_close(&c.data(), &d_c.as_vec());
}
#[test]
fn test_matmul_transpose() {
const M: usize = 1024; // Any
const K: usize = 16; // >= 16
const N: usize = 256; // >= 256, power of 2
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(M * K, &mut rng);
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
let a_t_data = random_vec_rng(K * M, &mut rng);
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
let mut a_b = a.matmul(b.permute()).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b.permute())
.retrieve();
let mut a_t_b_t = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b_t)
.retrieve();
cx.compile(
MetalCompiler::<f32>::default(),
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
);
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor_from_vec(a_data, (dfdx::shapes::Const::<M>, dfdx::shapes::Const::<K>));
let d_b = d_dev.tensor_from_vec(b_data, (dfdx::shapes::Const::<N>, dfdx::shapes::Const::<K>));
let d_a_t = d_dev.tensor_from_vec(
a_t_data,
(dfdx::shapes::Const::<K>, dfdx::shapes::Const::<M>),
);
let d_b_t = d_dev.tensor_from_vec(
b_t_data,
(dfdx::shapes::Const::<K>, dfdx::shapes::Const::<N>),
);
let d_a_b = d_a.clone().matmul(d_b.clone().permute());
let d_a_b_t = d_a.matmul(d_b_t.clone());
let d_a_t_b = d_a_t
.clone()
.permute::<_, dfdx::shapes::Axes2<1, 0>>()
.matmul(d_b.permute());
let d_a_t_b_t = d_a_t
.permute::<_, dfdx::shapes::Axes2<1, 0>>()
.matmul(d_b_t);
assert_close(&a_b.data(), &d_a_b.as_vec());
assert_close(&a_b_t.data(), &d_a_b_t.as_vec());
assert_close(&a_t_b.data(), &d_a_t_b.as_vec());
assert_close(&a_t_b_t.data(), &d_a_t_b_t.as_vec());
}
#[test]
fn test_relu_and_linear() {
// Test single and batch, unoptimized and optimized
let mut cx = Graph::new();
let input_data = random_vec(32);
let w1 = random_vec(32 * 64);
let w2 = random_vec(32 * 64);
let batch = cx
.named_tensor::<R2<2, 32>>("Batch")
.set(random_vec(32 * 2));
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
model.0.weight.set(w1.clone());
model.2.weight.set(w2.clone());
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();
cx.execute();
let unoptimized_b = b.data();
let unoptimized_batch_out = batch_out.data();
b.drop();
batch_out.drop();
cx.compile(
<(GenericCompiler, MetalCompiler<f32>)>::default(),
(&mut b, &mut batch_out),
);
cx.execute();
assert_close_precision(&unoptimized_b, &b.data(), 2);
assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2);
// Test against dfdx
let dev = Cpu::default();
let mut model = <(
dfdx::nn::modules::builders::UnbiasedLinear<32, 64>,
dfdx::nn::modules::builders::ReLU,
dfdx::nn::modules::builders::UnbiasedLinear<64, 32>,
)>::build_on_device(&dev);
// Set weights
model.0.weight = dev
.tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>))
.permute();
model.2.weight = dev
.tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>))
.permute();
let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,));
let out = model.forward(a);
assert_close_precision(&unoptimized_b, &out.as_vec(), 2);
}
#[test]
fn test_transformer_encoder_block() {
let mut cx = Graph::new();
let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<3, 4, 1> =
InitModule::initialize(&mut cx);
model
.attention
.w_k
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model
.attention
.w_q
.weight
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
model
.attention
.w_v
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
model
.attention
.w_o
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model
.ff
.0
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
model
.ff
.2
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<3>)>()
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[1, 2, 3]);
let mut b = model.forward(a).retrieve();
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
cx.execute();
let d_dev = Cpu::default();
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> =
d_dev.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<3, 1, 4>, f32>();
d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]);
d_model.self_attn.w_o.weight = d_dev
.tensor_from_vec(
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
)
.permute();
d_model.self_attn.w_k.weight = d_dev
.tensor_from_vec(
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
)
.permute();
d_model.self_attn.w_q.weight = d_dev
.tensor_from_vec(
vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
)
.permute();
d_model.self_attn.w_v.weight = d_dev
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
)
.permute();
d_model.ff.0 .0.weight = d_dev
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.],
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>),
)
.permute();
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (dfdx::shapes::Const::<4>,));
d_model.ff.0 .2.weight = d_dev
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.],
(dfdx::shapes::Const::<4>, dfdx::shapes::Const::<3>),
)
.permute();
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (dfdx::shapes::Const::<3>,));
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (dfdx::shapes::Const::<3>,));
d_model.norm1.epsilon = 1e-5;
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
d_model.norm2.epsilon = 1e-5;
let d_a = d_dev.tensor_from_vec(
vec![-1., 2., 3., 3., 3., -1.],
(dfdx::shapes::Const::<2>, dfdx::shapes::Const::<3>),
);
let d_b = d_model.forward(d_a);
assert_close(&b.data(), &d_b.as_vec());
}

View File

@@ -0,0 +1,2 @@
mod fp16;
mod fp32;

File diff suppressed because it is too large Load Diff

38
docs/01 Introduction.md Normal file
View File

@@ -0,0 +1,38 @@
# Luminal Introduction
Let's get up to speed with how to use luminal, and how it works internally.
First we'll take a look at what the simplest program will look like:
```rust
use luminal::prelude::*;
// Setup graph and tensors (1)
let mut cx = Graph::new();
let a = cx.new_tensor::<R1<3>>()
.set(vec![1.0, 2.0, 3.0]);
let b = cx.new_tensor::<R1<3>>()
.set(vec![1.0, 2.0, 3.0]);
// Actual operations (2)
let c = (a + b).retrieve();
// Run graph (3)
cx.execute();
// Get result (4)
println!("Result: {:?}", c);
// Prints out [2.0, 4.0, 6.0]
```
Wow! A lot is going on here just to add two tensors together. That's because luminal isn't really designed for such simple computation, and there's little benifit to using it here. But we'll see it pay off when we start doing more complex operations.
So what's happening here?
1) We're setting up a new `Graph` which tracks all computation and actually does execution. We're also defining two new tensors, both of shape (3,). At this point, these "tensors" are actually `GraphTensor`s that don't hold any data. Also, notice we pass in the shape as a type generic. *Types are known at compile time, similar to [dfdx](https://github.com/coreylowman/dfdx)!*
2) Now we can start doing the thing we came here for: the addition. So we add two `GraphTensor`s together, and get a new `GraphTensor`. Notice this *does not* consume anything, and we're free to use a or b later on. This is because `GraphTensor` is a super lightweight tracking struct which implements copy. "But wait, we never set tbe values of a and b, how can we add them? **We aren't actually adding them here.** Instead, we're writing this addition to the graph, and getting out c, which points to the result when it's actually done.
Then we set the data for these tensors. But if `GraphTensor` doesn't hold data, how can we set it? Well we aren't actually setting it *in* the tensor, just passing it through to the graph to say *once you run, set this tensor to this value.* We also need to mark the output we want to retrieve later. This is so that when the graph runs, it doesn't delete the data for c part-way through execution (a common optimization for unused tensors). Notice we're setting the sources *after* we define the computation. This is backward from a lot of other libs, but it means we can redefine the data and rerun everything without redefining the computation later on.
3) Once we call `cx.execute()`, we've already set all our sources, so our addition actually gets ran and stored in c!
4) Now since we're done computing c, we can fetch the data for c and see the result.
Alright, that was a lot but now we've touched on all the main aspects of running a model in luminal.
[Let's take a look at each piece in more depth.](https://github.com/jafioti/luminal/blob/main/docs/02%20GraphTensor%20API.md)

View File

@@ -0,0 +1,19 @@
# GraphTensors
We're working with pretty complicated graphs to build our computation on, but we don't want to manually place all the nodes ourselves! So how can we build these static graphs in a nice, familiar way? GraphTensors!
Essentially GraphTensors are pointers to a specific node on the graph, as well as some metadata about the output of that node, such as its shape. We can make a new GraphTensor by doing:
```rust
let mut cx = Graph::new(); // We need a graph to build!
let a: GraphTensor<R1<3>> = cx.tensor(); // Here we create a new node on the graph and get a GraphTensor back, pointing to it.
```
Notice the type of `a`: `GraphTensor<R1<3>>`. So what's that generic all about? It's the shape! We make tensor shapes part of the type, so they're tracked at compile time! In this case, the shape is rank 1, with 3 elements, or in other words, a vector of 3 dimensions. (Side note: `R1<N>` is a typedef of `(Const<N>,)`) It should be impossible to accidentally get a runtime shape mismatch.
Now we can use the `a` as you would in a library like PyTorch, performing linear algebra:
```rust
let b = a.exp().sqrt();
let c = b + a;
```
Looks familiar!
[Let's take a look at how GraphTensors are used to build whole neural networks.](https://github.com/jafioti/luminal/blob/main/docs/03%20Modules.md)

31
docs/03 Modules.md Normal file
View File

@@ -0,0 +1,31 @@
# NN Modules
Like any good DL library, we organize our networks into `Module`s. Here is the module trait:
```rust
/// A module with a forward pass
pub trait Module<I> {
type Output;
fn forward(&self, input: I) -> Self::Output;
}
```
Super simple, we just define a forward function that takes an input and returns an output. A consequence of this is it allows us to define seperate forward passes for single and batched inputs!
Now let's take a look at how `Linear` is defined:
```rust
/// A simple linear layer
pub struct Linear<const A: usize, const B: usize> {
pub(crate) weight: GraphTensor<R2<A, B>>,
}
impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
type Output = GraphTensor<R1<B>>;
fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
input.matmul(self.weight)
}
}
```
Here we see a single weight matrix as the internal state, of size AxB. We've written a single forward function for single input vectors of shape (A,) and matmul it by our weight matrix to get an output of shape (B,).
Now all of these ops are recorded on the graph, to be compiled and ran later on.
[So how does this compilation work? Let's find out!](https://github.com/jafioti/luminal/blob/main/docs/04%20Compilers.md)

27
docs/04 Compilers.md Normal file
View File

@@ -0,0 +1,27 @@
# Compilers
So now we have our graph all set up. We did our forward passes through the model, so now what? Do we run it?
We could! But it wouldn't be very fast. Right now your graph is full of **primops**, which are the simplest set of primitive operations in luminal. One of the key tenants of luminal is a small primop set, which makes it easy to add new backends and write compilers for. But another consequence of a small primset is that even simple operations usually end up creating quite a few operations, and even small neural networks can end up with hundreds or thousands of primops, which are slow to run directly. So it's time to compile the graph!
Compilers are structs that implement the `Compiler` trait, which simply specifies a single function:
```rust
pub trait Compiler {
/// Run a compilation pass
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, remap: T);
}
```
So all a compiler does is take a mutable reference to the graph, something called remap (beyond the scope of this introduction), and does something to the graph. That something is compilation, usually in the form of finding patterns of nodes and replacing them with other nodes. For instance, there's no Subtract operation in the primops, so subtractions are implemented as `add(a, mul(b, -1))`. We can have a compiler that looks for that pattern of nodes and directly replaces it with a `Subtract` operation. We'll look at how to do this in the [Writing Compilers](https://github.com/jafioti/luminal/blob/main/docs/06%20Writing%20Compilers.md) section.
All you need to know for now is that we can use this compiler on the graph by doing:
```rust
cx.compile(SubtractionCompiler::default());
```
Now the graph will have the old mul + add pattern removed and Subtract ops placed in. There are plenty of different compilers for different purposes. Some of the popular ones:
- GenericCompiler - A handful of hardware-agnostic optimizations like [CSE](https://en.wikipedia.org/wiki/Common_subexpression_elimination) to be ran before any hardware-specific compilers.
- CudaCompiler<T> - The full stack of cuda compilers to convert a graph to a cuda-specialized graph with T as the datatype (either f32 or f16). Imported from luminal_cuda
- MetalCompiler<T> - Same as CudaCompiler. Imported from luminal_metal
Compilers are entirely seperate from luminal, so they can be fully implemented by third party crates. For instance, everything specific to Cuda is contained in luminal_cuda.
[Now let's look into how to load weights from a file.](https://github.com/jafioti/luminal/blob/main/docs/05%20Serialization.md)

1
docs/05 Serialization.md Normal file
View File

@@ -0,0 +1 @@
Coming Soon

View File

@@ -0,0 +1 @@
Coming Soon

10
docs/CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,10 @@
# Contributing to luminal
![image](https://raw.githubusercontent.com/jafioti/luminal/main/resources/dag.jpeg)
Please take a look at the [issues](https://github.com/jafioti/luminal/issues) and [roadmap](https://github.com/users/jafioti/projects/1) to see what's targeted for upcoming releases. Contributions for those features are preferred and will be reviewed and merged very rapidly. Other contributions are welcome, but please note luminal is and always will be a fairly minimal library.
The core design of luminal is heavily predicated on extensibility. Compilers alow for immense complexity to be removed from the core library and added with third party compilers. For instance, datatypes and devices are typically first class primitives. In luminal, they're compilers and the core has no idea about them. This is the general trend we'll stick to: core remains brutally simple, and everything that can be externalized to a compiler will be.
We will be adding training support soon, and as you guessed, it will entirely reside in a compiler. Just define the model's graph, run the output through an optimizer, and then run the `AutogradCompiler` before any other compilers. Boom, we got training, and the core of the library has no idea! (aside from some quality of life apis)
PRs that remove complexity are always welcome, but note that line count often is a bad proxy for complexity. Ideally the entire luminal core should be a few thousand lines of code, but anything remotely resembling code golf is not allowed.

View File

@@ -1,68 +0,0 @@
## Luminal Introduction
Let's get up to speed with how to use luminal, and how it works internally.
First we'll take a look at what the simplest program will look like:
```rust
use luminal::prelude::*;
// Setup graph and tensors (1)
let mut cx = Graph::new();
let a = cx.new_tensor::<R1<3>>();
let b = cx.new_tensor::<R1<3>>();
// Actual operations (2)
let c = a + b;
// Set inputs and mark outputs (3)
a.set(vec![1.0, 2.0, 3.0]);
b.set(vec![1.0, 2.0, 3.0]);
c.mark();
// Run graph (4)
cx.execute();
// Get result (5)
println!("Result: {:?}", c.retrieve().unwrap().real_data(c.view().unwrap()).unwrap());
// Prints out [2.0, 4.0, 6.0]
```
Wow! A lot is going on here just to add two tensors together. That's because luminal isn't really designed for such simple computation, and there's little benifit to using it here. But we'll see it pay off when we start doing more complex operations.
So what's happening here?
1) We're setting up a new `Graph` which tracks all computation and actually does execution. We're also defining two new tensors, both of shape (3,). At this point, these "tensors" are actually `GraphTensor`s that don't hold any data. Also, notice we pass in the shape as a type generic. *Types are known at compile time, similar to [dfdx](https://github.com/coreylowman/dfdx)!*
2) Now we can start doing the thing we came here for: the addition. So we add two `GraphTensor`s together, and get a new `GraphTensor`. Notice this *does not* consume anything, and we're free to use a or b later on. This is because `GraphTensor` is a super lightweight tracking struct which implements copy. "But wait, we never set tbe values of a and b, how can we add them? **We aren't actually adding them here.** Instead, we're writing this addition to the graph, and getting out c, which points to the result when it's actually done.
3) Then we set the data for these tensors. But if `GraphTensor` doesn't hold data, how can we set it? Well we aren't actually setting it *in* the tensor, just passing it through to the graph to say *once you run, set this tensor to this value.* We also need to mark the output we want to retrieve later. This is so that when the graph runs, it doesn't delete the data for c part-way through execution (a common optimization for unused tensors). Notice we're setting the sources *after* we define the computation. This is backward from a lot of other libs, but it means we can redefine the data and rerun everything without redefining the computation later on.
4) Once we call `cx.execute()`, we've already set all our sources, so our addition actually gets ran and stored in c!
5) Now since we're done computing c, we can fetch the data for c and see the result. *This API is likely to change, as it's very ugly.*
Alright, that was a lot but now we've touched on all the main aspects of running a model in luminal.
## NN Modules
Like any good DL library, we organize our networks into `Module`s. Here is the module trait:
```rust
/// A module with a forward pass
pub trait Module<I> {
type Output;
fn forward(&self, input: I) -> Self::Output;
}
```
Super simple, we just define a forward function that takes an input and returns an output. A consequence of this is it allows us to define seperate forward passes for single and batched inputs!
Now let's take a look at how `Linear` is defined:
```rust
/// A simple linear layer
pub struct Linear<const A: usize, const B: usize> {
pub(crate) weight: GraphTensor<R2<A, B>>,
}
impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
type Output = GraphTensor<R1<B>>;
fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
input.matmul(self.weight)
}
}
```
Here we see a single weight matrix as the internal state, of size AxB. We've written a single forward function for single input vectors of shape (A,) and matmul it by our weight matrix to get an output of shape (B,).
Again, notice we're only dealing with `GraphTensor`s here, so when this code actually gets ran, **no computation happens, it just gets recorded to the graph.**

16
examples/llama/.gitignore vendored Normal file
View File

@@ -0,0 +1,16 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
setup/llama-7b-hf
.vscode

20
examples/llama/Cargo.toml Normal file
View File

@@ -0,0 +1,20 @@
[package]
name = "llama"
version = "0.1.0"
edition = "2021"
[features]
metal = ["dep:luminal_metal", "dep:metal-rs"]
cuda = ["dep:luminal_cuda"]
[dependencies]
luminal = {path="../.."}
luminal_metal = {path="../../crates/luminal_metal", optional=true}
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
rust_tokenizers = "8.1.0"
clap = { version = "4.4.18", features = ["derive"] }
byteorder = "1.5.0"
memmap2 = "0.9.4"
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
colored = "2.1.0"
itertools = "0.12.1"

View File

@@ -1,22 +0,0 @@
// Common
pub const VOCAB: usize = 32_000;
pub const HEAD_DIM: usize = 128;
pub const HEAD_DIM_OVER_2: usize = 64;
// 7B
pub const HIDDEN: usize = 4096;
pub const INTERMEDIATE: usize = 11008;
pub const HEADS: usize = 32;
pub const LAYERS: usize = 1;
// 13B
// pub const HIDDEN: usize = 5120;
// pub const INTERMEDIATE: usize = 13824;
// pub const HEADS: usize = 40;
// pub const LAYERS: usize = 40;
// 65B
// pub const HIDDEN: usize = 8192;
// pub const INTERMEDIATE: usize = 22016;
// pub const HEADS: usize = 64;
// pub const LAYERS: usize = 80;

View File

@@ -1,92 +0,0 @@
mod config;
mod loader;
mod model;
use luminal::prelude::*;
use model::LlamaForCausalLM;
use crate::model::KVCache;
#[rustfmt::skip]
fn main() {
let tokenizer = tokenizers::tokenizer::Tokenizer::from_pretrained("oobabooga/llama-tokenizer", None).unwrap();
let mut input: Vec<usize> = tokenizer.encode("The young boy ran over to the", false).unwrap().get_ids().iter().map(|i| *i as usize).collect();
println!("Creating Graph...");
let mut cx = Graph::new();
let model: LlamaForCausalLM<
{ config::VOCAB },
{ config::HEADS },
{ config::HIDDEN },
{ config::INTERMEDIATE },
{ config::HEAD_DIM },
{ config::HEAD_DIM_OVER_2 },
{ config::LAYERS },
> = InitModule::initialize(&mut cx);
let inp = cx.new_tensor::<(usize, usize)>("Input");
let (out, cache_src) = model.forward(inp);
out.mark();
for (k, v) in &cache_src {
k.mark_no_delete();
v.mark_no_delete();
}
println!("Loading...");
loader::DfdxDeferredLoader::new("../../Desktop/llama-dfdx-main/llama-7b-hf").load(&model, &mut cx);
println!("Inferencing...");
// First pass
inp.set_dyn(input.clone(), vec![1, input.len()]);
let now = std::time::Instant::now();
cx.display_shapes();
cx.execute();
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());
let out = out.retrieve().unwrap().real_data(out.view().unwrap()).unwrap();
input.push(sample_index(&out[(input.len() - 1) * 32_000..]));
println!("{}", tokenizer.decode(input.iter().map(|i| *i as u32).collect(), false).unwrap());
// Build KV cache forward graph
let (out, cache_dest): (_, Vec<KVCache<_, usize, {config::HEADS}, {config::HEAD_DIM}>>) = model.forward_kv((inp, cache_src.clone()));
out.mark();
for (k, v) in &cache_dest {
k.mark_no_delete();
v.mark_no_delete();
}
cx.prune([out.id], cache_src.iter().flat_map(|(k, v)| [k.id, v.id]));
loop {
inp.set_dyn(vec![*input.last().unwrap()], vec![1, 1]);
let now = std::time::Instant::now();
cx.execute();
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());
let o = out.retrieve().unwrap().real_data(out.view().unwrap()).unwrap();
// Sample tokens
input.push(sample_index(&o));
println!("{}", tokenizer.decode(input.iter().map(|i| *i as u32).collect(), false).unwrap());
// Swap caches
for ((src_k, src_v), (dest_k, dest_v)) in cache_src.iter().copied().zip(cache_dest.iter().copied()) {
// Move dest caches to src
cx.swap_tensors(src_k, dest_k);
cx.swap_tensors(src_v, dest_v);
// Drop dest caches
dest_k.drop();
dest_v.drop();
}
}
}
// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> usize {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0
}

View File

@@ -1,749 +0,0 @@
#![allow(clippy::type_complexity)]
use std::ops::{Add, Mul};
use luminal::{
nn::{activation::RMSNorm, embedding::Embedding},
op,
prelude::{movement::TryConcatAlong, *},
};
use rand::{thread_rng, Rng};
// Full LLaMa model implementation, heavily based off of https://github.com/coreylowman/llama-dfdx/blob/main/src/modeling.rs
pub type KVCache<Batch, Seq, const NUM_HEADS: usize, const HEAD_DIM: usize> = (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
}
impl<const I: usize, const H: usize, B: Dim, S: Dim> Module<GraphTensor<(B, S, Const<H>)>>
for Mlp<I, H>
{
type Output = GraphTensor<(B, S, Const<H>)>;
fn forward(&self, input: GraphTensor<(B, S, Const<H>)>) -> Self::Output {
let gate = input.matmul(self.gate_proj.permute());
let gate = gate.sigmoid() * gate;
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
Self {
gate_proj: cx.new_tensor("Gate Weight"),
up_proj: cx.new_tensor("Up Weight"),
down_proj: cx.new_tensor("Down Weight"),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
fn serialize(&self, s: &mut Serializer) {
s.tensor("gate_proj/weight", self.gate_proj);
s.tensor("up_proj/weight", self.up_proj);
s.tensor("down_proj/weight", self.down_proj);
}
}
pub struct RotaryEmbedding<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> {
pub inv_freq: GraphTensor<R1<HEAD_DIM_OVER_2>>,
}
impl<
Batch: Dim,
const NUM_HEADS: usize,
Seq: Dim,
PrevSeq: Dim,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
>
Module<(
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
)> for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
{
type Output = (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
);
fn forward(
&self,
(q, k, cache): (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
),
) -> Self::Output {
let (sin, cos) = self.get_sincos(q, cache);
let sin = sin.expand();
let cos = cos.expand();
let q_embed = (Self::rotate_half(q) * sin) + (q * cos);
let k_embed = (Self::rotate_half(k) * sin) + (k * cos);
(q_embed, k_embed)
}
}
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize>
RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
{
fn get_sincos<Batch: Dim, const NUM_HEADS: usize, Seq: Dim, PrevSeq: Dim>(
&self,
seq_tensor: GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
cache: Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
) -> (
GraphTensor<(Seq, Const<HEAD_DIM>)>,
GraphTensor<(Seq, Const<HEAD_DIM>)>,
) {
let graph = unsafe { self.inv_freq.graph_ref.as_mut().unwrap() };
let has_cache = cache.is_some();
let mut op = graph
.add_op(
op::Function(
"ARange".to_string(),
Box::new(move |inp, i| {
let offset = if has_cache {
inp[1].1.shape.shape()[2]
} else {
0
};
(
Some(Tensor {
data: Box::new(
(0..inp[0].1.shape.shape()[2])
.map(|i| (i + offset) as f32)
.collect::<Vec<_>>(),
),
}),
TensorView {
tensor_id: i,
shape: ShapeTracker::new(vec![inp[0].1.shape.shape()[2]]),
},
)
}),
),
vec![Seq::const_size()],
)
.input(seq_tensor.id);
if has_cache {
op = op.input(cache.unwrap().0.id);
}
let t: GraphTensor<(Seq,)> = GraphTensor::from_id(op.finish(), graph);
let freqs = t
.expand::<(Seq, Const<1>), _>()
.matmul(
self.inv_freq
.expand::<(Const<1>, Const<HEAD_DIM_OVER_2>), _>(),
)
.realize::<(Seq, usize)>();
let emb = (freqs, freqs).concat_along(Axis::<1>);
(emb.sin().realize(), emb.cos().realize())
}
fn rotate_half<Batch: Dim, NumHeads: Dim, Seq: Dim>(
x: GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)>,
) -> GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)> {
let x1 = x.slice((.., .., .., ..HEAD_DIM_OVER_2));
let x2 = x.slice((.., .., .., HEAD_DIM_OVER_2..));
(-x2, x1).concat_along(Axis::<3>).realize()
}
}
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> InitModule
for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
{
fn initialize(cx: &mut Graph) -> Self {
let s = Self {
inv_freq: cx.new_tensor("Inv Freq"),
};
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
s.inv_freq.set(
(0..HEAD_DIM_OVER_2)
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
);
s
}
}
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> SerializeModule
for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
{
fn serialize(&self, s: &mut Serializer) {
s.tensor("inv_freq", self.inv_freq);
}
}
pub struct Attention<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> {
pub q_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub k_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub v_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub o_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub rotary_embed: RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>,
}
fn attn_forward<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
Batch: Dim,
Seq: Dim,
PrevSeq: Dim,
>(
attn: &Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>,
x: GraphTensor<(Batch, Seq, Const<HIDDEN>)>,
cache: Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
) -> (
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
) {
let q = x
.matmul(attn.q_proj.permute())
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
Batch::const_size().to_reshape(0),
Seq::const_size().to_reshape(1),
ReshapeDim::Const(NUM_HEADS),
ReshapeDim::Const(HEAD_DIM),
])
.permute::<_, Axes4<0, 2, 1, 3>>();
let k = x
.matmul(attn.k_proj.permute())
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
Batch::const_size().to_reshape(0),
Seq::const_size().to_reshape(1),
ReshapeDim::Const(NUM_HEADS),
ReshapeDim::Const(HEAD_DIM),
])
.permute::<_, Axes4<0, 2, 1, 3>>();
let v = x
.matmul(attn.v_proj.permute())
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
Batch::const_size().to_reshape(0),
Seq::const_size().to_reshape(1),
ReshapeDim::Const(NUM_HEADS),
ReshapeDim::Const(HEAD_DIM),
])
.permute::<_, Axes4<0, 2, 1, 3>>();
let (q, k) = attn.rotary_embed.forward((
q.realize::<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>(),
k.realize(),
cache,
));
(q, k, v)
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
Batch: Dim,
CurSeq: Dim,
>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
GraphTensor<(CurSeq, CurSeq)>,
)> for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>,
);
fn forward(
&self,
(x, attn_mask): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
GraphTensor<(CurSeq, CurSeq)>,
),
) -> Self::Output {
let (q, k, v) = attn_forward(
self,
x,
Option::<KVCache<_, usize, NUM_HEADS, HEAD_DIM>>::None,
);
let inv_head_scale = (HEAD_DIM as f64).sqrt().recip() as f32;
let w = q
.batch_matmul(k.permute())
.mul(inv_head_scale)
.add(attn_mask.expand())
.softmax::<3>();
let o = w
.batch_matmul(v)
.permute::<(Batch, CurSeq, Const<NUM_HEADS>, Const<HEAD_DIM>), _>()
.dyn_reshape::<(Batch, CurSeq, Const<HIDDEN>)>(vec![
Batch::const_size().to_reshape(0),
CurSeq::const_size().to_reshape(1),
ReshapeDim::Const(HIDDEN),
]);
(o.matmul(self.o_proj.permute()), (k, v))
}
}
// KV cache forward
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
&self,
(x, cache): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>,
),
) -> (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>,
) {
let (q, k, v) = attn_forward(self, x, Some(cache));
// Add KV cache
let k = (
cache
.0
.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
k.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
)
.concat_along(Axis::<2>)
.realize::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>)>();
let v = (
cache
.1
.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
v.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
)
.concat_along(Axis::<2>)
.realize::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>)>();
let w = q
.batch_matmul(k.permute())
.mul((HEAD_DIM as f64).sqrt().recip() as f32) // Inv head scale
.softmax::<3>();
let o = w
.batch_matmul(v)
.permute::<(Batch, CurSeq, Const<NUM_HEADS>, Const<HEAD_DIM>), _>()
.dyn_reshape::<(Batch, CurSeq, Const<HIDDEN>)>(vec![
Batch::const_size().to_reshape(0),
CurSeq::const_size().to_reshape(1),
ReshapeDim::Const(HIDDEN),
]);
(o.matmul(self.o_proj.permute()), (k, v))
}
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> InitModule for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
q_proj: cx.new_tensor("Query Weight"),
k_proj: cx.new_tensor("Key Weight"),
v_proj: cx.new_tensor("Value Weight"),
o_proj: cx.new_tensor("Output Weight"),
rotary_embed: InitModule::initialize(cx),
}
}
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> SerializeModule for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn serialize(&self, s: &mut Serializer) {
s.tensor("q_proj/weight", self.q_proj);
s.tensor("k_proj/weight", self.k_proj);
s.tensor("v_proj/weight", self.v_proj);
s.tensor("o_proj/weight", self.o_proj);
s.module("rotary_emb", &self.rotary_embed);
}
}
pub struct DecoderLayer<
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> {
pub self_attn: Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>,
pub mlp: Mlp<INTERMEDIATE, HIDDEN>,
pub input_layer_norm: RMSNorm<HIDDEN>,
pub post_attention_layer_norm: RMSNorm<HIDDEN>,
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
Batch: Dim,
CurSeq: Dim,
>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
GraphTensor<(CurSeq, CurSeq)>,
)> for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>,
);
fn forward(
&self,
(x, attn_mask): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
GraphTensor<(CurSeq, CurSeq)>,
),
) -> Self::Output {
let (y, kv_cache) = self
.self_attn
.forward((self.input_layer_norm.forward(x), attn_mask));
let x = x + y;
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
(x + y, kv_cache)
}
}
// KV cache forward
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
&self,
(x, cache): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>,
),
) -> (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>,
) {
let (y, kv_cache) = self
.self_attn
.forward_kv((self.input_layer_norm.forward(x), cache));
let x = x + y;
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
(x + y, kv_cache)
}
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> InitModule for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
self_attn: InitModule::initialize(cx),
mlp: InitModule::initialize(cx),
input_layer_norm: InitModule::initialize(cx),
post_attention_layer_norm: InitModule::initialize(cx),
}
}
}
impl<
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
> SerializeModule for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
{
fn serialize(&self, s: &mut Serializer) {
s.module("self_attn", &self.self_attn);
s.module("mlp", &self.mlp);
s.module("input_layernorm", &self.input_layer_norm);
s.module("post_attention_layernorm", &self.post_attention_layer_norm);
}
}
pub struct Llama<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> {
pub embed_tokens: Embedding<VOCAB, HIDDEN>,
pub layers: Vec<DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>>,
pub norm: RMSNorm<HIDDEN>,
pub graph_ref: *mut Graph,
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
Batch: Dim,
CurSeq: Dim,
> Module<GraphTensor<(Batch, CurSeq)>>
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Vec<KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>>,
);
fn forward(&self, input: GraphTensor<(Batch, CurSeq)>) -> Self::Output {
let graph = unsafe { self.graph_ref.as_mut().unwrap() };
let attn_mask: GraphTensor<(CurSeq, CurSeq)> = GraphTensor::from_id(
graph
.add_op(
op::Function(
"AttentionMask".to_string(),
Box::new(|inp, i| {
let seq_len = inp[0].1.shape.shape()[1];
let mut data = vec![0.; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
(
Some(Tensor {
data: Box::new(data),
}),
TensorView {
tensor_id: i,
shape: ShapeTracker::new(vec![
inp[0].1.shape.shape()[1],
inp[0].1.shape.shape()[1],
]),
},
)
}),
),
vec![CurSeq::const_size(), CurSeq::const_size()],
)
.input(input.id)
.finish(),
graph,
);
let mut hidden_states = self.embed_tokens.forward(input);
let mut caches = vec![];
for layer_i in &self.layers {
let (new_hidden_states, kv_cache) = layer_i.forward((hidden_states, attn_mask));
hidden_states = new_hidden_states;
caches.push(kv_cache);
}
(self.norm.forward(hidden_states), caches)
}
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
pub fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
&self,
(input, caches): (
GraphTensor<(Batch, CurSeq)>,
Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
),
) -> (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
) {
let mut hidden_states = self.embed_tokens.forward(input);
let mut new_caches = vec![];
for (layer_i, cache) in self.layers.iter().zip(caches.into_iter()) {
let (new_hidden_states, kv_cache) = layer_i.forward_kv((hidden_states, cache));
hidden_states = new_hidden_states;
new_caches.push(kv_cache);
}
(self.norm.forward(hidden_states), new_caches)
}
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> InitModule
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
norm: InitModule::initialize(cx),
embed_tokens: InitModule::initialize(cx),
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
graph_ref: cx,
}
}
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> SerializeModule
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn serialize(&self, s: &mut Serializer) {
s.module("norm", &self.norm);
s.module("embed_tokens", &self.embed_tokens);
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("layers/{i}"), l);
}
}
}
pub struct LlamaForCausalLM<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> {
pub llama: Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>,
pub lm_head: GraphTensor<(Const<VOCAB>, Const<HIDDEN>)>,
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
Batch: Dim,
CurSeq: Dim,
> Module<GraphTensor<(Batch, CurSeq)>>
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
Vec<KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>>,
);
fn forward(&self, input: GraphTensor<(Batch, CurSeq)>) -> Self::Output {
let (hidden_states, caches) = self.llama.forward(input);
(hidden_states.matmul(self.lm_head.permute()), caches)
}
}
// KV cache forward
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
pub fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
&self,
(input, caches): (
GraphTensor<(Batch, CurSeq)>,
Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
),
) -> (
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
) {
let (hidden_states, caches) = self.llama.forward_kv((input, caches));
(hidden_states.matmul(self.lm_head.permute()), caches)
}
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> InitModule
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
Self {
llama: InitModule::initialize(cx),
lm_head: cx.new_tensor("LM Head"),
}
}
}
impl<
const VOCAB: usize,
const NUM_HEADS: usize,
const HIDDEN: usize,
const INTERMEDIATE: usize,
const HEAD_DIM: usize,
const HEAD_DIM_OVER_2: usize,
const LAYERS: usize,
> SerializeModule
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
{
fn serialize(&self, s: &mut Serializer) {
s.module("model", &self.llama);
s.tensor("lm_head/weight", self.lm_head);
}
}

View File

@@ -0,0 +1,28 @@
import argparse
import os
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument("src", help="root directory", default="llama-7b-hf")
args = parser.parse_args()
for f in os.listdir(args.src):
if not f.endswith(".bin"):
continue
print(f"Loading {f}")
sd = torch.load(os.path.join(args.src, f))
for key, tensor in sd.items():
print("Saving", key, tensor.shape, tensor.dtype)
path = os.path.sep.join(key.split("."))
os.makedirs(os.path.join(args.src, os.path.dirname(path)), exist_ok=True)
np_array = tensor.numpy()
with open(os.path.join(args.src, path), "w") as fp:
np_array.tofile(fp)
del np_array
del sd
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# Setup git LFS
echo "Setting up git LFS..."
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
sudo apt install git-lfs
elif [[ "$OSTYPE" == "darwin"* ]]; then
brew install git-lfs
fi
git lfs install
echo "Downloading Model..."
git lfs clone https://huggingface.co/decapoda-research/llama-7b-hf $SCRIPT_DIR/llama-7b-hf
# Convert the model
echo "Converting Model..."
python3 $SCRIPT_DIR/convert.py $SCRIPT_DIR/llama-7b-hf
echo "Done!"

View File

@@ -1,4 +1,3 @@
use half::f16;
use luminal::{op::Function, prelude::*};
/// Load the model in the same way dfdx-llama does
@@ -16,36 +15,29 @@ impl DfdxDeferredLoader {
}
impl Loader for DfdxDeferredLoader {
type Output = ();
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) {
let mut serializer = Serializer::default();
model.serialize(&mut serializer);
for (s, n) in serializer.state {
let shape: Vec<usize> = graph
for (s, n) in state_dict(model) {
let Some(n_elements) = graph
.graph
.node_weight_mut(n)
.unwrap()
.1
.iter()
.map(|i| match i {
RealDim::Const(m) => *m,
RealDim::Dyn => panic!("Dyn dimension in a weight"),
})
.collect();
.edges_directed(n, petgraph::Direction::Outgoing)
.find_map(|e| e.weight().as_data())
.map(|(_, _, s)| s.n_physical_elements().to_usize().unwrap())
else {
continue;
};
if let Some(inp_func) = graph
.graph
.node_weight_mut(n)
.unwrap()
.0
.as_any_mut()
.downcast_mut::<Function>()
{
let path = self.path.clone();
inp_func.1 = Box::new(move |_, i| {
inp_func.1 = Box::new(move |_| {
// Get memmapped tensor
let bytes = std::fs::read(format!("{path}/{s}")).unwrap();
let num_params: usize = shape.iter().product();
let data: Vec<f32> = if bytes.len() == num_params * 2 {
let data: Vec<f32> = if bytes.len() == n_elements * 2 {
// Half-precision
bytes
.chunks_exact(std::mem::size_of::<f16>())
@@ -53,7 +45,7 @@ impl Loader for DfdxDeferredLoader {
std::mem::transmute::<[u8; 2], f16>([chunk[0], chunk[1]]).to_f32()
})
.collect()
} else if bytes.len() == num_params * 4 {
} else if bytes.len() == n_elements * 4 {
// Full precision
bytes
.chunks_exact(std::mem::size_of::<f32>())
@@ -65,23 +57,16 @@ impl Loader for DfdxDeferredLoader {
.collect()
} else {
panic!(
"Expected {} or {} bytes, got {} when loading {}{}",
num_params * 2,
num_params * 4,
"Expected {} or {} bytes, got {} when loading {path}/{s}",
n_elements * 2,
n_elements * 4,
bytes.len(),
path,
s
)
};
(
Some(Tensor {
data: Box::new(data),
}),
TensorView {
tensor_id: i,
shape: ShapeTracker::new(shape.clone()),
},
)
vec![Tensor {
data: Box::new(data),
}]
});
};
}

164
examples/llama/src/main.rs Normal file
View File

@@ -0,0 +1,164 @@
mod loader;
mod model;
use std::{
io::{self, Write},
marker::PhantomData,
time::Instant,
};
use colored::Colorize;
use luminal::{prelude::*, shape::symbolic::Expression};
use rust_tokenizers::tokenizer::{
SentencePieceBpeTokenizer, Tokenizer,
TruncationStrategy::{self},
};
use crate::model::KVCache;
#[cfg(feature = "metal")]
type DeviceCompiler = luminal_metal::MetalCompiler<luminal::prelude::f16>;
#[cfg(feature = "cuda")]
type DeviceCompiler = luminal_cuda::CudaCompiler<luminal::prelude::f16>;
#[cfg(all(not(feature = "cuda"), not(feature = "metal")))]
type DeviceCompiler = CPUCompiler;
fn main() {
let prompt = "Here is a python implementation of merge sort:";
let tokens_to_generate = 128;
let tokenizer =
SentencePieceBpeTokenizer::from_file("setup/llama-7b-hf/tokenizer.model", false).unwrap();
print!("Defining graph");
io::stdout().flush().unwrap();
let now = Instant::now();
let mut cx = Graph::new();
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::LAYERS)
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
.collect();
cache_src.set_dyn(vec![], &[1, model::HEADS, 0, model::HEAD_DIM]);
let model = model::Llama::initialize(&mut cx);
let (logits, mut cache_dest) =
model.forward((input, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve();
cache_dest.keep();
loader::DfdxDeferredLoader::new("setup/llama-7b-hf").load(&model, &mut cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
print!("Compiling graph");
io::stdout().flush().unwrap();
let now = Instant::now();
cx.compile(
<(GenericCompiler, DeviceCompiler)>::default(),
(&mut input, &mut logits, &mut cache_src, &mut cache_dest),
);
// Keep model weights
let model_weights = downstream(state_set(&model), &cx);
cx.keep_tensors(&model_weights);
let cache_src_set = downstream(&cache_src, &cx);
let cache_dest_set = cache_dest.to_ids();
println!("\t\t - {}ms", now.elapsed().as_millis());
// Initial forward pass to load weights
print!("Loading model");
io::stdout().flush().unwrap();
let now = Instant::now();
input.set_dyn(vec![0.], &[1, 1]);
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
cache_dest.drop();
println!("\t\t - {}ms", now.elapsed().as_millis());
// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&model_weights, &mut cx);
// Run prompt processing pass
let mut input_ids = encode(&tokenizer, prompt);
input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
&[1, input_ids.len()],
);
cx.set_dyn_dim('t', input_ids.len());
print!("Processing Prompt");
io::stdout().flush().unwrap();
let now = Instant::now();
cx.execute();
let elapsed_ms = now.elapsed().as_millis();
println!(
"\t - {elapsed_ms}ms ({:.2} tok/s)",
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
);
delete_inputs(&cache_src_set, &mut cx);
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
// Decode token
print!(
"{}{}",
prompt.white().bold(),
decode(&tokenizer, &[output_id]).bright_green()
);
io::stdout().flush().unwrap();
// Swap caches
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
// Decode loop
let mut token_decode_times = vec![];
for _ in 0..tokens_to_generate {
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
cx.set_dyn_dim('p', input_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len());
let now = Instant::now();
cx.execute();
token_decode_times.push(now.elapsed().as_micros());
// Sample tokens
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
io::stdout().flush().unwrap();
// Swap caches
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
}
let avg_token_time = token_decode_times
.iter()
.map(|t| *t as f32 / 1000.)
.sum::<f32>()
/ token_decode_times.len() as f32;
println!(
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
avg_token_time,
1000.0 / avg_token_time
);
}
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
}
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
}
// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> i64 {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0 as i64
}

365
examples/llama/src/model.rs Normal file
View File

@@ -0,0 +1,365 @@
#![allow(clippy::type_complexity)]
use std::{marker::PhantomData, ops::Mul};
// LLaMa 1 7B Config
pub const VOCAB: usize = 32_000;
pub const HEAD_DIM: usize = 128;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const HIDDEN: usize = 4096;
pub const INTERMEDIATE: usize = 11008;
pub const HEADS: usize = 32;
pub const LAYERS: usize = 32;
use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
prelude::*,
shape::symbolic::{BigExpression, Expression},
};
// Full LLaMa model implementation, heavily based off of https://github.com/coreylowman/llama-dfdx/blob/main/src/modeling.rs
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
}
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
where
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
{
type Output = GraphTensor<Sh>;
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
let gate = input.matmul(self.gate_proj.permute()).swish();
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
Self {
gate_proj: cx.named_tensor("Gate Weight"),
up_proj: cx.named_tensor("Up Weight"),
down_proj: cx.named_tensor("Down Weight"),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
fn serialize(&self, s: &mut Serializer) {
s.tensor("gate_proj/weight", self.gate_proj);
s.tensor("up_proj/weight", self.up_proj);
s.tensor("down_proj/weight", self.down_proj);
}
}
pub struct RotaryEmbedding {
pub inv_freq: GraphTensor<R1<HEAD_DIM_OVER_2>>,
}
impl<Batch: Dimension, Seq: Dimension>
Module<(
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
BigExpression,
)> for RotaryEmbedding
{
type Output = GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>;
fn forward(
&self,
(inp, prev_seq): (
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
BigExpression,
),
) -> Self::Output {
let (sin, cos) = self.get_sincos::<Seq>(prev_seq);
(Self::rotate_half(inp) * sin.expand()) + (inp * cos.expand())
}
}
impl RotaryEmbedding {
fn get_sincos<Seq: Dimension>(
&self,
prev_seq: BigExpression,
) -> (
GraphTensor<(Seq, Const<HEAD_DIM>)>,
GraphTensor<(Seq, Const<HEAD_DIM>)>,
) {
let t = self.inv_freq.graph().arange::<Seq>() + prev_seq;
let freqs = t.expand::<(Seq, Const<1>), _>().matmul(
self.inv_freq
.expand::<(Const<1>, Const<HEAD_DIM_OVER_2>), _>(),
);
let emb = freqs.concat_along::<(Seq, Const<HEAD_DIM>), Axis<1>, _>(freqs);
(emb.sin().reshape(), emb.cos().reshape())
}
fn rotate_half<Batch: Dimension, Seq: Dimension>(
x: GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
) -> GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)> {
let x1 = x
.slice((.., .., .., ..Expression::from(HEAD_DIM_OVER_2)))
.contiguous();
let x2 = x
.slice((.., .., .., Expression::from(HEAD_DIM_OVER_2)..))
.contiguous();
(-x2).concat_along::<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>), Axis<3>, _>(x1)
}
}
impl InitModule for RotaryEmbedding {
fn initialize(cx: &mut Graph) -> Self {
Self {
inv_freq: cx.named_tensor("Inv Freq"),
}
}
}
impl SerializeModule for RotaryEmbedding {
fn serialize(&self, s: &mut Serializer) {
s.tensor("inv_freq", self.inv_freq);
}
}
pub struct Attention {
pub q_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub k_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub v_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub o_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
pub rotary_embed: RotaryEmbedding,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
)> for Attention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let queries = self
.rotary_embed
.forward((queries.permute(), PrevSeq::const_size().into()));
let keys = self
.rotary_embed
.forward((keys, PrevSeq::const_size().into()));
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
(
k_cache.concat_along::<_, Axis<2>, _>(keys),
v_cache.concat_along::<_, Axis<2>, _>(values),
)
} else {
(keys.realize(), values.contiguous().realize())
};
let mut weights = queries
.matmul(keys.permute())
.mul((HEAD_DIM as f64).sqrt().recip() as f32);
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
weights += attention_mask
.pad::<(CurSeq, TotSeq), _, _>(&[
(0.into(), Expression::from(0)),
(TotSeq::const_size() - CurSeq::const_size(), 0.into()),
])
.expand();
let outputs = weights
.softmax::<3>()
.matmul(values)
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN>)>();
(
outputs.matmul(self.o_proj.permute()),
(keys.contiguous(), values.contiguous()),
)
}
}
impl InitModule for Attention {
fn initialize(cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Query Weight"),
k_proj: cx.named_tensor("Key Weight"),
v_proj: cx.named_tensor("Value Weight"),
o_proj: cx.named_tensor("Output Weight"),
rotary_embed: InitModule::initialize(cx),
}
}
}
impl SerializeModule for Attention {
fn serialize(&self, s: &mut Serializer) {
s.tensor("q_proj/weight", self.q_proj);
s.tensor("k_proj/weight", self.k_proj);
s.tensor("v_proj/weight", self.v_proj);
s.tensor("o_proj/weight", self.o_proj);
s.module("rotary_emb", &self.rotary_embed);
}
}
pub struct TransformerBlock {
pub self_attn: Attention,
pub mlp: Mlp<INTERMEDIATE, HIDDEN>,
pub input_layer_norm: RMSNorm<HIDDEN>,
pub post_attention_layer_norm: RMSNorm<HIDDEN>,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
// Attention
let normed = self.input_layer_norm.forward(x);
let (y, cache) = self
.self_attn
.forward((normed, cache, PhantomData::<TotSeq>));
// Residual Addition
x += y;
// Feed Forward
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
// Residual Addition
(x + y, cache)
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
Self {
self_attn: InitModule::initialize(cx),
mlp: InitModule::initialize(cx),
input_layer_norm: InitModule::initialize(cx),
post_attention_layer_norm: InitModule::initialize(cx),
}
}
}
impl SerializeModule for TransformerBlock {
fn serialize(&self, s: &mut Serializer) {
s.module("self_attn", &self.self_attn);
s.module("mlp", &self.mlp);
s.module("input_layernorm", &self.input_layer_norm);
s.module("post_attention_layernorm", &self.post_attention_layer_norm);
}
}
pub struct Llama {
// Token embeddings
pub embedding: Embedding<VOCAB, HIDDEN>,
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: RMSNorm<HIDDEN>,
// LM Head Layer
pub lm_head: GraphTensor<R2<VOCAB, HIDDEN>>,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq>>>,
PhantomData<TotSeq>,
)> for Llama
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq>>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
// Embed tokens
let mut x = self.embedding.forward(input);
// Run through layers and collect new caches
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) =
layer.forward((x, cache.as_ref().map(|c| c[i]), PhantomData::<TotSeq>));
new_caches.push(new_cache);
}
// Run through last norm and output projection
let output = self.norm.forward(x);
let output = output.matmul(self.lm_head.permute());
(output, new_caches)
}
}
impl InitModule for Llama {
fn initialize(cx: &mut Graph) -> Self {
Self {
norm: InitModule::initialize(cx),
embedding: InitModule::initialize(cx),
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
lm_head: cx.named_tensor("LM Head"),
}
}
}
impl SerializeModule for Llama {
fn serialize(&self, s: &mut Serializer) {
s.module("model/norm", &self.norm);
s.module("model/embed_tokens", &self.embedding);
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("model/layers/{i}"), l);
}
s.tensor("lm_head/weight", self.lm_head);
}
}

14
examples/mistral/.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

View File

@@ -0,0 +1,20 @@
[package]
name = "mistral"
version = "0.1.0"
edition = "2021"
[features]
metal = ["dep:luminal_metal", "dep:metal-rs"]
cuda = ["dep:luminal_cuda"]
[dependencies]
luminal = {path="../.."}
luminal_metal = {path="../../crates/luminal_metal", optional=true}
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
rust_tokenizers = "8.1.0"
clap = { version = "4.4.18", features = ["derive"] }
byteorder = "1.5.0"
memmap2 = "0.9.4"
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
colored = "2.1.0"
itertools = "0.12.1"

View File

@@ -0,0 +1,10 @@
# Three Laws of Robotics
**The Three Laws of Robotics** (often shortened to **The Three Laws** or **Asimov's Laws**) are a set of rules devised by science fiction author Isaac Asimov, which were to be followed by robots in several of his stories. The rules were introduced in his 1942 short story "Runaround" (included in the 1950 collection I, Robot), although similar restrictions had been implied in earlier stories.
## The Laws
The Three Laws, presented to be from the fictional "Handbook of Robotics, 56th Edition, 2058 A.D.", are:
- The First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm.
- The Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
- The Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.

View File

@@ -0,0 +1 @@
[INST]Write me a python implementation of merge sort[/INST]

View File

@@ -0,0 +1,209 @@
[INST] Complete the following
## SCENE VII. The forest.
A table set out. Enter DUKE SENIOR, AMIENS, and Lords like outlaws
### DUKE SENIOR
I think he be transform'd into a beast;
For I can no where find him like a man.
### First Lord
My lord, he is but even now gone hence:
Here was he merry, hearing of a song.
### DUKE SENIOR
If he, compact of jars, grow musical,
We shall have shortly discord in the spheres.
Go, seek him: tell him I would speak with him.
Enter JAQUES
### First Lord
He saves my labour by his own approach.
### DUKE SENIOR
Why, how now, monsieur! what a life is this,
That your poor friends must woo your company?
What, you look merrily!
### JAQUES
A fool, a fool! I met a fool i' the forest,
A motley fool; a miserable world!
As I do live by food, I met a fool
Who laid him down and bask'd him in the sun,
And rail'd on Lady Fortune in good terms,
In good set terms and yet a motley fool.
'Good morrow, fool,' quoth I. 'No, sir,' quoth he,
'Call me not fool till heaven hath sent me fortune:'
And then he drew a dial from his poke,
And, looking on it with lack-lustre eye,
Says very wisely, 'It is ten o'clock:
Thus we may see,' quoth he, 'how the world wags:
'Tis but an hour ago since it was nine,
And after one hour more 'twill be eleven;
And so, from hour to hour, we ripe and ripe,
And then, from hour to hour, we rot and rot;
And thereby hangs a tale.' When I did hear
The motley fool thus moral on the time,
My lungs began to crow like chanticleer,
That fools should be so deep-contemplative,
And I did laugh sans intermission
An hour by his dial. O noble fool!
A worthy fool! Motley's the only wear.
### DUKE SENIOR
What fool is this?
### JAQUES
O worthy fool! One that hath been a courtier,
And says, if ladies be but young and fair,
They have the gift to know it: and in his brain,
Which is as dry as the remainder biscuit
After a voyage, he hath strange places cramm'd
With observation, the which he vents
In mangled forms. O that I were a fool!
I am ambitious for a motley coat.
### DUKE SENIOR
Thou shalt have one.
### JAQUES
It is my only suit;
Provided that you weed your better judgments
Of all opinion that grows rank in them
That I am wise. I must have liberty
Withal, as large a charter as the wind,
To blow on whom I please; for so fools have;
And they that are most galled with my folly,
They most must laugh. And why, sir, must they so?
The 'why' is plain as way to parish church:
He that a fool doth very wisely hit
Doth very foolishly, although he smart,
Not to seem senseless of the bob: if not,
The wise man's folly is anatomized
Even by the squandering glances of the fool.
Invest me in my motley; give me leave
To speak my mind, and I will through and through
Cleanse the foul body of the infected world,
If they will patiently receive my medicine.
### DUKE SENIOR
Fie on thee! I can tell what thou wouldst do.
### JAQUES
What, for a counter, would I do but good?
### DUKE SENIOR
Most mischievous foul sin, in chiding sin:
For thou thyself hast been a libertine,
As sensual as the brutish sting itself;
And all the embossed sores and headed evils,
That thou with licence of free foot hast caught,
Wouldst thou disgorge into the general world.
### JAQUES
Why, who cries out on pride,
That can therein tax any private party?
Doth it not flow as hugely as the sea,
Till that the weary very means do ebb?
What woman in the city do I name,
When that I say the city-woman bears
The cost of princes on unworthy shoulders?
Who can come in and say that I mean her,
When such a one as she such is her neighbour?
Or what is he of basest function
That says his bravery is not of my cost,
Thinking that I mean him, but therein suits
His folly to the mettle of my speech?
There then; how then? what then? Let me see wherein
My tongue hath wrong'd him: if it do him right,
Then he hath wrong'd himself; if he be free,
Why then my taxing like a wild-goose flies,
Unclaim'd of any man. But who comes here?
Enter ORLANDO, with his sword drawn
### ORLANDO
Forbear, and eat no more.
### JAQUES
Why, I have eat none yet.
### ORLANDO
Nor shalt not, till necessity be served.
### JAQUES
Of what kind should this cock come of?
### DUKE SENIOR
Art thou thus bolden'd, man, by thy distress,
Or else a rude despiser of good manners,
That in civility thou seem'st so empty?
### ORLANDO
You touch'd my vein at first: the thorny point
Of bare distress hath ta'en from me the show
Of smooth civility: yet am I inland bred
And know some nurture. But forbear, I say:
He dies that touches any of this fruit
Till I and my affairs are answered.
### JAQUES
An you will not be answered with reason, I must die.
### DUKE SENIOR
What would you have? Your gentleness shall force
More than your force move us to gentleness.
### ORLANDO
I almost die for food; and let me have it.
### DUKE SENIOR
Sit down and feed, and welcome to our table.
### ORLANDO
Speak you so gently? Pardon me, I pray you:
I thought that all things had been savage here;
And therefore put I on the countenance
Of stern commandment. But whate'er you are
That in this desert inaccessible,
Under the shade of melancholy boughs,
Lose and neglect the creeping hours of time
If ever you have look'd on better days,
If ever been where bells have knoll'd to church,
If ever sat at any good man's feast,
If ever from your eyelids wiped a tear
And know what 'tis to pity and be pitied,
Let gentleness my strong enforcement be:
In the which hope I blush, and hide my sword.
### DUKE SENIOR
True is it that we have seen better days,
And have with holy bell been knoll'd to church
And sat at good men's feasts and wiped our eyes
Of drops that sacred pity hath engender'd:
And therefore sit you down in gentleness
And take upon command what help we have
That to your wanting may be minister'd.
### ORLANDO
Then but forbear your food a little while,
Whiles, like a doe, I go to find my fawn
And give it food. There is an old poor man,
Who after me hath many a weary step
Limp'd in pure love: till he be first sufficed,
Oppress'd with two weak evils, age and hunger,
I will not touch a bit.
### DUKE SENIOR
Go find him out,
And we will nothing waste till you return.
### ORLANDO
I thank ye; and be blest for your good comfort!
Exit
### DUKE SENIOR
Thou seest we are not all alone unhappy:
This wide and universal theatre
Presents more woeful pageants than the scene
Wherein we play in.
[/INST]

View File

@@ -0,0 +1,8 @@
#!/usr/bin/env bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
echo "Downloading Tokenizer"
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral_tokenizer.model
echo "Downloading Model"
curl --location https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q8_0.gguf?download=true --output $SCRIPT_DIR/mistral-7b-instruct-v0.2.Q8_0.gguf
echo "Done Downloading Model"

View File

@@ -0,0 +1,302 @@
//! Support for the GGUF file format.
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
pub const DEFAULT_ALIGNMENT: u64 = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Gguf,
}
impl TryFrom<u32> for Magic {
type Error = ();
fn try_from(value: u32) -> Result<Self, ()> {
let magic = match value {
0x46554747 | 0x47475546 => Self::Gguf,
_ => panic!("unknown magic 0x{value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
pub fn read<R: std::io::Read>(reader: &mut R) -> Result<Self, ()> {
let magic = reader.read_u32::<LittleEndian>().unwrap();
let magic = Magic::try_from(magic).unwrap();
let version = reader.read_u32::<LittleEndian>().unwrap();
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => panic!("gguf: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub metadata: HashMap<String, Value>,
pub tensor_infos: HashMap<String, (usize, usize, GgmlDType)>, // buffer size and offset
pub tensor_data_offset: u64,
}
pub fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String, ()> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>().unwrap() as usize
}
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v).unwrap();
// GGUF strings are supposed to be non-null terminated but in practice this happens.
while let Some(0) = v.last() {
v.pop();
}
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
Ok(String::from_utf8_lossy(&v).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ValueType {
// The value is a 8-bit unsigned integer.
U8,
// The value is a 8-bit signed integer.
I8,
// The value is a 16-bit unsigned little-endian integer.
U16,
// The value is a 16-bit signed little-endian integer.
I16,
// The value is a 32-bit unsigned little-endian integer.
U32,
// The value is a 32-bit signed little-endian integer.
I32,
// The value is a 64-bit unsigned little-endian integer.
U64,
// The value is a 64-bit signed little-endian integer.
I64,
// The value is a 32-bit IEEE754 floating point number.
F32,
// The value is a 64-bit IEEE754 floating point number.
F64,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
Bool,
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
#[derive(Debug, Clone)]
pub enum Value {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
String(String),
Array(Vec<Value>),
}
impl Value {
pub fn read<R: std::io::Read>(
reader: &mut R,
value_type: ValueType,
magic: &VersionedMagic,
) -> Result<Self, ()> {
let v = match value_type {
ValueType::U8 => Self::U8(reader.read_u8().unwrap()),
ValueType::I8 => Self::I8(reader.read_i8().unwrap()),
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>().unwrap()),
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>().unwrap()),
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>().unwrap()),
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>().unwrap()),
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>().unwrap()),
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>().unwrap()),
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>().unwrap()),
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>().unwrap()),
ValueType::Bool => match reader.read_u8().unwrap() {
0 => Self::Bool(false),
1 => Self::Bool(true),
b => panic!("unexpected bool value {b}"),
},
ValueType::String => Self::String(read_string(reader, magic).unwrap()),
ValueType::Array => {
let value_type = reader.read_u32::<LittleEndian>().unwrap();
let value_type = ValueType::from_u32(value_type).unwrap();
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>().unwrap() as usize
}
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
vs.push(Value::read(reader, value_type, magic).unwrap())
}
Self::Array(vs)
}
};
Ok(v)
}
}
impl ValueType {
pub fn from_u32(v: u32) -> Result<Self, ()> {
let v = match v {
0 => Self::U8,
1 => Self::I8,
2 => Self::U16,
3 => Self::I16,
4 => Self::U32,
5 => Self::I32,
6 => Self::F32,
7 => Self::Bool,
8 => Self::String,
9 => Self::Array,
10 => Self::U64,
11 => Self::I64,
12 => Self::F64,
v => panic!("unrecognized value-type {v:#08x}"),
};
Ok(v)
}
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self, ()> {
let magic = VersionedMagic::read(reader).unwrap();
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>().unwrap() as usize
}
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>().unwrap() as usize
}
};
// Read metadata
let mut metadata = HashMap::new();
for _idx in 0..metadata_kv_count {
let key = read_string(reader, &magic).unwrap();
let value_type = reader.read_u32::<LittleEndian>().unwrap();
let value_type = ValueType::from_u32(value_type).unwrap();
let value = Value::read(reader, value_type, &magic).unwrap();
metadata.insert(key, value);
}
// Read tensor infos
let mut tensor_infos = HashMap::new();
for _idx in 0..tensor_count {
let tensor_name = read_string(reader, &magic).unwrap();
let n_dimensions = reader.read_u32::<LittleEndian>().unwrap();
let n_elements = match magic {
VersionedMagic::GgufV1 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader
.read_u32_into::<LittleEndian>(&mut dimensions)
.unwrap();
dimensions.into_iter().map(|c| c as usize).product()
}
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader
.read_u64_into::<LittleEndian>(&mut dimensions)
.unwrap();
dimensions.into_iter().map(|c| c as usize).product()
}
};
let ggml_dtype = reader.read_u32::<LittleEndian>().unwrap();
let offset = reader.read_u64::<LittleEndian>().unwrap();
tensor_infos.insert(
tensor_name,
(n_elements, offset as usize, GgmlDType::from_u32(ggml_dtype)),
);
}
let position = reader.stream_position().unwrap();
let alignment = match metadata.get("general.alignment") {
Some(Value::U8(v)) => *v as u64,
Some(Value::U16(v)) => *v as u64,
Some(Value::U32(v)) => *v as u64,
Some(Value::I8(v)) if *v >= 0 => *v as u64,
Some(Value::I16(v)) if *v >= 0 => *v as u64,
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
Ok(Self {
magic,
metadata,
tensor_infos,
tensor_data_offset,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
}
impl GgmlDType {
fn from_u32(u: u32) -> Self {
match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
_ => panic!("unknown dtype for tensor {u}"),
}
}
}

View File

@@ -0,0 +1,173 @@
use std::{
fs::File,
io::{Read, Seek},
};
use itertools::Itertools;
use luminal::{op::Function, prelude::*};
use crate::gguf::*;
#[cfg(feature = "metal")]
use {
luminal_metal::MetalBuffer,
memmap2::Mmap,
metal_rs::{Device, MTLResourceOptions},
};
#[cfg(feature = "metal")]
pub struct MetalQ8Loader(String);
#[cfg(feature = "metal")]
impl MetalQ8Loader {
pub fn new<S: Into<String>>(path: S) -> Self {
Self(path.into())
}
}
#[cfg(feature = "metal")]
impl Loader for MetalQ8Loader {
type Output = Vec<NodeIndex>;
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) -> Self::Output {
// Read metadata from file
let mut reader = File::open(&self.0).unwrap();
let Content {
mut tensor_infos,
tensor_data_offset,
..
} = Content::read(&mut reader).unwrap();
// Create weight loading closures
let mut q8_weights = vec![];
for (weight_name, node_index) in state_dict(model) {
if let Some(loading_node) = graph
.graph
.node_weight_mut(node_index)
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
{
let file_path = self.0.clone();
let (n_elements, buffer_offset, data_type) =
tensor_infos.remove(&weight_name.replace('/', ".")).unwrap();
let n_bytes = match data_type {
GgmlDType::F32 => n_elements * 4,
GgmlDType::Q8_0 => {
q8_weights.push(node_index);
n_elements + (n_elements / 16)
}
_ => panic!("Unsupported dtype: {data_type:?}"),
};
loading_node.1 = Box::new(move |_| {
let mmap_buffer =
unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() };
let buffer = Device::system_default().unwrap().new_buffer_with_data(
unsafe {
mmap_buffer
.as_ptr()
.add(buffer_offset + tensor_data_offset as usize)
as *const _
},
n_bytes as u64,
MTLResourceOptions::StorageModeShared,
);
vec![Tensor {
data: Box::new(MetalBuffer(buffer)),
}]
});
}
}
q8_weights
}
}
#[cfg(not(feature = "metal"))]
pub struct Q8Loader(String);
#[cfg(not(feature = "metal"))]
impl Q8Loader {
pub fn new<S: Into<String>>(path: S) -> Self {
Self(path.into())
}
}
#[cfg(not(feature = "metal"))]
impl Loader for Q8Loader {
type Output = Vec<NodeIndex>;
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) -> Self::Output {
#[repr(C, packed)]
#[derive(Clone, Copy)]
struct Q8Block {
delta: f16,
weights: [i8; 32],
}
// Read metadata from file
let mut reader = File::open(&self.0).unwrap();
let Content {
mut tensor_infos,
tensor_data_offset,
..
} = Content::read(&mut reader).unwrap();
// Create weight loading closures
let mut q8_weights = vec![];
for (weight_name, node_index) in state_dict(model) {
if let Some(loading_node) = graph
.graph
.node_weight_mut(node_index)
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
{
let file_path = self.0.clone();
let (n_elements, buffer_offset, data_type) =
tensor_infos.remove(&weight_name.replace('/', ".")).unwrap();
let n_bytes = match data_type {
GgmlDType::F32 => n_elements * 4,
GgmlDType::Q8_0 => {
q8_weights.push(node_index);
n_elements + (n_elements / 16)
}
_ => panic!("Unsupported dtype: {data_type:?}"),
};
loading_node.1 = Box::new(move |_| {
// Load all bytes
let mut bytes = vec![0; n_bytes];
let mut file = File::open(&file_path).unwrap();
file.seek(std::io::SeekFrom::Start(
buffer_offset as u64 + tensor_data_offset,
))
.unwrap();
file.read_exact(&mut bytes).unwrap();
// Dequantize into f32
let data: Vec<f32> = match data_type {
GgmlDType::F32 => bytes
.into_iter()
.chunks(4)
.into_iter()
.map(|c| {
let c = c.collect::<Vec<_>>();
f32::from_le_bytes([c[0], c[1], c[2], c[3]])
})
.collect(),
GgmlDType::Q8_0 => bytes
.into_iter()
.chunks(34)
.into_iter()
.map(|c| {
let chunk = c.collect::<Vec<_>>();
unsafe { chunk.align_to::<Q8Block>().1[0] }
})
.flat_map(|chunk| {
chunk
.weights
.into_iter()
.map(move |i| i as f32 * chunk.delta.to_f32())
})
.collect(),
_ => panic!("Unsupported dtype: {data_type:?}"),
};
vec![Tensor::new(data)]
});
}
}
q8_weights
}
}

View File

@@ -0,0 +1,185 @@
use std::{
io::{self, Write},
marker::PhantomData,
time::Instant,
};
use clap::Parser;
use colored::Colorize;
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
mod gguf;
mod loader;
mod model;
use crate::model::KVCache;
use luminal::{prelude::*, shape::symbolic::Expression};
// Command args parser
#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
pub struct CLIArgs {
/// Number of tokens to generate
#[clap(short = 't', long = "gen_tokens", default_value = "128")]
gen_tokens: i32,
/// Prompt for the model
#[clap(short = 'p', long = "prompt", default_value = include_str!("../prompts/merge_sort.txt"))]
prompt: String,
}
fn main() {
let cli_args = CLIArgs::parse();
let tokenizer =
SentencePieceBpeTokenizer::from_file("setup/mistral_tokenizer.model", false).unwrap();
print!("Defining graph");
io::stdout().flush().unwrap();
let now = Instant::now();
// Set up graph
let mut cx = Graph::new();
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
.collect();
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
let model = model::MistralLM::initialize(&mut cx);
let (logits, mut cache_dest) =
model.forward((input, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve();
cache_dest.keep();
// Set up model loading
#[cfg(feature = "metal")]
let quantized_weight_nodes =
loader::MetalQ8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf")
.load(&model, &mut cx);
#[cfg(not(feature = "metal"))]
loader::Q8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf").load(&model, &mut cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
print!("Compiling graph");
io::stdout().flush().unwrap();
let now = Instant::now();
cx.compile(
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::MetalQuantizedCompiler::<f32>::new(quantized_weight_nodes),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal::compilers::CPUCompiler::default(),
),
(&mut input, &mut logits, &mut cache_src, &mut cache_dest),
);
// Keep model weights
let model_weights = downstream(state_set(&model), &cx);
cx.keep_tensors(&model_weights);
let cache_src_set = downstream(&cache_src, &cx);
let cache_dest_set = cache_dest.to_ids();
println!("\t\t - {}ms", now.elapsed().as_millis());
// Initial forward pass to load weights
print!("Loading model");
io::stdout().flush().unwrap();
let now = Instant::now();
input.set_dyn(vec![0.], &[1, 1]);
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
cache_dest.drop();
println!("\t\t - {}ms", now.elapsed().as_millis());
// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&model_weights, &mut cx);
// Run prompt processing pass
let mut input_ids = encode(&tokenizer, &cli_args.prompt);
input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
&[1, input_ids.len()],
);
cx.set_dyn_dim('t', input_ids.len());
print!("Processing Prompt");
io::stdout().flush().unwrap();
let now = Instant::now();
cx.execute();
let elapsed_ms = now.elapsed().as_millis();
println!(
"\t - {elapsed_ms}ms ({:.2} tok/s)",
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
);
delete_inputs(&cache_src_set, &mut cx);
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
// Decode token
print!(
"{}{}",
cli_args.prompt.white().bold(),
decode(&tokenizer, &[output_id]).bright_green()
);
io::stdout().flush().unwrap();
// Swap caches
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
// Decode loop
let mut token_decode_times = vec![];
for _ in 0..cli_args.gen_tokens {
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
cx.set_dyn_dim('p', input_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len());
let now = Instant::now();
cx.execute();
token_decode_times.push(now.elapsed().as_micros());
// Sample tokens
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
io::stdout().flush().unwrap();
// Swap caches
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
}
let avg_token_time = token_decode_times
.iter()
.map(|t| *t as f32 / 1000.)
.sum::<f32>()
/ token_decode_times.len() as f32;
println!(
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
avg_token_time,
1000.0 / avg_token_time
);
}
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
}
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
}
// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> i64 {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0 as i64
}

View File

@@ -0,0 +1,350 @@
use std::{marker::PhantomData, ops::Div};
use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
prelude::*,
shape::symbolic::{BigExpression, Expression},
};
// Mistral 7B Config
pub const VOCAB_SIZE: usize = 32000;
pub const HIDDEN_DIM: usize = 4096;
pub const NUM_LAYERS: usize = 32;
pub const N_HEADS: usize = 32;
pub const N_KV_HEADS: usize = 8;
pub const MLP_DIM: usize = 14336;
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
}
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
where
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
{
type Output = GraphTensor<Sh>;
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
let gate = input.matmul(self.gate_proj.permute()).swish();
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
Self {
gate_proj: cx.named_tensor("Gate Weight"),
up_proj: cx.named_tensor("Up Weight"),
down_proj: cx.named_tensor("Down Weight"),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
fn serialize(&self, s: &mut Serializer) {
s.tensor("ffn_gate/weight", self.gate_proj);
s.tensor("ffn_up/weight", self.up_proj);
s.tensor("ffn_down/weight", self.down_proj);
}
}
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
prev_seq: BigExpression,
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
// Get freqs
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = freqs.inv_pow(1000000.0).recip();
let pos = input.graph().arange::<Seq>() + prev_seq;
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
// Split input into evens and odds
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> = split
.slice((.., .., .., .., ..Expression::from(1)))
.contiguous()
.realize();
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> = split
.slice((.., .., .., .., Expression::from(1)..))
.contiguous()
.realize();
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
// Combine back into output
x0_out
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
x1_out,
)
.reshape()
}
pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
)> for SelfAttention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().into());
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().into());
// Add KV cache
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
(
k_cache.concat_along::<_, Axis<2>, _>(keys),
v_cache.concat_along::<_, Axis<2>, _>(values),
)
} else {
(keys.realize(), values.contiguous().realize())
};
// Repeat the KV States for Grouped-Query Attention
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
// Calculate attention weights
let mut attention_weights = queries
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
.matmul(repeated_keys.permute())
.div((HEAD_DIM as f32).sqrt());
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq), _, _>(&[
(0.into(), Expression::from(0)),
(TotSeq::const_size() - CurSeq::const_size(), 0.into()),
])
.expand();
// Calculate final outputs
let output = attention_weights
.softmax::<4>()
// Apply distribution to values
.matmul(repeated_values)
// Merge heads
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
let output = output
// Apply output projection
.matmul(self.o_proj.permute());
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
}
}
impl InitModule for SelfAttention {
fn initialize(cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Q Proj"),
k_proj: cx.named_tensor("K Proj"),
v_proj: cx.named_tensor("V Proj"),
o_proj: cx.named_tensor("O Proj"),
}
}
}
impl SerializeModule for SelfAttention {
fn serialize(&self, s: &mut Serializer) {
s.tensor("attn_q/weight", self.q_proj);
s.tensor("attn_v/weight", self.v_proj);
s.tensor("attn_k/weight", self.k_proj);
s.tensor("attn_output/weight", self.o_proj);
}
}
pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: RMSNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: RMSNorm<HIDDEN_DIM>,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Option<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
// Attention
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));
// Residual Addition
x += y;
// Feed Forward
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
// Residual Addition
(x + y, cache)
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: {
let mut norm = RMSNorm::initialize(cx);
norm.epsilon = 1e-5;
norm
},
feed_forward: InitModule::initialize(cx),
feed_forward_norm: {
let mut norm = RMSNorm::initialize(cx);
norm.epsilon = 1e-5;
norm
},
}
}
}
impl SerializeModule for TransformerBlock {
fn serialize(&self, s: &mut Serializer) {
s.module("", &self.attention);
s.module("attn_norm", &self.attention_norm);
s.module("ffn_norm", &self.feed_forward_norm);
s.module("", &self.feed_forward);
}
}
pub struct MistralLM {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: RMSNorm<HIDDEN_DIM>,
// LM Head Layer
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq>>>,
PhantomData<TotSeq>,
)> for MistralLM
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq)>,
Option<Vec<KVCache<Batch, PrevSeq>>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
// Embed tokens
let mut x = self.embedding.forward(input);
// Run through layers and collect new caches
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) =
layer.forward((x, cache.as_ref().map(|c| c[i]), PhantomData::<TotSeq>));
new_caches.push(new_cache);
}
// Run through last norm and output projection
let output = self.norm.forward(x).matmul(self.lm_head.permute());
(output, new_caches)
}
}
impl InitModule for MistralLM {
fn initialize(cx: &mut Graph) -> Self {
Self {
embedding: InitModule::initialize(cx),
norm: {
let mut norm = RMSNorm::initialize(cx);
norm.epsilon = 1e-5;
norm
},
lm_head: cx.named_tensor("LM Head"),
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
}
}
}
impl SerializeModule for MistralLM {
fn serialize(&self, s: &mut Serializer) {
s.module("token_embd", &self.embedding);
s.module("output_norm", &self.norm);
s.tensor("output/weight", self.lm_head);
for (i, layer) in self.layers.iter().enumerate() {
s.module(&format!("blk/{i}"), layer);
}
}
}

View File

@@ -1,17 +1,16 @@
use luminal::{nn::linear::Linear, prelude::*};
fn main() {
// Create a new graph
let mut cx = Graph::new();
let model: Linear<4, 5> = InitModule::initialize(&mut cx);
let a = cx.new_tensor::<R1<4>>("Input");
let b = model.forward(a);
a.set(vec![1., 2., 3., 4.]);
b.mark();
cx.execute();
println!(
"B: {:?}",
b.retrieve().unwrap().real_data(b.view().unwrap()).unwrap()
);
// Randomly initialize a linear layer with an input size of 4 and an output size of 5
let model = Linear::<4, 5>::initialize(&mut cx);
// Make an input tensor
let a = cx.tensor::<R1<4>>().set(vec![1., 2., 3., 4.]);
// Feed tensor through model
let b = model.forward(a).retrieve();
// Execute the graph
cx.execute_debug();
// Print the results
println!("B: {:?}", b.data());
}

BIN
out.bin

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 57 KiB

View File

@@ -0,0 +1,13 @@
# These are supported funding model platforms
github: coreylowman
patreon: dfdx
open_collective: # Replace with a single Open Collective username
ko_fi: coreylowman
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -0,0 +1,23 @@
on: [pull_request]
jobs:
cargo-check:
name: cargo-check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --features ci-check
- uses: actions-rs/cargo@v1
with:
command: check
args: --no-default-features --features ci-check,no-std,cudnn,cublas,cublaslt,nvrtc,driver,curand,nccl

View File

@@ -0,0 +1,18 @@
on: [pull_request]
jobs:
clippy:
name: clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --no-default-features --features ci-check,no-std,cudnn,cublas,cublaslt,nvrtc,driver,curand,nccl -- -D warnings

View File

@@ -0,0 +1,19 @@
on: [pull_request]
jobs:
cargo-fmt:
name: cargo-fmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check

3
resources/luminal_cudarc/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/Cargo.lock
/target
/.vscode/

View File

View File

@@ -0,0 +1,40 @@
[package]
name = "luminal_cudarc"
version = "0.10.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Safe wrappers around CUDA apis"
readme = "README.md"
keywords = [
"cuda",
"nvidia",
"gpu",
"nvrtc",
"cublas",
]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[package.metadata.docs.rs]
features = ["ci-check", "f16", "cudnn"]
[features]
default = ["std", "driver", "nvrtc", "cublas", "curand"]
nvrtc = []
driver = ["nvrtc"]
cublas = ["driver"]
cublaslt = ["driver"]
cudnn = ["driver"]
curand = ["driver"]
nccl = ["driver"]
std = []
no-std = ["no-std-compat/std", "dep:spin"]
f16 = ["dep:half"]
ci-check = []
static-linking=[]
[dependencies]
spin = { version = "0.9.8", optional = true, features = ["rwlock"], default-features = false }
no-std-compat = { version = "0.4.1", optional = true, features = [ "alloc" ] }
half = { version = "2.3.1", optional = true, default-features = false, features = ["num-traits", "rand_distr"] }

View File

@@ -174,28 +174,3 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,5 +1,3 @@
Copyright (c) 2015
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the

View File

@@ -0,0 +1,89 @@
# cudarc: minimal and safe api over the cuda toolkit
[![](https://dcbadge.vercel.app/api/server/AtUhGqBDP5)](https://discord.gg/AtUhGqBDP5)
[![crates.io](https://img.shields.io/crates/v/cudarc?style=for-the-badge)](https://crates.io/crates/cudarc)
[![docs.rs](https://img.shields.io/docsrs/cudarc?label=docs.rs%20latest&style=for-the-badge)](https://docs.rs/cudarc)
Checkout cudarc on [crates.io](https://crates.io/crates/cudarc) and [docs.rs](https://docs.rs/cudarc/latest/cudarc/).
Safe abstractions over:
1. [CUDA driver API](https://docs.nvidia.com/cuda/cuda-driver-api/index.html)
2. [NVRTC API](https://docs.nvidia.com/cuda/nvrtc/index.html)
3. [cuRAND API](https://docs.nvidia.com/cuda/curand/index.html)
4. [cuBLAS API](https://docs.nvidia.com/cuda/cublas/index.html)
5. [cuBLASLt API](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
**Pre-alpha state**, expect breaking changes and not all cuda functions
contain a safe wrapper. **Contributions welcome for any that aren't included!**
# Design
Goals are:
1. As safe as possible (there will still be a lot of unsafe due to ffi & async)
2. As ergonomic as possible
3. Allow mixing of high level `safe` apis, with low level `sys` apis
To that end there are three levels to each wrapper (by default the safe api is exported):
```rust
use cudarc::driver::{safe, result, sys};
use cudarc::nvrtc::{safe, result, sys};
use cudarc::cublas::{safe, result, sys};
use cudarc::cublaslt::{safe, result, sys};
use cudarc::curand::{safe, result, sys};
```
where:
1. `sys` is the raw ffi apis generated with bindgen
2. `result` is a very small wrapper around sys to return `Result` from each function
3. `safe` is a wrapper around result/sys to provide safe abstractions
*Heavily recommend sticking with safe APIs*
# API Preview
It's easy to create a new device and transfer data to the gpu:
```rust
let dev = cudarc::driver::CudaDevice::new(0)?;
// allocate buffers
let inp = dev.htod_copy(vec![1.0f32; 100])?;
let mut out = dev.alloc_zeros::<f32>(100)?;
```
You can also use the nvrtc api to compile kernels at runtime:
```rust
let ptx = cudarc::nvrtc::compile_ptx("
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, const size_t numel) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}")?;
// and dynamically load it into the device
dev.load_ptx(ptx, "my_module", &["sin_kernel"])?;
```
`cudarc` provides a very simple interface to launch kernels, tuples
are the arguments!
```rust
let sin_kernel = dev.get_func("my_module", "sin_kernel").unwrap();
let cfg = LaunchConfig::for_num_elems(100);
unsafe { sin_kernel.launch(cfg, (&mut out, &inp, 100usize)) }?;
```
And of course it's easy to copy things back to host after you're done:
```rust
let out_host: Vec<f32> = dev.dtoh_sync_copy(&out)?;
assert_eq!(out_host, [1.0; 100].map(f32::sin));
```
# License
Dual-licensed to be compatible with the Rust project.
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.

View File

@@ -0,0 +1,130 @@
use std::path::{Path, PathBuf};
fn main() {
println!("cargo:rerun-if-changed=build.rs");
#[cfg(not(feature = "ci-check"))]
link_cuda();
}
#[allow(unused)]
fn link_cuda() {
println!("cargo:rerun-if-env-changed=CUDA_ROOT");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_TOOLKIT_ROOT_DIR");
let candidates: Vec<PathBuf> = root_candidates().collect();
let toolkit_root = root_candidates()
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap_or_else(|| {
panic!(
"Unable to find `include/cuda.h` under any of: {:?}. Set the `CUDA_ROOT` environment variable to `$CUDA_ROOT/include/cuda.h` to override path.",
candidates
)
});
for path in lib_candidates(&toolkit_root) {
println!("cargo:rustc-link-search=native={}", path.display());
}
#[cfg(feature = "driver")]
println!("cargo:rustc-link-lib=dylib=cuda");
#[cfg(feature = "nccl")]
println!("cargo:rustc-link-lib=dylib=nccl");
#[cfg(feature = "static-linking")]
{
println!("cargo:rustc-link-lib=dylib=stdc++");
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
{
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=static=cublasLt_static");
}
#[cfg(feature = "cublas")]
println!("cargo:rustc-link-lib=static=cublas_static");
#[cfg(feature = "curand")]
{
println!("cargo:rustc-link-lib=dylib=culibos");
println!("cargo:rustc-link-lib=static=curand_static");
}
#[cfg(feature = "nvrtc")]
{
println!("cargo:rustc-link-lib=static=nvrtc_static");
println!("cargo:rustc-link-lib=static=nvptxcompiler_static");
println!("cargo:rustc-link-lib=static=nvrtc-builtins_static");
}
}
#[cfg(not(feature = "static-linking"))]
{
#[cfg(feature = "nvrtc")]
println!("cargo:rustc-link-lib=dylib=nvrtc");
#[cfg(feature = "curand")]
println!("cargo:rustc-link-lib=dylib=curand");
#[cfg(feature = "cublas")]
println!("cargo:rustc-link-lib=dylib=cublas");
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
println!("cargo:rustc-link-lib=dylib=cublasLt");
}
#[cfg(feature = "cudnn")]
{
let cudnn_root = root_candidates()
.find(|path| path.join("include").join("cudnn.h").is_file())
.unwrap_or_else(|| {
panic!(
"Unable to find `include/cudnn.h` under any of: {:?}. Set the `CUDNN_LIB` environment variable to `$CUDNN_LIB/include/cudnn.h` to override path.",
candidates
)
});
for path in lib_candidates(&cudnn_root) {
println!("cargo:rustc-link-search=native={}", path.display());
}
}
#[cfg(feature = "cudnn")]
println!("cargo:rustc-link-lib=dylib=cudnn");
}
fn root_candidates() -> impl Iterator<Item = PathBuf> {
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::into);
env_vars.chain(roots).map(Into::<PathBuf>::into)
}
fn lib_candidates(root: &Path) -> Vec<PathBuf> {
[
"lib",
"lib/x64",
"lib/Win32",
"lib/x86_64",
"lib/x86_64-linux-gnu",
"lib64",
"lib64/stubs",
"targets/x86_64-linux",
"targets/x86_64-linux/lib",
"targets/x86_64-linux/lib/stubs",
]
.iter()
.map(|&p| root.join(p))
.filter(|p| p.is_dir())
.collect()
}

View File

@@ -0,0 +1,19 @@
use cudarc::driver::{CudaDevice, CudaSlice, DriverError};
fn main() -> Result<(), DriverError> {
let dev = CudaDevice::new(0)?;
// unsafe initialization of unset memory
let _: CudaSlice<f32> = unsafe { dev.alloc::<f32>(10) }?;
// this will have memory initialized as 0
let _: CudaSlice<f64> = dev.alloc_zeros::<f64>(10)?;
// initialize with a rust vec
let _: CudaSlice<usize> = dev.htod_copy(vec![0; 10])?;
// or finially, initialize with a slice. this is synchronous though.
let _: CudaSlice<u32> = dev.htod_sync_copy(&[1, 2, 3])?;
Ok(())
}

View File

@@ -0,0 +1,31 @@
use cudarc::driver::{CudaDevice, CudaSlice, DriverError};
fn main() -> Result<(), DriverError> {
let dev = CudaDevice::new(0)?;
let a: CudaSlice<f64> = dev.alloc_zeros::<f64>(10)?;
let mut b = dev.alloc_zeros::<f64>(10)?;
// you can do device to device copies of course
dev.dtod_copy(&a, &mut b)?;
// but also host to device copys with already allocated buffers
dev.htod_copy_into(vec![2.0; 10], &mut b)?;
// if you want to use slices, you can do synchronous copy
dev.htod_sync_copy_into(&[3.0; 10], &mut b)?;
// you can transfer back using reclaim:
let mut a_host: Vec<f64> = dev.sync_reclaim(a)?;
assert_eq!(a_host, [0.0; 10]);
// or copy back without losing ownership:
let b_host = dev.dtoh_sync_copy(&b)?;
assert_eq!(b_host, [3.0; 10]);
// or use a slice
dev.dtoh_sync_copy_into(&b, &mut a_host)?;
assert_eq!(a_host, b_host);
Ok(())
}

View File

@@ -0,0 +1,32 @@
use cudarc::{
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
nvrtc::Ptx,
};
fn main() -> Result<(), DriverError> {
let dev = CudaDevice::new(0)?;
// You can load a function from a pre-compiled PTX like so:
dev.load_ptx(Ptx::from_file("./examples/sin.ptx"), "sin", &["sin_kernel"])?;
// and then retrieve the function with `get_func`
let f = dev.get_func("sin", "sin_kernel").unwrap();
let a_host = [1.0, 2.0, 3.0];
let a_dev = dev.htod_copy(a_host.into())?;
let mut b_dev = a_dev.clone();
let n = 3;
let cfg = LaunchConfig::for_num_elems(n);
unsafe { f.launch(cfg, (&mut b_dev, &a_dev, n as i32)) }?;
let a_host_2 = dev.sync_reclaim(a_dev)?;
let b_host = dev.sync_reclaim(b_dev)?;
println!("Found {:?}", b_host);
println!("Expected {:?}", a_host.map(f32::sin));
assert_eq!(&a_host, a_host_2.as_slice());
Ok(())
}

View File

@@ -0,0 +1,42 @@
use cudarc::{
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
nvrtc::Ptx,
};
fn main() -> Result<(), DriverError> {
let dev = CudaDevice::new(0)?;
dev.load_ptx(Ptx::from_file("./examples/sin.ptx"), "sin", &["sin_kernel"])?;
let n = 3;
let cfg = LaunchConfig::for_num_elems(n);
let a_host = [1.0, 2.0, 3.0];
let a_dev = dev.htod_copy(a_host.into())?;
let mut b_dev = a_dev.clone();
// create a stream with `fork_default_stream()`
// This synchronizes with the default stream, so since
// we put this call **after** the `htod_copy` & `clone` above,
// cuda will complete those orders **before** work on this stream
// can start.
let stream = dev.fork_default_stream()?;
let f = dev.get_func("sin", "sin_kernel").unwrap();
// we launch it differently too
unsafe { f.launch_on_stream(&stream, cfg, (&mut b_dev, &a_dev, n as i32)) }?;
// and we must join with the default work stream in order for copies
// to work corrently.
// NOTE: this is actually async with respect to the host!
dev.wait_for(&stream)?;
let a_host_2 = dev.sync_reclaim(a_dev)?;
let b_host = dev.sync_reclaim(b_dev)?;
println!("Found {:?}", b_host);
println!("Expected {:?}", a_host.map(f32::sin));
assert_eq!(&a_host, a_host_2.as_slice());
Ok(())
}

View File

@@ -0,0 +1,52 @@
use cudarc::{driver::*, nvrtc::compile_ptx};
/// Here's the struct in rust, note that we have #[repr(C)]
/// here which allows us to pass it to cuda.
#[repr(C)]
struct MyCoolRustStruct {
a: f32,
b: f64,
c: u32,
d: usize,
}
/// We have to implement this to send it to cuda!
unsafe impl DeviceRepr for MyCoolRustStruct {}
const PTX_SRC: &str = "
// here's the same struct in cuda
struct MyCoolStruct {
float a;
double b;
unsigned int c;
size_t d;
};
extern \"C\" __global__ void my_custom_kernel(MyCoolStruct thing) {
assert(thing.a == 1.0);
assert(thing.b == 2.34);
assert(thing.c == 57);
assert(thing.d == 420);
}
";
fn main() -> Result<(), DriverError> {
let dev = CudaDevice::new(0)?;
let ptx = compile_ptx(PTX_SRC).unwrap();
dev.load_ptx(ptx, "module", &["my_custom_kernel"])?;
// try changing some of these values to see a device assert
let thing = MyCoolRustStruct {
a: 1.0,
b: 2.34,
c: 57,
d: 420,
};
let f = dev.get_func("module", "my_custom_kernel").unwrap();
// since MyCoolRustStruct implements DeviceRepr, we can pass it to launch.
unsafe { f.launch(LaunchConfig::for_num_elems(1), (thing,)) }?;
Ok(())
}

View File

@@ -0,0 +1,65 @@
use cudarc::driver::*;
use cudarc::nvrtc::compile_ptx;
use std::thread;
const KERNEL_SRC: &str = "
extern \"C\" __global__ void hello_world(int i) {
printf(\"Hello from the cuda kernel in thread %d\\n\", i);
}
";
fn main() -> Result<(), DriverError> {
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
{
// Option 1: use the same device on each thread.
// This requires calling the CudaDevice::bind_to_thread() method.
// Note that all kernels are submitted to the same stream/context,
// so the kernels will still execute in sequentially in the order
// they are submitted to the gpu.
let dev = CudaDevice::new(0)?;
let ptx = compile_ptx(KERNEL_SRC).unwrap();
dev.load_ptx(ptx, "kernel", &["hello_world"])?;
// explicit borrow so we don't have to re-clone the device for each thread
let dev = &dev;
thread::scope(move |s| {
for i in 0..10i32 {
s.spawn(move || {
// NOTE: this is the important call to have
// without this, you'll get a CUDA_ERROR_INVALID_CONTEXT
dev.bind_to_thread()?;
let f = dev.get_func("kernel", "hello_world").unwrap();
unsafe { f.launch(cfg, (i,)) }
});
}
});
}
{
// Option 2: create a new device in each thread
// This requires loading the PTX for each device, since they won't
// share a loaded modules on the Rust side of things.
let ptx = compile_ptx(KERNEL_SRC).unwrap();
thread::scope(|s| {
for i in 0..10i32 {
let ptx = ptx.clone();
s.spawn(move || {
let dev = CudaDevice::new(0)?;
dev.load_ptx(ptx, "kernel", &["hello_world"])?;
let f = dev.get_func("kernel", "hello_world").unwrap();
unsafe { f.launch(cfg, (i + 100,)) }
});
}
});
}
Ok(())
}

View File

@@ -0,0 +1 @@
target/

View File

@@ -0,0 +1,14 @@
[package]
name = "build-workflow"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
bindgen = "0.66.1"
cc = "1.0.82"
regex = "1.9.3"
[dependencies]
cudarc = { path = "../.." }

Some files were not shown because too many files have changed in this diff Show More