mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
774 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be4fb7dd9f | ||
|
|
6b2c216e45 | ||
|
|
d309b4f338 | ||
|
|
d562a0321b | ||
|
|
1b61fd2e4d | ||
|
|
e0c9b2b1ff | ||
|
|
d976f71585 | ||
|
|
00acf7aebe | ||
|
|
29751edf20 | ||
|
|
235db905da | ||
|
|
6337d90bce | ||
|
|
8c434b5081 | ||
|
|
84ba491b56 | ||
|
|
c8044504c5 | ||
|
|
9132ad8d94 | ||
|
|
c20b257657 | ||
|
|
89d9bbe105 | ||
|
|
2f34d413e1 | ||
|
|
a9875fde4d | ||
|
|
c19e211629 | ||
|
|
3e49033616 | ||
|
|
f0135920aa | ||
|
|
8776a1c3de | ||
|
|
b6022900b0 | ||
|
|
cc94802ed0 | ||
|
|
ed2fc61c73 | ||
|
|
ebbc86a312 | ||
|
|
b065cdd22b | ||
|
|
c7a0944eda | ||
|
|
b2d6a48eab | ||
|
|
ca71f0dd16 | ||
|
|
94306b086a | ||
|
|
7f2b9cf336 | ||
|
|
df1f5c3ca8 | ||
|
|
ffb7e4c706 | ||
|
|
598e303649 | ||
|
|
866dfb7804 | ||
|
|
1cfefed1ce | ||
|
|
d5c6ef451c | ||
|
|
e363385d3f | ||
|
|
258d1be49f | ||
|
|
18f15d98e2 | ||
|
|
0b5ce105a3 | ||
|
|
d66bf3412a | ||
|
|
5ede551fcb | ||
|
|
e10a19668c | ||
|
|
881fa13a13 | ||
|
|
e279912bb8 | ||
|
|
6dc2a996d2 | ||
|
|
7ee1cad15c | ||
|
|
b8a0f08cea | ||
|
|
edc96f3626 | ||
|
|
dd369f18a9 | ||
|
|
9326fe3cc8 | ||
|
|
3bd99c9f24 | ||
|
|
bd56364160 | ||
|
|
9547004247 | ||
|
|
647f119d3c | ||
|
|
8952443ebd | ||
|
|
5947e5cd3d | ||
|
|
10d94710f7 | ||
|
|
d13af7c562 | ||
|
|
c2bbe446da | ||
|
|
b0a732e5b0 | ||
|
|
59cf7998c9 | ||
|
|
a6f38be402 | ||
|
|
bc92e3137f | ||
|
|
30310a173d | ||
|
|
c00935b451 | ||
|
|
15e4ee6aa3 | ||
|
|
9ec1e75fe6 | ||
|
|
5898076da5 | ||
|
|
5b17c1880e | ||
|
|
1afea6bd86 | ||
|
|
8dff3619b9 | ||
|
|
111452a68e | ||
|
|
d147ed5063 | ||
|
|
162859dedb | ||
|
|
56de7fa4c3 | ||
|
|
7cc02dd51d | ||
|
|
e5963f1c9a | ||
|
|
9d32721ca7 | ||
|
|
bc6b8fb283 | ||
|
|
12381b2624 | ||
|
|
2821145268 | ||
|
|
959528efad | ||
|
|
6a5a45eeae | ||
|
|
4166e27055 | ||
|
|
f55cf6c0f7 | ||
|
|
6ddabf2995 | ||
|
|
54461a6d33 | ||
|
|
b5d6f424d9 | ||
|
|
f846af5901 | ||
|
|
f9c766dca7 | ||
|
|
218db50c79 | ||
|
|
3fddb7e5a8 | ||
|
|
7bd8de272b | ||
|
|
80915d3f3a | ||
|
|
791f1395d5 | ||
|
|
b5a13381a9 | ||
|
|
c64e408471 | ||
|
|
b1770a0b0e | ||
|
|
37dc4428af | ||
|
|
2d198b6be7 | ||
|
|
67e8e439c0 | ||
|
|
908d2c9222 | ||
|
|
c401a95af2 | ||
|
|
e2864d852f | ||
|
|
f043ba2d5e | ||
|
|
cf8412d3bf | ||
|
|
5b4bde0070 | ||
|
|
9fead8dad3 | ||
|
|
0d44507f3c | ||
|
|
3272749663 | ||
|
|
5f917dcbcf | ||
|
|
85a08aca3f | ||
|
|
192858edf1 | ||
|
|
9a5e6f6e69 | ||
|
|
6884bd010d | ||
|
|
9dd852c27e | ||
|
|
198fe76cb3 | ||
|
|
9696c4ce09 | ||
|
|
9a2f8fadd3 | ||
|
|
b59fefaa11 | ||
|
|
8348d06902 | ||
|
|
8f7f6a6ab3 | ||
|
|
13e6dc6da5 | ||
|
|
244711d46e | ||
|
|
9695bcef84 | ||
|
|
2f20b9959c | ||
|
|
308938ec02 | ||
|
|
b1c435b6be | ||
|
|
4219d8ec7b | ||
|
|
8bd7598678 | ||
|
|
e89bdbb612 | ||
|
|
ebb0df6c69 | ||
|
|
8f2d13df3d | ||
|
|
69c207b599 | ||
|
|
fa04b05b5d | ||
|
|
54912c4f6a | ||
|
|
1c0f525e57 | ||
|
|
26c0de512f | ||
|
|
0c27cb02a8 | ||
|
|
b822800ffe | ||
|
|
b54da0ddde | ||
|
|
9295ff8d72 | ||
|
|
e5dcff3f34 | ||
|
|
a1acd5883b | ||
|
|
556e386621 | ||
|
|
9f9256f08a | ||
|
|
f3c53c1193 | ||
|
|
9f668ee333 | ||
|
|
617ef95c09 | ||
|
|
c539946c25 | ||
|
|
7e9f1c7fc0 | ||
|
|
cf0e6ad2f6 | ||
|
|
9813b188f3 | ||
|
|
bf7c1c5608 | ||
|
|
ec09c0202b | ||
|
|
71365cf2d4 | ||
|
|
481d074f5a | ||
|
|
a240e2adc8 | ||
|
|
c3643925ef | ||
|
|
a6b368fa14 | ||
|
|
ab9df3d94e | ||
|
|
c727113351 | ||
|
|
d203df40d5 | ||
|
|
c506d1e783 | ||
|
|
56ce86f194 | ||
|
|
54a8ebc60d | ||
|
|
b3e07bd638 | ||
|
|
94a6a0a9e9 | ||
|
|
fb279c9ee6 | ||
|
|
3ae34ad3b3 | ||
|
|
6b08212df8 | ||
|
|
03d2d02d00 | ||
|
|
0f09b19199 | ||
|
|
fcf232699f | ||
|
|
1ed89b5656 | ||
|
|
69da97727b | ||
|
|
9edf9cdc0b | ||
|
|
2f13fd6100 | ||
|
|
ed278c9be3 | ||
|
|
9e04457895 | ||
|
|
e6c4291db6 | ||
|
|
f62e6ad85e | ||
|
|
0ba62fde38 | ||
|
|
d62f2e217a | ||
|
|
f385ea287e | ||
|
|
140ee69480 | ||
|
|
2c93b7788c | ||
|
|
4fdc8f38eb | ||
|
|
c0645fe35e | ||
|
|
5b5812defa | ||
|
|
349e3d2472 | ||
|
|
fa67608d48 | ||
|
|
527c20d146 | ||
|
|
ff1da67423 | ||
|
|
efd7489a1c | ||
|
|
4dd7cd7cfd | ||
|
|
33274b905e | ||
|
|
3670378bc6 | ||
|
|
275180be20 | ||
|
|
40a62e70be | ||
|
|
95462aa89e | ||
|
|
7a9f9e04d0 | ||
|
|
cf35b286f2 | ||
|
|
e1cf44a4e0 | ||
|
|
b891b8b595 | ||
|
|
67366e1a2f | ||
|
|
ee8206e2ca | ||
|
|
5cdc559241 | ||
|
|
daa7166534 | ||
|
|
2cf0bc29c8 | ||
|
|
139ae0ddad | ||
|
|
703f4d3847 | ||
|
|
d79042d334 | ||
|
|
f9b52f0058 | ||
|
|
5b50192830 | ||
|
|
ae431e0dd4 | ||
|
|
35626309ac | ||
|
|
a38168a91c | ||
|
|
64ebab654f | ||
|
|
ec0ea40bbe | ||
|
|
49ae10a25e | ||
|
|
1a1ba5216b | ||
|
|
0bbc6215d8 | ||
|
|
4e5300c4d4 | ||
|
|
166d4a12a5 | ||
|
|
e4f90c304b | ||
|
|
fa966c8c7c | ||
|
|
9a0261acd2 | ||
|
|
743bacb125 | ||
|
|
d0afd42eb2 | ||
|
|
4c9691c49d | ||
|
|
9aaff41dfa | ||
|
|
a8b6508155 | ||
|
|
a23e536fa0 | ||
|
|
e654f3e72d | ||
|
|
1a6ce5df82 | ||
|
|
a6cd8d9b0f | ||
|
|
8a62e090a3 | ||
|
|
b550de47e4 | ||
|
|
5bc2477352 | ||
|
|
370973108d | ||
|
|
88ed1ded6d | ||
|
|
e9b8a883d0 | ||
|
|
4a7db75715 | ||
|
|
72b3cba68b | ||
|
|
e7c78e9b46 | ||
|
|
0bc32b9c92 | ||
|
|
9b81ef2326 | ||
|
|
cfc8e7dae2 | ||
|
|
09666f93ab | ||
|
|
b489a86fa9 | ||
|
|
4d4338fb58 | ||
|
|
805ebb1931 | ||
|
|
a57b316216 | ||
|
|
94e08ae947 | ||
|
|
21aee96114 | ||
|
|
ac802a3273 | ||
|
|
70f4fff5c2 | ||
|
|
f2e1c17c8c | ||
|
|
9493c11a53 | ||
|
|
7c72d5b06f | ||
|
|
a15cfbae65 | ||
|
|
34ab545763 | ||
|
|
e67d3e6598 | ||
|
|
621536a1dd | ||
|
|
6d9f9176cd | ||
|
|
2e81b54446 | ||
|
|
38acdf315e | ||
|
|
30dff8597c | ||
|
|
2ebd5f2deb | ||
|
|
162b8c38a1 | ||
|
|
1fb155ddfd | ||
|
|
241b9f527b | ||
|
|
53dc4dd9df | ||
|
|
7c7558fcb3 | ||
|
|
5262e32346 | ||
|
|
664fad5f84 | ||
|
|
4c3e530ef3 | ||
|
|
d582111d04 | ||
|
|
e9384dc714 | ||
|
|
b97da50c9d | ||
|
|
517124b424 | ||
|
|
0ef3121ac6 | ||
|
|
542f74f404 | ||
|
|
8662ba864d | ||
|
|
0fc68006d5 | ||
|
|
eac3a57b6d | ||
|
|
f46bc1cb99 | ||
|
|
8f7004c4c3 | ||
|
|
185facb1d5 | ||
|
|
07d0febef1 | ||
|
|
d35a40eacb | ||
|
|
8a744e6035 | ||
|
|
50b47f8610 | ||
|
|
c1af144891 | ||
|
|
dd123fec89 | ||
|
|
92cca97a76 | ||
|
|
a5d01c7576 | ||
|
|
84fbf805c3 | ||
|
|
51545ee82c | ||
|
|
e16771035f | ||
|
|
10ee2c7343 | ||
|
|
f637fff192 | ||
|
|
3e0cafbae3 | ||
|
|
d6c9c977d8 | ||
|
|
1a0f59943e | ||
|
|
5032d894b8 | ||
|
|
2046ee9ade | ||
|
|
d48ac14458 | ||
|
|
24ff638e43 | ||
|
|
7bb4e856ec | ||
|
|
be93cfe817 | ||
|
|
140aeb4591 | ||
|
|
907dadc6a0 | ||
|
|
c9a1e5c47d | ||
|
|
f36d98363c | ||
|
|
c05c0e0575 | ||
|
|
123b48d5ec | ||
|
|
c4553fc132 | ||
|
|
b38be86191 | ||
|
|
da3970082a | ||
|
|
c2a11bf114 | ||
|
|
75a141d8ba | ||
|
|
ee17a48dbe | ||
|
|
02df7e7f8d | ||
|
|
22ae700048 | ||
|
|
dbe6a42018 | ||
|
|
7387ca1b19 | ||
|
|
9e03c3421f | ||
|
|
8a6d088ff3 | ||
|
|
305e8f104c | ||
|
|
472eae1576 | ||
|
|
6c234daba2 | ||
|
|
dfb8691923 | ||
|
|
2784738e41 | ||
|
|
b86b27e0c7 | ||
|
|
ba3faa49df | ||
|
|
c833a65153 | ||
|
|
cab6b2fff2 | ||
|
|
35e5da1ff4 | ||
|
|
67aac97299 | ||
|
|
2b884d6304 | ||
|
|
9fa0b8d0a5 | ||
|
|
422fd32d74 | ||
|
|
1d88be2001 | ||
|
|
3d5c3180be | ||
|
|
47b61ac847 | ||
|
|
18560d0852 | ||
|
|
1400aecf1d | ||
|
|
9e3bea8cac | ||
|
|
b4bf84840e | ||
|
|
941a8b93eb | ||
|
|
666cbe6c5a | ||
|
|
33b7f0914f | ||
|
|
2b2e06d6fa | ||
|
|
eaa4ad8ef5 | ||
|
|
750a6e9e8b | ||
|
|
0028b5ca78 | ||
|
|
d4b18a0e35 | ||
|
|
22d7c563cb | ||
|
|
4671708601 | ||
|
|
9b3948a3ff | ||
|
|
4c415fba7b | ||
|
|
f775833e10 | ||
|
|
b1b06b1e15 | ||
|
|
bf8f3d91d2 | ||
|
|
b40fb1a94b | ||
|
|
2e52833bb5 | ||
|
|
7e2518bbba | ||
|
|
808cf7849e | ||
|
|
1a454b23f8 | ||
|
|
bc4483706b | ||
|
|
8d0cff2b0b | ||
|
|
35097e8e2b | ||
|
|
1fdc8de899 | ||
|
|
995293e5da | ||
|
|
e9d7604f0b | ||
|
|
85824bb1ee | ||
|
|
652f0e365f | ||
|
|
554331f567 | ||
|
|
7311c8f48c | ||
|
|
9a904b6dcc | ||
|
|
8b4234eb60 | ||
|
|
b121bcb20b | ||
|
|
d63ceba488 | ||
|
|
67d8d6b992 | ||
|
|
0130d5dfd9 | ||
|
|
ab8f7187e6 | ||
|
|
0c291d594b | ||
|
|
1d828f7982 | ||
|
|
6f3d52f345 | ||
|
|
ed1f76808d | ||
|
|
a00fe78aa1 | ||
|
|
58a56f9fc0 | ||
|
|
2254b4c96c | ||
|
|
b6a0caa79b | ||
|
|
4e7c6c27ce | ||
|
|
4eb0a8e1fb | ||
|
|
e222cb7a97 | ||
|
|
858f198b43 | ||
|
|
4b5872b5d1 | ||
|
|
d2269eebf7 | ||
|
|
7f8b21f71f | ||
|
|
ca1703745f | ||
|
|
c3f2547349 | ||
|
|
5c24050775 | ||
|
|
927fb9fac2 | ||
|
|
167944b422 | ||
|
|
4cec36f4b5 | ||
|
|
66fbf23d67 | ||
|
|
7e401c69c7 | ||
|
|
5c4076bc8c | ||
|
|
d2cb4f0d48 | ||
|
|
ef16ee6b23 | ||
|
|
c518caacf2 | ||
|
|
acef1725f3 | ||
|
|
6cad14a20b | ||
|
|
c33333724d | ||
|
|
8149440f8f | ||
|
|
b4717747d5 | ||
|
|
9fc98f3288 | ||
|
|
24347bf69c | ||
|
|
c5aa4d2975 | ||
|
|
9ec05b25a8 | ||
|
|
bd83d880a9 | ||
|
|
8037d370ee | ||
|
|
4a0a86577e | ||
|
|
75ea980bd2 | ||
|
|
de5049577c | ||
|
|
5f99756be4 | ||
|
|
33724c7214 | ||
|
|
4f75032c7e | ||
|
|
a45b4b6e85 | ||
|
|
912db261fe | ||
|
|
fad53704fd | ||
|
|
29aeac0531 | ||
|
|
a426971470 | ||
|
|
1bd50bff21 | ||
|
|
d2d733b931 | ||
|
|
1ad6edd9ce | ||
|
|
d924809d85 | ||
|
|
24b1b324e6 | ||
|
|
531b28f75a | ||
|
|
ce40bb7f58 | ||
|
|
be667fb936 | ||
|
|
de2a2c8bb8 | ||
|
|
55e68dff43 | ||
|
|
e922d565a7 | ||
|
|
1763e85aa7 | ||
|
|
5d10422881 | ||
|
|
a004408327 | ||
|
|
cc0b34a640 | ||
|
|
21596a01d7 | ||
|
|
68f0c6f6ca | ||
|
|
3a2ab1d176 | ||
|
|
ff7289ef39 | ||
|
|
0b370359c4 | ||
|
|
d2b720da3f | ||
|
|
ef1054a921 | ||
|
|
0b30af2a7a | ||
|
|
5d97b4ee52 | ||
|
|
e179494ac4 | ||
|
|
db2fc3cbb0 | ||
|
|
935caa24ce | ||
|
|
268a9b2cf8 | ||
|
|
5d8238bcf4 | ||
|
|
a1c4f18725 | ||
|
|
19ec1f1d36 | ||
|
|
3032c685cd | ||
|
|
c890ebdbe1 | ||
|
|
a26d2fe86f | ||
|
|
312305fcb7 | ||
|
|
a402a29f93 | ||
|
|
ed964105ec | ||
|
|
414a3dcc83 | ||
|
|
c51e87385f | ||
|
|
fec403b9f5 | ||
|
|
90e06d90e5 | ||
|
|
25bf6ee63a | ||
|
|
9e5880b130 | ||
|
|
9220e7b1e0 | ||
|
|
4a97c8bee9 | ||
|
|
93d45509ad | ||
|
|
c0d0ec0c32 | ||
|
|
c4b4233e20 | ||
|
|
61e59b27ec | ||
|
|
d16d22492e | ||
|
|
3a25325d37 | ||
|
|
4599fec534 | ||
|
|
1637d0fdb8 | ||
|
|
13de77b68c | ||
|
|
700d8f71e2 | ||
|
|
84eea2a0eb | ||
|
|
b2be7b2583 | ||
|
|
5c396368b6 | ||
|
|
7335d07755 | ||
|
|
b63746fe84 | ||
|
|
ef964536e9 | ||
|
|
f96e3a903e | ||
|
|
7ec82a97d6 | ||
|
|
741b167910 | ||
|
|
268fb4e9aa | ||
|
|
3b0b264ba5 | ||
|
|
1c7a3b8ed9 | ||
|
|
7c307c886e | ||
|
|
e00a89c647 | ||
|
|
c6d37ed5c5 | ||
|
|
deef279977 | ||
|
|
835527333c | ||
|
|
b2735b8dc6 | ||
|
|
7050a8bd7a | ||
|
|
2c6ac7124e | ||
|
|
cc0c2bf8cb | ||
|
|
e335bb24df | ||
|
|
ef0768ebef | ||
|
|
a6c8c4c254 | ||
|
|
1f81ffb182 | ||
|
|
23b7937507 | ||
|
|
ac23472220 | ||
|
|
0e07eb7614 | ||
|
|
f6e2fd1be2 | ||
|
|
ddc6644a87 | ||
|
|
6d987df3e2 | ||
|
|
7b2fd581b6 | ||
|
|
6f810111c4 | ||
|
|
3b154540da | ||
|
|
0675610007 | ||
|
|
2ae67dd894 | ||
|
|
8623843e72 | ||
|
|
98ef29fec0 | ||
|
|
7d37b56c20 | ||
|
|
54c48df279 | ||
|
|
7460fcde9d | ||
|
|
fc2a56039a | ||
|
|
b29f8e3a0f | ||
|
|
3031ead6dc | ||
|
|
20951c0721 | ||
|
|
75b1064922 | ||
|
|
1a4135515b | ||
|
|
acbb1b6e2c | ||
|
|
c0632cb689 | ||
|
|
144e3b7a98 | ||
|
|
dfd21a343b | ||
|
|
0faadea621 | ||
|
|
0dd8f4b7c7 | ||
|
|
96e39c2535 | ||
|
|
909d5b7836 | ||
|
|
1125351f4c | ||
|
|
345622f452 | ||
|
|
53b9bd6e61 | ||
|
|
e7d0a08150 | ||
|
|
0939f50ce2 | ||
|
|
84d7a0cedc | ||
|
|
c9c540057b | ||
|
|
a2edbe14ec | ||
|
|
694fa93d30 | ||
|
|
4214a33525 | ||
|
|
404322b4ab | ||
|
|
afd3eeee88 | ||
|
|
84adc99c33 | ||
|
|
3f4b592c60 | ||
|
|
d61c848f6a | ||
|
|
94c7d00517 | ||
|
|
e799363d0d | ||
|
|
d0d7f74e42 | ||
|
|
de5835822d | ||
|
|
77fb4305e8 | ||
|
|
4cdb364e4a | ||
|
|
fbebf6d485 | ||
|
|
4fde0f4524 | ||
|
|
6f3cff1cd4 | ||
|
|
b90847c43f | ||
|
|
e5c7c8b2a2 | ||
|
|
9e453719e3 | ||
|
|
63b04f1e9a | ||
|
|
904baefa68 | ||
|
|
c82a00981a | ||
|
|
d1add4231f | ||
|
|
678591a1a5 | ||
|
|
f4a07f5259 | ||
|
|
e5e904498c | ||
|
|
89740bdd30 | ||
|
|
c6b72fa317 | ||
|
|
80b917b02f | ||
|
|
971361feac | ||
|
|
4a553724a2 | ||
|
|
b87b30f045 | ||
|
|
a3a69f53da | ||
|
|
cb659f3c25 | ||
|
|
8135540b22 | ||
|
|
3a10b6f4db | ||
|
|
82d4a96ae1 | ||
|
|
802091e15e | ||
|
|
4292259db1 | ||
|
|
b6efabf216 | ||
|
|
7e5471bdfa | ||
|
|
5fa5aff813 | ||
|
|
5035ad1d99 | ||
|
|
63797c90f9 | ||
|
|
8b475ea4f2 | ||
|
|
1cd4fb2e73 | ||
|
|
946ea8dfb8 | ||
|
|
8c264fb2a5 | ||
|
|
b96b792612 | ||
|
|
909ea995b6 | ||
|
|
4e197b512f | ||
|
|
aa3f8cce3d | ||
|
|
00dcc29eb1 | ||
|
|
032cec5c5a | ||
|
|
f0d6fedc90 | ||
|
|
bb9ff4f113 | ||
|
|
1c3f6735f8 | ||
|
|
2d210641d3 | ||
|
|
4078d895c7 | ||
|
|
d67820b6ba | ||
|
|
6869047b44 | ||
|
|
ba7c3972b5 | ||
|
|
f931504a09 | ||
|
|
52c18171a1 | ||
|
|
b89dbefb3c | ||
|
|
07936bc8e4 | ||
|
|
647eda7895 | ||
|
|
5bb703084c | ||
|
|
e7283e9105 | ||
|
|
7102d06e73 | ||
|
|
cf54dee88e | ||
|
|
854864ac5e | ||
|
|
3eceaae45f | ||
|
|
e604a8cba0 | ||
|
|
c351acb075 | ||
|
|
d880efc1db | ||
|
|
e118b293fd | ||
|
|
254996063a | ||
|
|
a85b2ac301 | ||
|
|
d2f8471943 | ||
|
|
b6a7a3bc1e | ||
|
|
3b3007cbdd | ||
|
|
adc2092275 | ||
|
|
96831f2d4e | ||
|
|
baf8664d10 | ||
|
|
d071fd5397 | ||
|
|
44f9415811 | ||
|
|
b104364edb | ||
|
|
24bbf0ead9 | ||
|
|
e8af292958 | ||
|
|
8d14f83bc3 | ||
|
|
2ff89167c2 | ||
|
|
87854bbdf0 | ||
|
|
a3a4a972d7 | ||
|
|
75fbb709d7 | ||
|
|
6a311347bf | ||
|
|
2787fdd8b6 | ||
|
|
634f5c26ee | ||
|
|
e7683ac3ff | ||
|
|
72fdc3bcfe | ||
|
|
271977d1dd | ||
|
|
e6de090ed3 | ||
|
|
65c0224ae5 | ||
|
|
e957a4c99a | ||
|
|
4e6d5b733c | ||
|
|
be591b2f4a | ||
|
|
bb90b73533 | ||
|
|
0abf5c2379 | ||
|
|
ccbf55923d | ||
|
|
eb842428b7 | ||
|
|
e61aa736db | ||
|
|
8794afb246 | ||
|
|
ddf32b6215 | ||
|
|
811fe65412 | ||
|
|
0ad73d19ed | ||
|
|
0b845dc7ee | ||
|
|
a0449b4d6b | ||
|
|
7e58e1f299 | ||
|
|
f1da8c3cb7 | ||
|
|
b87f0124b7 | ||
|
|
49db4cdea8 | ||
|
|
b72e0a2270 | ||
|
|
67965bc275 | ||
|
|
a8abee1422 | ||
|
|
4da5e94adf | ||
|
|
0dc1e71148 | ||
|
|
7d7972d54c | ||
|
|
41de512cdc | ||
|
|
07b2b1f28c | ||
|
|
2aec49d0e5 | ||
|
|
ffa50d43c5 | ||
|
|
ef06f5a746 | ||
|
|
aebdbe5ca8 | ||
|
|
acdcfc14fb | ||
|
|
a6b403e667 | ||
|
|
e5cfe80029 | ||
|
|
798ac9dd69 | ||
|
|
64a05e2f14 | ||
|
|
99f5843c42 | ||
|
|
3f2250e51f | ||
|
|
9a3de0103d | ||
|
|
1848ef4905 | ||
|
|
b8725ec9aa | ||
|
|
fbba2eb1db | ||
|
|
8554a1fcfc | ||
|
|
2f4e189f93 | ||
|
|
1bda13aec0 | ||
|
|
49cadac789 | ||
|
|
1fe9f3a068 | ||
|
|
edb102f7a2 | ||
|
|
97376b36bc | ||
|
|
41d88b0c4a | ||
|
|
ee70a44f8b | ||
|
|
1be715f322 | ||
|
|
519319c9b2 | ||
|
|
76abe671e4 | ||
|
|
14541394dc | ||
|
|
78a10f89ed | ||
|
|
f20a9fd2ed | ||
|
|
18eb48735d | ||
|
|
6274ba8169 | ||
|
|
a63bae227e | ||
|
|
6182590829 | ||
|
|
0922bcb903 | ||
|
|
6eb62664a5 | ||
|
|
dcb2072f36 | ||
|
|
c5bd1a9ce9 | ||
|
|
da1192bd01 | ||
|
|
ef3b917f5e | ||
|
|
2f32bcbb8f | ||
|
|
71adf60a71 | ||
|
|
8a1c51317c | ||
|
|
783d01dd6f | ||
|
|
8c0567146f | ||
|
|
8a4a98fa27 | ||
|
|
37b363a92f | ||
|
|
e9352d0506 | ||
|
|
6efdcdb2b9 | ||
|
|
eb1355c65a | ||
|
|
a135938588 | ||
|
|
aa3bf1ef51 | ||
|
|
dcaa13b20e | ||
|
|
c8ca146d0c | ||
|
|
8c73edb584 | ||
|
|
f6a52704d9 | ||
|
|
40814bc323 | ||
|
|
c652f8050a | ||
|
|
c4bf441fc1 | ||
|
|
8ca9add11e | ||
|
|
1282de3d05 | ||
|
|
f25c40bb08 | ||
|
|
3e12aa3492 | ||
|
|
78696adb53 | ||
|
|
51f649da8a | ||
|
|
3a35e59691 | ||
|
|
f5098784d7 | ||
|
|
b3a21eaa52 | ||
|
|
cc1d92e62f | ||
|
|
3fdc34e286 | ||
|
|
10f3eaad39 | ||
|
|
ede46bd1e0 | ||
|
|
aa48af32ea | ||
|
|
4acc6d1114 | ||
|
|
b781be3cc2 | ||
|
|
d4a04a5055 | ||
|
|
3985301749 | ||
|
|
51b6d2536d | ||
|
|
2001353e9e | ||
|
|
da8b5f62d2 | ||
|
|
1cf47a06c6 | ||
|
|
2b68b022f9 | ||
|
|
7de2e883b9 | ||
|
|
ebbbfa1998 | ||
|
|
7120aede15 | ||
|
|
14c93f0e96 |
22
.github/workflows/rust.yml
vendored
22
.github/workflows/rust.yml
vendored
@@ -1,22 +0,0 @@
|
||||
name: Rust
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
- name: Run tests
|
||||
run: cargo test --verbose
|
||||
34
.github/workflows/test.yml
vendored
Normal file
34
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
cpu_test:
|
||||
name: CPU Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Build
|
||||
run: cargo build --no-default-features --verbose
|
||||
- name: Run tests
|
||||
run: cargo test --no-default-features --verbose
|
||||
# macos_test:
|
||||
# name: MacOS Tests
|
||||
# runs-on: macos-13
|
||||
# timeout-minutes: 20
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
# - name: Build
|
||||
# run: cargo build --verbose
|
||||
# - name: Run tests
|
||||
# run: cargo test --verbose -- --test-threads 1
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,7 +1,14 @@
|
||||
/target
|
||||
/crates/**/target
|
||||
|
||||
.DS_Store
|
||||
.vscode
|
||||
*.vscode
|
||||
Cargo.lock
|
||||
*.st
|
||||
*.npx
|
||||
*.npz
|
||||
*.npz
|
||||
/**/llama-7b-hf
|
||||
/**/mistral-7b-hf
|
||||
/**/setup_weights/target
|
||||
*.model
|
||||
*.gguf
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -1,31 +1,29 @@
|
||||
[package]
|
||||
name = "luminal"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
description = "Deep learning at the speed of light."
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
#default = ["cuda"]
|
||||
cuda = ["dep:cudarc"]
|
||||
|
||||
[dependencies]
|
||||
luminal_macro = { path = "./resources/luminal_macro" }
|
||||
itertools = "0.11.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
matrixmultiply = "0.3.8"
|
||||
num-traits = "0.2.16"
|
||||
petgraph = {path="./resources/petgraph"}
|
||||
petgraph = "0.6.4"
|
||||
rand = "0.8.5"
|
||||
strum = { version = "0.25.0", features = ["derive"] }
|
||||
urlencoding = "2.1.2"
|
||||
webbrowser = "0.8.10"
|
||||
dyn-clone = "1.0.12"
|
||||
cudarc = {version="0.9.13", optional=true}
|
||||
|
||||
safetensors = "0.3.1"
|
||||
memmap2 = "0.7.1"
|
||||
half = "2.3.1"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
tinyvec = "1.6.0"
|
||||
term_size = "0.3.2"
|
||||
colored = "2.0.4"
|
||||
regex = "1.9.5"
|
||||
rustc-hash = "1.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = "0.13"
|
||||
tokenizers = "0.13.3"
|
||||
dfdx = { version = "0.13", features = ["f16"] }
|
||||
|
||||
201
LICENSE-APACHE
Normal file
201
LICENSE-APACHE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
21
LICENSE-MIT
Normal file
21
LICENSE-MIT
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Joe Fioti
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
115
README.md
115
README.md
@@ -1,76 +1,109 @@
|
||||
# luminal
|
||||

|
||||
[](https://github.com/Sidekick-AI/dataflow/actions)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/VQf3j8WWNd)
|
||||
|
||||
**Deep learning at the speed of light.**
|
||||
|
||||
Luminal is a deep learning library that prioritizes **static computation** and **operator fusion** to achieve high performance.
|
||||
Luminal is a deep learning library that uses **composable compilers** to achieve high performance.
|
||||
|
||||
```rust
|
||||
use luminal::prelude::*;
|
||||
|
||||
// Setup graph and tensors
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.new_tensor::<R2<3, 1>>("A");
|
||||
let b = cx.new_tensor::<R2<1, 4>>("B");
|
||||
let a = cx.tensor::<R2<3, 1>>()
|
||||
.set([[1.0], [2.0], [3.0]]);
|
||||
let b = cx.tensor::<R2<1, 4>>()
|
||||
.set([[1.0, 2.0, 3.0, 4.0]]);
|
||||
|
||||
// Do stuff...
|
||||
let c = a.matmul(b);
|
||||
// Do math...
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
// Set inputs and mark outputs
|
||||
a.set(vec![1.0, 2.0, 3.0]);
|
||||
b.set(vec![1.0, 2.0, 3.0, 3.0]);
|
||||
c.mark();
|
||||
|
||||
// Optimize and run graph
|
||||
cx.optimize(GenericOptimizer::default());
|
||||
// Compile and run graph
|
||||
cx.compile(<(GenericCompiler, CPUCompiler)>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
// Get result
|
||||
println!("Result: {:?}", c.retrieve().unwrap().data);
|
||||
println!("Result: {:?}", c);
|
||||
```
|
||||
|
||||
## Why does this look so different from other DL libraries?
|
||||
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. So when you see `x + y`, the addition actually happens right there. This is great for debugging, it works exactly as most developers expect.
|
||||
## Getting Started
|
||||
**Mistral 7B**
|
||||
```bash
|
||||
cd ./examples/mistral
|
||||
# Download the model
|
||||
bash ./setup/setup.sh
|
||||
# Run the model
|
||||
cargo run --release --features metal # MacOS (Recommended)
|
||||
cargo run --release --features cuda # Nvidia
|
||||
cargo run --release # CPU
|
||||
```
|
||||
|
||||
However, this isn't great for performance because what makes sense for a developer doesn't make sense for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
## Features
|
||||
### Speed
|
||||
Luminal can run Q8 Mistral 7B on M-series Macbooks at 15-25 tokens per second. The goal is to become the fastest ML framework for any model on any device.
|
||||
|
||||
Luminal takes a different approach, more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Here everything's static. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, and executed later.
|
||||
### Simplicity
|
||||
The core of luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
## But Why?
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our optimizers have global knowledge and can do much more aggressive optimization **without any sync points**.
|
||||
### RISC-style architecture
|
||||
Everything in luminal boils down to 11 primitive ops:
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Contiguous`
|
||||
|
||||
Of course, we can still split the network into multiple seperate graphs if we want to insert dynamic control flow part-way through, which means this method doesn't preclude optimizations like KV caching, because the KV cached forward pass is just a seperate graph!
|
||||
These ops are enough to support transformers, convnets, etc.
|
||||
|
||||
Some huge benefits are now unlocked:
|
||||
### Native
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
Correctness matters. So we write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
|
||||
|
||||
## Ideology
|
||||
### Why does this look so different from other DL libraries?
|
||||
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. In PyTorch, when you see `x + y`, the addition actually happens right there. This is great for debugging because it works exactly as most developers expect.
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### Compile everything
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd will be handled by a compiler!
|
||||
|
||||
Now we can do:
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Devices and Dtypes are handled through optimizers (just run the CUDA optimizer to convert the graph to use CUDA kernels, then the fp16 optimizer to convert to half-precision kernels)
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
|
||||
## RISC-style architecture
|
||||
Luminal can be ran on new accelerators by implementing 11 primitive ops. Take a look at `src/optimizers/cuda/prim.rs` to see 1-to-1 CUDA translations of the primops.
|
||||
### Compile-time Shape Checks
|
||||
All operations are shape checked at compile time, so no more shape mismatches! Credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
|
||||
|
||||
Accellerators are free to implement their own custom ops, and their own optimizers to convert luminal primitive ops to their bespoke ops.
|
||||
|
||||
## Compile-time Shape Checks
|
||||
All operations are shape checked at compile time, so no more shape mismatches! All credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
|
||||
|
||||
## View the Graph
|
||||
Once you've written all your computation code, run `cx.display_graph()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.optimize(GeneralOptimizer::default())` and display the graph again. Much better.
|
||||
### View the Graph
|
||||
Once you've written all your computation code, run `cx.display()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.compile(GenericCompiler::default())` and display the graph again. Much better.
|
||||
|
||||
## Where are we?
|
||||
Currently luminal is extremely alpha. Please don't use this in prod.
|
||||
|
||||
- Llama 1 is implemented in `examples/llama`. You'll need to follow the instructions in [llama-dfdx](https://github.com/coreylowman/llama-dfdx) to download and convert the llama weights, and point this example loading path at them.
|
||||
- The llama example shows how to implement a loader for a custom format. Safetensors loaders are already implemented, and are the recommended way to load a model.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Performance on M-series macs with LLMs is within 20% of llama.cpp (a *heavily* optimized library)
|
||||
- Mistral 7B and Llama 7B are implemented in `examples/`. See instructions above for running.
|
||||
- We have a small library of NN modules in `nn`, including transformers.
|
||||
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the tinygrad ops set.
|
||||
- Currently there are very few optimizers, so primops are mostly used to run these models, which are very slow.
|
||||
- Next release will bring a signifigant amount of optimizers which should fuse primops into much faster ops. The aim for 0.2 is to be usably fast, not SOTA yet.
|
||||
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
- The aim for 0.3 is to achieve SOTA performance on an M1 pro (50 tok/s), and near SOTA on single nvidia gpus (>100 tok/s), as well as support many mainstream models (Whisper, Stable Diffusion, Yolo v9, etc.)
|
||||
|
||||
Some things on the roadmap:
|
||||
- Write common sense cuda ops and optimizer (matmuls, mul-add, etc.)
|
||||
- Optimize cuda and metal matmul kernels
|
||||
- Fine-grained metal and cuda IR
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Write specialized CUDA kernels for full transformer architecture (FlashAttention, etc.)
|
||||
- Automatic differentiation of graphs
|
||||
- Autograd engine
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM training
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
## License
|
||||
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
22
crates/luminal_cuda/Cargo.toml
Normal file
22
crates/luminal_cuda/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "luminal_cuda"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
description = "Cuda compiler for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_cudarc = { version="0.10.0", features = [
|
||||
"cublas",
|
||||
"f16",
|
||||
]}
|
||||
itertools = "0.12.1"
|
||||
rustc-hash = "1.1.0"
|
||||
num-traits = "0.2.18"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = { version = "0.13", features = ["f16"] }
|
||||
rand = "0.8.5"
|
||||
612
crates/luminal_cuda/src/binary.rs
Normal file
612
crates/luminal_cuda/src/binary.rs
Normal file
@@ -0,0 +1,612 @@
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use luminal_cudarc::{
|
||||
driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig},
|
||||
nvrtc::{compile_ptx_with_opts, CompileOptions},
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
op::*,
|
||||
prelude::{petgraph::visit::EdgeRef, *},
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
get_idx_valid_exps, hash,
|
||||
other::CudaARange,
|
||||
prim::{CudaAdd, CudaCopyToDevice, CudaLessThan, CudaMul, CudaSumReduce},
|
||||
render_dyn_dim_inputs, select_const, CudaData, CudaFloat,
|
||||
};
|
||||
|
||||
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
|
||||
pub struct CudaSub<T> {
|
||||
function: CudaFunction,
|
||||
device: Arc<CudaDevice>,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CudaFloat> CudaSub<T> {
|
||||
pub fn new(
|
||||
a_shape: ShapeTracker,
|
||||
b_shape: ShapeTracker,
|
||||
dev: Arc<CudaDevice>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (a_idx, a_valid) = get_idx_valid_exps(a_shape);
|
||||
let (b_idx, b_valid) = get_idx_valid_exps(b_shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape]);
|
||||
let type_name = T::type_name();
|
||||
let mut code = format!(
|
||||
"
|
||||
#include \"cuda_fp16.h\"
|
||||
extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp_a, const {type_name} *inp_b, int numel{rendered}) {{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < numel) {{
|
||||
out[idx] =
|
||||
(({a_valid}) == 0 ? {} : inp_a[{a_idx}])
|
||||
- (({b_valid}) == 0 ? {} : inp_b[{b_idx}]);
|
||||
}}
|
||||
}}",
|
||||
if T::is_f32() {
|
||||
"0.0"
|
||||
} else {
|
||||
"__float2half(0.0)"
|
||||
},
|
||||
if T::is_f32() {
|
||||
"0.0"
|
||||
} else {
|
||||
"__float2half(0.0)"
|
||||
},
|
||||
);
|
||||
let name = format!("kernel_{}", hash(&code));
|
||||
code = code.replace("kernel", &name);
|
||||
if !dev.has_func(&name, &name) {
|
||||
dev.load_ptx(
|
||||
compile_ptx_with_opts(
|
||||
code,
|
||||
CompileOptions {
|
||||
arch: Some("sm_75"),
|
||||
include_paths: vec!["/usr/local/cuda/include".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
&name,
|
||||
&[name.clone().leak()],
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Self {
|
||||
function: dev.get_func(&name, &name).unwrap(),
|
||||
device: dev,
|
||||
_phantom: Default::default(),
|
||||
dyn_symbols,
|
||||
dyn_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Operator for CudaSub<T>
|
||||
where
|
||||
T: std::fmt::Debug
|
||||
+ Copy
|
||||
+ luminal_cudarc::driver::DeviceRepr
|
||||
+ std::marker::Unpin
|
||||
+ luminal_cudarc::driver::ValidAsZeroBits,
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let a = tensors[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let b = tensors[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
|
||||
let out = self.device.alloc_zeros::<T>(inp_size).unwrap();
|
||||
let mut params = vec![
|
||||
(&out).as_kernel_param(),
|
||||
(&a.0).as_kernel_param(),
|
||||
(&b.0).as_kernel_param(),
|
||||
inp_size.as_kernel_param(),
|
||||
];
|
||||
let mut dims = [0; 10];
|
||||
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
|
||||
for (i, d) in self.dyn_symbols.iter().enumerate() {
|
||||
dims[i] = dyn_map[d] as i32;
|
||||
params.push(unsafe {
|
||||
dims[0]
|
||||
.as_kernel_param()
|
||||
.add(i * std::mem::size_of::<i32>())
|
||||
});
|
||||
}
|
||||
unsafe {
|
||||
self.function
|
||||
.clone()
|
||||
.launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct CudaSubtractionCompiler<T: CudaFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat> Compiler for CudaSubtractionCompiler<T>
|
||||
where
|
||||
CudaData<T>: luminal::prelude::Data,
|
||||
{
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = CudaDevice::new(0).unwrap();
|
||||
let (mut neg_one, mut mul, mut add) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let mut searcher = select_const!(-1.0, T)
|
||||
.ptr(&mut neg_one)
|
||||
.edge(SelectOp::new().ty::<CudaMul<T>>().ptr(&mut mul))
|
||||
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut add))
|
||||
.search(graph);
|
||||
|
||||
while searcher.next_match() {
|
||||
if check_no_delete(graph, &[neg_one, mul, add]) {
|
||||
continue;
|
||||
}
|
||||
let (a, a_edge) = graph
|
||||
.graph
|
||||
.edges_directed(add, petgraph::Direction::Incoming)
|
||||
.find(|e| e.source() != mul)
|
||||
.map(|e| (e.source(), e.weight().as_data().unwrap()))
|
||||
.unwrap();
|
||||
let (b, b_edge) = graph
|
||||
.graph
|
||||
.edges_directed(mul, petgraph::Direction::Incoming)
|
||||
.find(|e| e.source() != neg_one)
|
||||
.map(|e| (e.source(), e.weight().as_data().unwrap()))
|
||||
.unwrap();
|
||||
let b_final_shape = graph
|
||||
.graph
|
||||
.edges_connecting(mul, add)
|
||||
.next()
|
||||
.unwrap()
|
||||
.weight()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2;
|
||||
if !b_final_shape.is_contiguous()
|
||||
|| b_final_shape.is_sliced()
|
||||
|| b_final_shape.is_padded()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let sub = graph
|
||||
.add_op(CudaSub::<T>::new(
|
||||
a_edge.2,
|
||||
b_edge.2,
|
||||
dev.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(a, a_edge.1, a_edge.2)
|
||||
.input(b, b_edge.1, b_edge.2)
|
||||
.finish();
|
||||
move_outgoing_edge(add, sub, &mut graph.graph);
|
||||
|
||||
if graph.get_dests(neg_one).len() == 1 {
|
||||
graph.graph.remove_node(neg_one);
|
||||
}
|
||||
graph.graph.remove_node(mul);
|
||||
graph.graph.remove_node(add);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
|
||||
pub struct CudaEqual<T> {
|
||||
function: CudaFunction,
|
||||
device: Arc<CudaDevice>,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CudaFloat> CudaEqual<T> {
|
||||
pub fn new(
|
||||
a_shape: ShapeTracker,
|
||||
b_shape: ShapeTracker,
|
||||
dev: Arc<CudaDevice>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (a_idx, a_valid) = get_idx_valid_exps(a_shape);
|
||||
let (b_idx, b_valid) = get_idx_valid_exps(b_shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape]);
|
||||
let type_name = T::type_name();
|
||||
let mut code = format!(
|
||||
"
|
||||
#include \"cuda_fp16.h\"
|
||||
extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp_a, const {type_name} *inp_b, int numel{rendered}) {{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < numel) {{
|
||||
{type_name} a_val = ({a_valid}) == 0 ? {} : inp_a[{a_idx}];
|
||||
{type_name} b_val = ({b_valid}) == 0 ? {} : inp_b[{b_idx}];
|
||||
out[idx] = ({type_name})(a_val == b_val);
|
||||
}}
|
||||
}}",
|
||||
if T::is_f32() {
|
||||
"0.0"
|
||||
} else {
|
||||
"__float2half(0.0)"
|
||||
},
|
||||
if T::is_f32() {
|
||||
"0.0"
|
||||
} else {
|
||||
"__float2half(0.0)"
|
||||
},
|
||||
);
|
||||
let name = format!("kernel_{}", hash(&code));
|
||||
code = code.replace("kernel", &name);
|
||||
if !dev.has_func(&name, &name) {
|
||||
dev.load_ptx(
|
||||
compile_ptx_with_opts(
|
||||
code,
|
||||
CompileOptions {
|
||||
arch: Some("sm_75"),
|
||||
include_paths: vec!["/usr/local/cuda/include".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
&name,
|
||||
&[name.clone().leak()],
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Self {
|
||||
function: dev.get_func(&name, &name).unwrap(),
|
||||
device: dev,
|
||||
_phantom: Default::default(),
|
||||
dyn_symbols,
|
||||
dyn_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Operator for CudaEqual<T>
|
||||
where
|
||||
T: std::fmt::Debug
|
||||
+ Copy
|
||||
+ luminal_cudarc::driver::DeviceRepr
|
||||
+ std::marker::Unpin
|
||||
+ luminal_cudarc::driver::ValidAsZeroBits,
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let a = tensors[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let b = tensors[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
|
||||
let out = self.device.alloc_zeros::<T>(inp_size).unwrap();
|
||||
let mut params = vec![
|
||||
(&out).as_kernel_param(),
|
||||
(&a.0).as_kernel_param(),
|
||||
(&b.0).as_kernel_param(),
|
||||
inp_size.as_kernel_param(),
|
||||
];
|
||||
let mut dims = [0; 10];
|
||||
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
|
||||
for (i, d) in self.dyn_symbols.iter().enumerate() {
|
||||
dims[i] = dyn_map[d] as i32;
|
||||
params.push(unsafe {
|
||||
dims[0]
|
||||
.as_kernel_param()
|
||||
.add(i * std::mem::size_of::<i32>())
|
||||
});
|
||||
}
|
||||
unsafe {
|
||||
self.function
|
||||
.clone()
|
||||
.launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct CudaEqualCompiler<T: CudaFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat> Compiler for CudaEqualCompiler<T>
|
||||
where
|
||||
CudaData<T>: luminal::prelude::Data,
|
||||
{
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = CudaDevice::new(0).unwrap();
|
||||
let (mut less_than1, mut less_than2, mut add, mut one, mut sub) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let s = select_const!(1.0, T).ptr(&mut one).edge(
|
||||
SelectOp::new()
|
||||
.ty::<CudaLessThan<T>>()
|
||||
.ptr(&mut less_than1)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<CudaLessThan<T>>()
|
||||
.ptr(&mut less_than2)
|
||||
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut add)),
|
||||
)
|
||||
.edge(SelectOp::new().ty::<CudaSub<T>>().ptr(&mut sub)),
|
||||
);
|
||||
|
||||
let mut searcher = s.search(graph);
|
||||
while searcher.next_match() {
|
||||
let lt1_inputs = graph
|
||||
.graph
|
||||
.neighbors_directed(less_than1, petgraph::Direction::Incoming)
|
||||
.sorted()
|
||||
.collect::<Vec<_>>();
|
||||
let lt2_inputs = graph
|
||||
.graph
|
||||
.neighbors_directed(less_than2, petgraph::Direction::Incoming)
|
||||
.sorted()
|
||||
.collect::<Vec<_>>();
|
||||
if lt1_inputs != lt2_inputs {
|
||||
continue;
|
||||
}
|
||||
let inputs = graph
|
||||
.graph
|
||||
.edges_directed(less_than1, petgraph::Direction::Incoming)
|
||||
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
|
||||
.map(|e| e.source())
|
||||
.collect::<Vec<_>>();
|
||||
let (a, b) = (inputs[0], inputs[1]);
|
||||
if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) {
|
||||
continue;
|
||||
}
|
||||
let a_edge = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(a, less_than1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap();
|
||||
let b_edge = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(b, less_than1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap();
|
||||
let equals = graph
|
||||
.add_op(CudaEqual::<T>::new(
|
||||
a_edge.2,
|
||||
b_edge.2,
|
||||
dev.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(a, a_edge.1, a_edge.2)
|
||||
.input(b, b_edge.1, b_edge.2)
|
||||
.finish();
|
||||
move_outgoing_edge(sub, equals, &mut graph.graph);
|
||||
|
||||
graph.graph.remove_node(sub);
|
||||
graph.safe_remove_node(add, 0);
|
||||
graph.safe_remove_node(one, 0);
|
||||
graph.safe_remove_node(less_than2, 0);
|
||||
graph.safe_remove_node(less_than1, 0);
|
||||
searcher.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Clone, LuminalEqFalse)]
|
||||
pub struct CudaGather<T> {
|
||||
function: CudaFunction,
|
||||
device: Arc<CudaDevice>,
|
||||
pub embed_dim: usize,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CudaFloat> CudaGather<T> {
|
||||
pub fn new(dev: Arc<CudaDevice>, embed_dim: usize) -> Self {
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include \"cuda_fp16.h\"
|
||||
extern \"C\" __global__ void gather({type_name} *out, const {type_name} *weights, const float *inp, int n_embeddings, int embedding_dim) {{
|
||||
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
if (x < n_embeddings && y < embedding_dim) {{
|
||||
out[x * embedding_dim + y] = weights[(int)inp[x] * embedding_dim + y];
|
||||
}}
|
||||
}}");
|
||||
dev.load_ptx(
|
||||
compile_ptx_with_opts(
|
||||
code,
|
||||
CompileOptions {
|
||||
arch: Some("sm_75"),
|
||||
include_paths: vec!["/usr/local/cuda/include".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
"gather",
|
||||
&["gather"],
|
||||
)
|
||||
.unwrap();
|
||||
Self {
|
||||
function: dev.get_func("gather", "gather").unwrap(),
|
||||
device: dev,
|
||||
embed_dim,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Operator for CudaGather<T>
|
||||
where
|
||||
T: std::fmt::Debug + Copy + luminal_cudarc::driver::DeviceRepr + std::marker::Unpin + CudaFloat,
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
// Inp 1 should be Vec<f32> and inp 2 should be a CudaSlice<T>
|
||||
let indexes = inputs[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<Vec<f32>>()
|
||||
.unwrap();
|
||||
let weights = inputs[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
|
||||
let mut indexes_buffer = unsafe { self.device.alloc::<f32>(indexes.len()).unwrap() };
|
||||
self.device
|
||||
.htod_copy_into(indexes.clone(), &mut indexes_buffer)
|
||||
.unwrap();
|
||||
let mut out = self
|
||||
.device
|
||||
.alloc_zeros::<T>(indexes.len() * self.embed_dim)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
self.function
|
||||
.clone()
|
||||
.launch(
|
||||
LaunchConfig {
|
||||
grid_dim: (
|
||||
indexes.len().div_ceil(16) as u32,
|
||||
self.embed_dim.div_ceil(16) as u32,
|
||||
1,
|
||||
),
|
||||
block_dim: (16, 16, 1),
|
||||
shared_mem_bytes: 0,
|
||||
},
|
||||
(
|
||||
&mut out,
|
||||
&weights.0,
|
||||
&indexes_buffer,
|
||||
indexes.len(),
|
||||
self.embed_dim,
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct MetalGatherCompiler<T: CudaFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat> Compiler for MetalGatherCompiler<T>
|
||||
where
|
||||
CudaData<T>: luminal::prelude::Data,
|
||||
{
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = CudaDevice::new(0).unwrap();
|
||||
let (mut ind_copy, mut arange, mut equal, mut mul, mut sum_reduce) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let s = SelectOp::new()
|
||||
.ty::<CudaARange<T>>()
|
||||
.ptr(&mut arange)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<CudaCopyToDevice<T>>()
|
||||
.ptr(&mut ind_copy)
|
||||
.edge(SelectOp::new().ty::<CudaEqual<T>>().ptr(&mut equal)),
|
||||
)
|
||||
.edge(SelectOp::new().ty::<CudaMul<T>>().ptr(&mut mul))
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<CudaSumReduce<T>>()
|
||||
.ptr(&mut sum_reduce),
|
||||
);
|
||||
let mut searcher = s.search(graph);
|
||||
while searcher.next_match() {
|
||||
if check_no_delete(graph, &[arange, equal, mul, sum_reduce]) {
|
||||
continue;
|
||||
}
|
||||
let embedding_dim = graph
|
||||
.graph
|
||||
.edges_directed(mul, petgraph::Direction::Incoming)
|
||||
.find(|e| e.source() != equal && !e.weight().is_schedule())
|
||||
.unwrap()
|
||||
.weight()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2
|
||||
.shape()[2]
|
||||
.to_usize()
|
||||
.unwrap();
|
||||
let gather = graph
|
||||
.add_op(CudaGather::<T>::new(dev.clone(), embedding_dim))
|
||||
.finish();
|
||||
move_incoming_edge(ind_copy, gather, &mut graph.graph);
|
||||
graph.safe_remove_node(equal, 1);
|
||||
move_incoming_edge(mul, gather, &mut graph.graph);
|
||||
move_outgoing_edge(sum_reduce, gather, &mut graph.graph);
|
||||
graph.graph.remove_node(sum_reduce);
|
||||
graph.safe_remove_node(mul, 0);
|
||||
graph.safe_remove_node(ind_copy, 0);
|
||||
graph.safe_remove_node(arange, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
184
crates/luminal_cuda/src/lib.rs
Normal file
184
crates/luminal_cuda/src/lib.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
mod binary;
|
||||
mod matmul;
|
||||
mod other;
|
||||
mod prim;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal_cudarc::driver::{CudaSlice, DeviceRepr};
|
||||
|
||||
use std::{collections::hash_map::DefaultHasher, fmt::Write, hash::Hasher};
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
use self::symbolic::{BigExpression, Term};
|
||||
|
||||
pub type CudaCompiler<T> = (
|
||||
prim::CudaPrimitiveCompiler<T>,
|
||||
binary::CudaSubtractionCompiler<T>,
|
||||
binary::CudaEqualCompiler<T>,
|
||||
other::ARangeCompiler<T>,
|
||||
binary::MetalGatherCompiler<T>,
|
||||
matmul::CudaMatMulCompiler<T>,
|
||||
prim::CopyCompiler<T>,
|
||||
);
|
||||
|
||||
pub trait CudaFloat:
|
||||
std::fmt::Debug
|
||||
+ Copy
|
||||
+ luminal_cudarc::driver::DeviceRepr
|
||||
+ std::marker::Unpin
|
||||
+ luminal_cudarc::driver::ValidAsZeroBits
|
||||
{
|
||||
fn to_f32(self) -> f32;
|
||||
fn from_f32(a: f32) -> Self;
|
||||
fn is_f32() -> bool;
|
||||
fn type_name() -> &'static str;
|
||||
}
|
||||
|
||||
impl CudaFloat for f32 {
|
||||
fn from_f32(a: f32) -> Self {
|
||||
a
|
||||
}
|
||||
fn to_f32(self) -> f32 {
|
||||
self
|
||||
}
|
||||
fn is_f32() -> bool {
|
||||
true
|
||||
}
|
||||
fn type_name() -> &'static str {
|
||||
"float"
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct CudaData<T>(CudaSlice<T>);
|
||||
|
||||
impl<T: DeviceRepr> Clone for CudaData<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.try_clone().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl Data for CudaData<f32> {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaFloat for f16 {
|
||||
fn from_f32(a: f32) -> Self {
|
||||
f16::from_f32(a)
|
||||
}
|
||||
fn to_f32(self) -> f32 {
|
||||
self.to_f32()
|
||||
}
|
||||
fn is_f32() -> bool {
|
||||
false
|
||||
}
|
||||
fn type_name() -> &'static str {
|
||||
"__half"
|
||||
}
|
||||
}
|
||||
impl Data for CudaData<f16> {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_to_cuda_string(expr: BigExpression) -> String {
|
||||
let mut symbols = vec![];
|
||||
for term in expr.terms {
|
||||
let new_symbol = match term {
|
||||
Term::Num(n) => n.to_string(),
|
||||
Term::Var(c) => {
|
||||
if c == 'z' {
|
||||
"(int)idx".to_string()
|
||||
} else {
|
||||
c.to_string()
|
||||
}
|
||||
}
|
||||
Term::Max => format!(
|
||||
"max((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
Term::Min => format!(
|
||||
"min((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
_ => format!(
|
||||
"({}{term:?}{})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
};
|
||||
symbols.push(new_symbol);
|
||||
}
|
||||
symbols.pop().unwrap()
|
||||
}
|
||||
|
||||
fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) {
|
||||
(
|
||||
expr_to_cuda_string(shape.index_expression()),
|
||||
expr_to_cuda_string(shape.valid_expression()),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_dyn_dim_inputs(shapes: &[ShapeTracker]) -> (Vec<char>, String) {
|
||||
let symbols: Vec<char> = shapes
|
||||
.iter()
|
||||
.flat_map(|st| {
|
||||
st.shape()
|
||||
.into_iter()
|
||||
.chain(
|
||||
st.padding
|
||||
.into_iter()
|
||||
.flat_map(|i| [i.0.into(), i.1.into()]),
|
||||
)
|
||||
.chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
|
||||
})
|
||||
.flat_map(|d| d.to_symbols())
|
||||
.unique()
|
||||
.collect();
|
||||
(
|
||||
symbols.clone(),
|
||||
symbols.into_iter().fold(String::default(), |mut acc, c| {
|
||||
write!(&mut acc, ", const int {c}").unwrap();
|
||||
acc
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select_const {
|
||||
($i: expr, $t: tt) => {
|
||||
luminal::compiler_utils::SelectOp::new().check(|o, _| {
|
||||
if let Some(c) = o.as_any().downcast_ref::<$crate::prim::CudaConstant<$t>>() {
|
||||
if let luminal::op::ConstantValue::Float(f) = c.0 {
|
||||
(f - $i).abs() < 0.0001
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
};
|
||||
}
|
||||
|
||||
fn hash<T: std::hash::Hash>(obj: T) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
obj.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
353
crates/luminal_cuda/src/matmul.rs
Normal file
353
crates/luminal_cuda/src/matmul.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use luminal_cudarc::{
|
||||
cublas::{sys::cublasOperation_t::*, CudaBlas},
|
||||
driver::{CudaDevice, DevicePtr, DevicePtrMut},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
prim::{CudaMul, CudaSumReduce},
|
||||
CudaData, CudaFloat,
|
||||
};
|
||||
use luminal::{
|
||||
graph::NodeIndex,
|
||||
op::{InputTensor, Operator},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
/// Multiplies a MxK matrix with a KxN matrix, resulting in a MxN matrix
|
||||
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
|
||||
pub struct CudaMatmul2D<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat + 'static> Operator for CudaMatmul2D<T>
|
||||
where
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let (m, k, n) = (
|
||||
a_shape[0].to_usize().unwrap() as i32,
|
||||
a_shape[1].to_usize().unwrap() as i32,
|
||||
b_shape[1].to_usize().unwrap() as i32,
|
||||
);
|
||||
let a = inp[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let b = inp[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let mut out = self.1.alloc_zeros::<T>((m * n) as usize).unwrap();
|
||||
let (a_row_major, b_row_major) = (
|
||||
inp[0].1.indexes[1] > inp[0].1.indexes[0],
|
||||
inp[1].1.indexes[1] > inp[1].1.indexes[0],
|
||||
);
|
||||
let (transa, transb) = match (a_row_major, b_row_major) {
|
||||
(true, true) => (CUBLAS_OP_N, CUBLAS_OP_N),
|
||||
(false, false) => (CUBLAS_OP_T, CUBLAS_OP_T),
|
||||
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
|
||||
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
|
||||
};
|
||||
if T::is_f32() {
|
||||
unsafe {
|
||||
luminal_cudarc::cublas::result::sgemm(
|
||||
*self.0.handle(),
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
&1.0_f32 as *const f32,
|
||||
*b.0.device_ptr() as *const f32,
|
||||
if b_row_major { n } else { k },
|
||||
*a.0.device_ptr() as *const f32,
|
||||
if a_row_major { k } else { m },
|
||||
&0.0_f32 as *const f32,
|
||||
*out.device_ptr_mut() as *mut f32,
|
||||
n,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
luminal_cudarc::cublas::result::hgemm(
|
||||
*self.0.handle(),
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
&f16::from_f32(1.0) as *const f16,
|
||||
*b.0.device_ptr() as *const f16,
|
||||
if b_row_major { n } else { k },
|
||||
*a.0.device_ptr() as *const f16,
|
||||
if a_row_major { k } else { m },
|
||||
&f16::from_f32(0.0) as *const f16,
|
||||
*out.device_ptr_mut() as *mut f16,
|
||||
n,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplies a BxMxK matrix with a BxKxN matrix, resulting in a BxMxN matrix
|
||||
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
|
||||
pub struct CudaBatchMatmul2D<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat + 'static> Operator for CudaBatchMatmul2D<T>
|
||||
where
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let a_strides = inp[0].1.strides();
|
||||
let (batch_size, m, k, n) = (
|
||||
a_shape[0].to_usize().unwrap() as i32,
|
||||
a_shape[1].to_usize().unwrap() as i32,
|
||||
a_shape[2].to_usize().unwrap() as i32,
|
||||
b_shape[1].to_usize().unwrap() as i32,
|
||||
);
|
||||
let a = inp[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let b = inp[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<CudaData<T>>()
|
||||
.unwrap();
|
||||
let mut out = self
|
||||
.1
|
||||
.alloc_zeros::<T>((m * n * batch_size) as usize)
|
||||
.unwrap();
|
||||
let (a_row_major, b_row_major) = (
|
||||
inp[0].1.indexes[2] > inp[0].1.indexes[1],
|
||||
inp[1].1.indexes[1] > inp[1].1.indexes[0],
|
||||
);
|
||||
let (transa, transb) = match (a_row_major, b_row_major) {
|
||||
(true, true) => (CUBLAS_OP_N, CUBLAS_OP_N),
|
||||
(false, false) => (CUBLAS_OP_T, CUBLAS_OP_T),
|
||||
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
|
||||
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
|
||||
};
|
||||
if T::is_f32() {
|
||||
unsafe {
|
||||
luminal_cudarc::cublas::result::sgemm_strided_batched(
|
||||
*self.0.handle(),
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
&1.0_f32 as *const f32,
|
||||
*b.0.device_ptr() as *const f32,
|
||||
if b_row_major { n } else { k },
|
||||
0,
|
||||
*a.0.device_ptr() as *const f32,
|
||||
if a_row_major { k } else { m },
|
||||
a_strides[0].to_usize().unwrap() as i64,
|
||||
&0.0_f32 as *const f32,
|
||||
*out.device_ptr_mut() as *mut f32,
|
||||
n,
|
||||
(m * n) as i64,
|
||||
batch_size,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
luminal_cudarc::cublas::result::hgemm_strided_batched(
|
||||
*self.0.handle(),
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
&f16::from_f32(1.0) as *const f16,
|
||||
*b.0.device_ptr() as *const f16,
|
||||
if b_row_major { n } else { k },
|
||||
0,
|
||||
*a.0.device_ptr() as *const f16,
|
||||
if a_row_major { k } else { m },
|
||||
a_strides[0].to_usize().unwrap() as i64,
|
||||
&f16::from_f32(0.0) as *const f16,
|
||||
*out.device_ptr_mut() as *mut f16,
|
||||
n,
|
||||
(m * n) as i64,
|
||||
batch_size,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CudaMatMulCompiler<T>(PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat + 'static> Compiler for CudaMatMulCompiler<T>
|
||||
where
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
let dev = CudaDevice::new(0).unwrap();
|
||||
// Look for the matmul pattern
|
||||
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
|
||||
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
|
||||
// Actually starts at [A,B] | [B, C]
|
||||
let s = SelectEdge::new(
|
||||
SelectOp::new()
|
||||
.ty::<CudaMul<T>>()
|
||||
.shapes([['A', 'C', 'B'], ['A', 'C', 'B']])
|
||||
.fakes([
|
||||
[Some(false), Some(true), Some(false)],
|
||||
[Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul),
|
||||
SelectOp::new()
|
||||
.ty::<CudaSumReduce<T>>()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<CudaSumReduce<T>>() {
|
||||
o.2 == 2
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
);
|
||||
let mut searcher = s.search(graph);
|
||||
while searcher.next_match() {
|
||||
if graph.no_delete.contains(&mul) {
|
||||
// The intermediate mul can't be deleted
|
||||
continue;
|
||||
}
|
||||
// Insert MatMul2D op
|
||||
let mut srcs = graph.get_sources(mul);
|
||||
// Undo expansions and permute
|
||||
srcs[0].2.remove_dim(1);
|
||||
srcs[1].2.remove_dim(0);
|
||||
srcs[1].2.permute(&[1, 0]);
|
||||
let new_op = graph
|
||||
.add_op(CudaMatmul2D::<T>(
|
||||
Arc::new(CudaBlas::new(dev.clone()).unwrap()),
|
||||
dev.clone(),
|
||||
Default::default(),
|
||||
))
|
||||
.input(srcs[0].0, 0, srcs[0].2)
|
||||
.input(srcs[1].0, 0, srcs[1].2)
|
||||
.finish();
|
||||
|
||||
// Create edges to dests
|
||||
move_outgoing_edge(sum_reduce, new_op, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
sum_reduce,
|
||||
new_op,
|
||||
);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
mul,
|
||||
new_op,
|
||||
);
|
||||
|
||||
// Remove the old ops
|
||||
graph.graph.remove_node(mul);
|
||||
graph.graph.remove_node(sum_reduce);
|
||||
}
|
||||
|
||||
// Look for the batch matmul pattern
|
||||
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
|
||||
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
|
||||
// Actually starts at [A,B] | [B, C]
|
||||
let mut searcher = SelectEdge::new(
|
||||
SelectOp::new()
|
||||
.ty::<CudaMul<T>>()
|
||||
.shapes([['D', 'A', 'C', 'B'], ['D', 'A', 'C', 'B']])
|
||||
.fakes([
|
||||
[Some(false), Some(false), Some(true), Some(false)],
|
||||
[Some(true), Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul),
|
||||
SelectOp::new()
|
||||
.ty::<CudaSumReduce<T>>()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<CudaSumReduce<T>>() {
|
||||
o.2 == 3
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
)
|
||||
.search(graph);
|
||||
while searcher.next_match() {
|
||||
if graph.no_delete.contains(&mul) {
|
||||
// The intermediate mul can't be deleted
|
||||
continue;
|
||||
}
|
||||
// Insert BatchMatMul2D op
|
||||
let mut srcs = graph.get_sources(mul);
|
||||
// Undo expansions and permute
|
||||
srcs[0].2.remove_dim(2);
|
||||
srcs[1].2.remove_dim(1);
|
||||
srcs[1].2.remove_dim(0);
|
||||
srcs[1].2.permute(&[1, 0]);
|
||||
let new_op = graph
|
||||
.add_op(CudaBatchMatmul2D::<T>(
|
||||
Arc::new(CudaBlas::new(dev.clone()).unwrap()),
|
||||
dev.clone(),
|
||||
Default::default(),
|
||||
))
|
||||
.input(srcs[0].0, 0, srcs[0].2)
|
||||
.input(srcs[1].0, 0, srcs[1].2)
|
||||
.finish();
|
||||
|
||||
// Create edges to dests
|
||||
move_outgoing_edge(sum_reduce, new_op, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
sum_reduce,
|
||||
new_op,
|
||||
);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
mul,
|
||||
new_op,
|
||||
);
|
||||
|
||||
// Remove the old ops
|
||||
graph.graph.remove_node(mul);
|
||||
graph.graph.remove_node(sum_reduce);
|
||||
}
|
||||
}
|
||||
}
|
||||
195
crates/luminal_cuda/src/other.rs
Normal file
195
crates/luminal_cuda/src/other.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use luminal_cudarc::{
|
||||
driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig},
|
||||
nvrtc::{compile_ptx_with_opts, CompileOptions},
|
||||
};
|
||||
|
||||
use luminal::{
|
||||
op::*,
|
||||
prelude::{petgraph::visit::EdgeRef, *},
|
||||
shape::symbolic::BigExpression,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
binary::CudaSub,
|
||||
prim::{CudaAdd, CudaContiguous, CudaSumReduce},
|
||||
select_const, CudaData, CudaFloat,
|
||||
};
|
||||
|
||||
#[derive(LuminalPrint, Clone, LuminalEqFalse)]
|
||||
pub struct CudaARange<T> {
|
||||
function: CudaFunction,
|
||||
device: Arc<CudaDevice>,
|
||||
pub size: BigExpression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CudaFloat> CudaARange<T> {
|
||||
pub fn new(
|
||||
dev: Arc<CudaDevice>,
|
||||
size: BigExpression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let type_name = T::type_name();
|
||||
let code = format!(
|
||||
"
|
||||
#include \"cuda_fp16.h\"
|
||||
extern \"C\" __global__ void arange({type_name} *out, int n_elements) {{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n_elements) {{
|
||||
out[idx] = ({type_name})idx;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
dev.load_ptx(
|
||||
compile_ptx_with_opts(
|
||||
code,
|
||||
CompileOptions {
|
||||
arch: Some("sm_75"),
|
||||
include_paths: vec!["/usr/local/cuda/include".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
"arange",
|
||||
&["arange"],
|
||||
)
|
||||
.unwrap();
|
||||
Self {
|
||||
function: dev.get_func("arange", "arange").unwrap(),
|
||||
device: dev,
|
||||
size,
|
||||
_phantom: Default::default(),
|
||||
dyn_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Operator for CudaARange<T>
|
||||
where
|
||||
T: std::fmt::Debug + Copy + luminal_cudarc::driver::DeviceRepr + std::marker::Unpin + CudaFloat,
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let n_elements = self
|
||||
.size
|
||||
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
|
||||
.unwrap();
|
||||
let mut out = self.device.alloc_zeros::<T>(n_elements).unwrap();
|
||||
unsafe {
|
||||
self.function
|
||||
.clone()
|
||||
.launch(
|
||||
LaunchConfig::for_num_elems(n_elements as u32),
|
||||
(&mut out, n_elements as i32),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(CudaData(out)),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct ARangeCompiler<T: CudaFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: CudaFloat> Compiler for ARangeCompiler<T>
|
||||
where
|
||||
CudaData<T>: Data,
|
||||
{
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = CudaDevice::new(0).unwrap();
|
||||
let (
|
||||
mut one_const,
|
||||
mut contig1,
|
||||
mut contig2,
|
||||
mut contig3,
|
||||
mut contig4,
|
||||
mut sum_reduce,
|
||||
mut subtraction_constant,
|
||||
mut subtraction,
|
||||
) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
|
||||
// TODO: Make sure this actually checks the shape transformations to ensure pooling happens
|
||||
let contig = SelectOp::new().ty::<CudaContiguous<T>>();
|
||||
let pre_sub_pattern = select_const!(1.0, T)
|
||||
.ptr(&mut one_const)
|
||||
.edge(contig.clone().ptr(&mut contig1))
|
||||
.edge(contig.clone().ptr(&mut contig2))
|
||||
.edge(contig.clone().ptr(&mut contig3))
|
||||
.edge(contig.clone().ptr(&mut contig4))
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<CudaSumReduce<T>>()
|
||||
.ptr(&mut sum_reduce),
|
||||
);
|
||||
let mut s1 = pre_sub_pattern
|
||||
.clone()
|
||||
.edge(
|
||||
select_const!(1.0, T)
|
||||
.ptr(&mut subtraction_constant)
|
||||
.edge(SelectOp::new().ty::<CudaSub<T>>().ptr(&mut subtraction)),
|
||||
)
|
||||
.search(graph);
|
||||
let mut s2 = pre_sub_pattern
|
||||
.edge(
|
||||
select_const!(-1.0, T)
|
||||
.ptr(&mut subtraction_constant)
|
||||
.edge(SelectOp::new().ty::<CudaAdd<T>>().ptr(&mut subtraction)),
|
||||
)
|
||||
.search(graph);
|
||||
|
||||
while s1.next_match() || s2.next_match() {
|
||||
let arange_amount = {
|
||||
let sh = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(one_const, contig1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2;
|
||||
sh.dims[sh.indexes[sh.len() - 1]]
|
||||
};
|
||||
let arange_op = graph
|
||||
.add_op(CudaARange::<T>::new(
|
||||
dev.clone(),
|
||||
arange_amount.into(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish();
|
||||
move_outgoing_edge(subtraction, arange_op, &mut graph.graph);
|
||||
|
||||
graph.graph.remove_node(subtraction);
|
||||
graph.safe_remove_node(subtraction_constant, 0);
|
||||
graph.safe_remove_node(sum_reduce, 0);
|
||||
graph.safe_remove_node(contig4, 0);
|
||||
graph.safe_remove_node(contig3, 0);
|
||||
graph.safe_remove_node(contig2, 0);
|
||||
graph.safe_remove_node(contig1, 0);
|
||||
graph.safe_remove_node(one_const, 0);
|
||||
s1.clear_cached_results();
|
||||
s2.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
1717
crates/luminal_cuda/src/prim.rs
Normal file
1717
crates/luminal_cuda/src/prim.rs
Normal file
File diff suppressed because it is too large
Load Diff
1030
crates/luminal_cuda/src/tests/fp16.rs
Normal file
1030
crates/luminal_cuda/src/tests/fp16.rs
Normal file
File diff suppressed because it is too large
Load Diff
973
crates/luminal_cuda/src/tests/fp32.rs
Normal file
973
crates/luminal_cuda/src/tests/fp32.rs
Normal file
@@ -0,0 +1,973 @@
|
||||
use dfdx::prelude::{Module as DfdxModule, *};
|
||||
use itertools::Itertools;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use luminal::{
|
||||
nn::{activation::ReLU, linear::Linear, norm::RMSNorm},
|
||||
prelude::{symbolic::Expression, Module, *},
|
||||
};
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use dfdx::prelude::{
|
||||
Axes as DAxes, Axes2 as DAxes2, Axes3 as DAxes3, Axes4 as DAxes4, Axes5 as DAxes5,
|
||||
Axis as DAxis, Const as DConst, *,
|
||||
};
|
||||
#[allow(unused_imports)]
|
||||
use luminal::{
|
||||
prelude::{
|
||||
Axes as LAxes, Axes2 as LAxes2, Axes3 as LAxes3, Axes4 as LAxes4, Axes5 as LAxes5,
|
||||
Axis as LAxis, Const as LConst, *,
|
||||
},
|
||||
tests::{
|
||||
assert_close, assert_close_precision, assert_exact, random_vec, random_vec_rng, test_graphs,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::CudaCompiler;
|
||||
|
||||
#[test]
|
||||
fn test_contiguous() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
|
||||
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<3>, DConst::<4>));
|
||||
let d_b = d_a.permute::<Rank2<4, 3>, _>().reshape::<Rank2<12, 1>>();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R2<1, 12>>().set(data.clone());
|
||||
let mut b = a.softmax::<1>().retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<12>));
|
||||
let d_b = d_a.softmax::<DAxis<1>>();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotate() {
|
||||
let mut cx = Graph::new();
|
||||
const D: usize = 2;
|
||||
const S: usize = 2;
|
||||
const H: usize = 2;
|
||||
let data = random_vec(D * S * H);
|
||||
let a = cx
|
||||
.tensor::<R4<1, D, S, H>>()
|
||||
.set(data)
|
||||
.keep()
|
||||
.permute::<_, LAxes4<0, 2, 1, 3>>();
|
||||
let x1 = a.slice((.., .., .., ..Expression::from(H / 2)));
|
||||
let x2 = a.slice((.., .., .., Expression::from(H / 2)..));
|
||||
let mut rotated_a = (-x2)
|
||||
.concat_along::<R4<1, S, D, H>, LAxis<3>, _>(x1)
|
||||
.retrieve();
|
||||
cx.execute();
|
||||
let unopt = rotated_a.data();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut rotated_a);
|
||||
cx.execute();
|
||||
|
||||
assert_close(&unopt, &rotated_a.data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.constant_expr('a');
|
||||
let mut a = (a * a).retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut a);
|
||||
|
||||
cx.set_dyn_dim('a', 10);
|
||||
cx.execute();
|
||||
assert_exact(&a.data(), &[100.0]);
|
||||
a.drop();
|
||||
cx.set_dyn_dim('a', 25);
|
||||
cx.execute();
|
||||
assert_exact(&a.data(), &[625.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(data.clone());
|
||||
let mut b = a.log2().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
assert_close(
|
||||
&b.data(),
|
||||
&data.into_iter().map(|i| i.log2()).collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exp2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(data.clone());
|
||||
let mut b = a.exp2().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
assert_close(
|
||||
&b.data(),
|
||||
&data.into_iter().map(|i: f32| i.exp2()).collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recip() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 4096.]);
|
||||
let mut b = a.recip().retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 4096.]);
|
||||
let d_b = d_a.recip();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sin() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut b = a.sin().retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_a.sin();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut b = a / a.sqrt();
|
||||
b.retrieve();
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_a.clone() / d_a.sqrt();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = (a + b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a + d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a - b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a - d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_square() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = rand::thread_rng();
|
||||
let data = random_vec_rng(40960, &mut rng);
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'s'>, luminal::prelude::Const<4096>)>()
|
||||
.set_dyn(data.clone(), &[1, 10, 4096]);
|
||||
let mut b = a * a;
|
||||
b.retrieve();
|
||||
|
||||
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec::<Rank3<1, 10, 4096>>(
|
||||
data,
|
||||
(
|
||||
dfdx::prelude::Const::<1>,
|
||||
dfdx::prelude::Const::<10>,
|
||||
dfdx::prelude::Const::<4096>,
|
||||
),
|
||||
);
|
||||
let d_b = d_a.clone() * d_a;
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a * b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a * d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul2() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<(LConst<1>, LConst<1>, Dyn<'a'>, Dyn<'a'>)>()
|
||||
.set_dyn(vec![82.4, 783.0, 99.6, 974.5], &[1, 1, 2, 2]);
|
||||
let b = cx.tensor::<R0>().set(vec![0.57735026]);
|
||||
let mut c = (a * b.expand()).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([[[[82.4, 783.0], [99.6, 974.5]]]]);
|
||||
let d_b = d_dev.tensor(0.57735026);
|
||||
let d_c = d_a * d_b.broadcast::<_, dfdx::shapes::Axes4<0, 1, 2, 3>>();
|
||||
|
||||
assert_exact(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a / b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a / d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a.max(b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a.maximum(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mod() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(3);
|
||||
let b_data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R1<3>>().set(b_data.clone());
|
||||
let mut c = a % b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
// No dfdx equivalent
|
||||
|
||||
assert_close(
|
||||
&c.data(),
|
||||
&a_data
|
||||
.into_iter()
|
||||
.zip(b_data)
|
||||
.map(|(a, b)| a % b)
|
||||
.collect_vec(),
|
||||
);
|
||||
}
|
||||
|
||||
// Reduction op tests
|
||||
|
||||
#[test]
|
||||
fn test_sum_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
|
||||
let d_b = d_a.clone().sum::<_, DAxis<2>>();
|
||||
let d_c = d_a.clone().sum::<_, DAxis<1>>();
|
||||
let d_d = d_a.sum::<_, DAxis<0>>();
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_reduce2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(32 * 10 * 10 * 128);
|
||||
let a = cx.tensor::<R5<1, 32, 10, 10, 128>>().set(data.clone());
|
||||
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut d);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
data,
|
||||
(
|
||||
DConst::<1>,
|
||||
DConst::<32>,
|
||||
DConst::<10>,
|
||||
DConst::<10>,
|
||||
DConst::<128>,
|
||||
),
|
||||
);
|
||||
let d_d = d_a.sum::<_, DAxis<2>>();
|
||||
|
||||
assert_exact(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.max_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.max_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.max_reduce::<_, LAxis<0>>().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
|
||||
let d_b = d_a.clone().max::<_, DAxis<2>>();
|
||||
let d_c = d_a.clone().max::<_, DAxis<1>>();
|
||||
let d_d = d_a.max::<_, DAxis<0>>();
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>));
|
||||
let d_b = d_a.clone().mean::<_, DAxis<2>>();
|
||||
let d_c = d_a.clone().mean::<_, DAxis<1>>();
|
||||
let d_d = d_a.mean::<_, DAxis<0>>();
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_simple() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(256 * 256);
|
||||
let b_data = random_vec(256 * 256);
|
||||
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<256>, DConst::<256>));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<256>, DConst::<256>));
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul() {
|
||||
let d_dev = Cpu::default();
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
|
||||
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
for m in (1..23).step_by(4) {
|
||||
for k in (1..35).step_by(3) {
|
||||
for n in (1..70).step_by(7) {
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let b_data = random_vec_rng(k * n, &mut rng);
|
||||
a.set_dyn(a_data.clone(), &[m, k]);
|
||||
b.set_dyn(b_data.clone(), &[k, n]);
|
||||
cx.execute();
|
||||
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (m, k));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (k, n));
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
c.drop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attn_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
|
||||
let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
|
||||
let a = cx
|
||||
.named_tensor::<R4<1, 32, 11, 128>>("Input")
|
||||
.set(a_data.clone())
|
||||
.keep();
|
||||
let b = cx
|
||||
.named_tensor::<R4<1, 32, 128, 11>>("Input")
|
||||
.set(b_data.clone())
|
||||
.keep();
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
a_data,
|
||||
(DConst::<1>, DConst::<32>, DConst::<11>, DConst::<128>),
|
||||
);
|
||||
let d_b = d_dev.tensor_from_vec(
|
||||
b_data,
|
||||
(DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>),
|
||||
);
|
||||
let d_c = d_a.matmul(d_b);
|
||||
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_matmul() {
|
||||
let m = 12;
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>();
|
||||
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
for batch in (1..23).step_by(4) {
|
||||
for k in (1..35).step_by(3) {
|
||||
for n in (1..48).step_by(7) {
|
||||
let a_data = random_vec_rng(batch * m * k, &mut rng);
|
||||
let b_data = random_vec_rng(k * n, &mut rng);
|
||||
a.set_dyn(a_data.clone(), &[batch, m, k]);
|
||||
b.set_dyn(b_data.clone(), &[k, n]);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (batch, m, k));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (k, n));
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close_precision(&c.data(), &d_c.to_dtype::<f32>().as_vec(), 2);
|
||||
c.drop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_matmul_transpose() {
|
||||
const B: usize = 1;
|
||||
const M: usize = 48; // Any
|
||||
const K: usize = 4096; // >= 16, multiple of 16
|
||||
const N: usize = 4096; // >= 256, multiple of 256
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(B * M * K, &mut rng);
|
||||
let a = cx.named_tensor::<R3<B, M, K>>("A").set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.named_tensor::<R2<N, K>>("B").set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(B * K * M, &mut rng);
|
||||
let a_t = cx.named_tensor::<R3<B, K, M>>("A_T").set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.named_tensor::<R2<K, N>>("B_T").set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a.matmul(b.permute::<_, LAxes2<1, 0>>()).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, LAxes3<0, 2, 1>>()
|
||||
.matmul(b.permute::<_, LAxes2<1, 0>>())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t.permute::<_, LAxes3<0, 2, 1>>().matmul(b_t).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, CudaCompiler<f32>)>::default(),
|
||||
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<B>, DConst::<M>, DConst::<K>));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<N>, DConst::<K>));
|
||||
let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::<B>, DConst::<K>, DConst::<M>));
|
||||
let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::<K>, DConst::<N>));
|
||||
let d_a_b = d_a.clone().matmul(d_b.clone().permute::<_, DAxes2<1, 0>>());
|
||||
let d_a_b_t = d_a.matmul(d_b_t.clone());
|
||||
let d_a_t_b = d_a_t
|
||||
.clone()
|
||||
.permute::<_, DAxes3<0, 2, 1>>()
|
||||
.matmul(d_b.permute::<_, DAxes2<1, 0>>());
|
||||
let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t);
|
||||
|
||||
assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 1);
|
||||
assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 1);
|
||||
assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 1);
|
||||
assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_transpose() {
|
||||
const M: usize = 1024; // Any
|
||||
const K: usize = 16; // >= 16
|
||||
const N: usize = 767; // >= 256, multiple of 256
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(M * K, &mut rng);
|
||||
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(K * M, &mut rng);
|
||||
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a.matmul(b.permute()).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, LAxes2<1, 0>>()
|
||||
.matmul(b.permute())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t.permute::<_, LAxes2<1, 0>>().matmul(b_t).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, CudaCompiler<f32>)>::default(),
|
||||
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<M>, DConst::<K>));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (DConst::<N>, DConst::<K>));
|
||||
let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::<K>, DConst::<M>));
|
||||
let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::<K>, DConst::<N>));
|
||||
let d_a_b = d_a.clone().matmul(d_b.clone().permute());
|
||||
let d_a_b_t = d_a.matmul(d_b_t.clone());
|
||||
let d_a_t_b = d_a_t
|
||||
.clone()
|
||||
.permute::<_, DAxes2<1, 0>>()
|
||||
.matmul(d_b.permute());
|
||||
let d_a_t_b_t = d_a_t.permute::<_, DAxes2<1, 0>>().matmul(d_b_t);
|
||||
|
||||
assert_close(&a_b.data(), &d_a_b.as_vec());
|
||||
assert_close(&a_b_t.data(), &d_a_b_t.as_vec());
|
||||
assert_close(&a_t_b.data(), &d_a_t_b.as_vec());
|
||||
assert_close(&a_t_b_t.data(), &d_a_t_b_t.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_and_linear() {
|
||||
// Test single and batch, unoptimized and optimized
|
||||
let mut cx = Graph::new();
|
||||
let input_data = random_vec(32);
|
||||
let w1 = random_vec(32 * 64);
|
||||
let w2 = random_vec(32 * 64);
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 32>>("Batch")
|
||||
.set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
|
||||
|
||||
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
|
||||
model.0.weight.set(w1.clone());
|
||||
model.2.weight.set(w2.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
let mut batch_out = model.forward(batch).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let unoptimized_b = b.data();
|
||||
let unoptimized_batch_out = batch_out.data();
|
||||
b.drop();
|
||||
batch_out.drop();
|
||||
cx.compile(
|
||||
<(GenericCompiler, CudaCompiler<f32>)>::default(),
|
||||
(&mut b, &mut batch_out),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
assert_close_precision(&unoptimized_b, &b.data(), 2);
|
||||
assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2);
|
||||
|
||||
// Test against dfdx
|
||||
let dev = Cpu::default();
|
||||
let mut model = <(
|
||||
dfdx::nn::modules::builders::UnbiasedLinear<32, 64>,
|
||||
dfdx::nn::modules::builders::ReLU,
|
||||
dfdx::nn::modules::builders::UnbiasedLinear<64, 32>,
|
||||
)>::build_on_device(&dev);
|
||||
// Set weights
|
||||
model.0.weight = dev
|
||||
.tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>))
|
||||
.permute();
|
||||
model.2.weight = dev
|
||||
.tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>))
|
||||
.permute();
|
||||
let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,));
|
||||
let out = model.forward(a);
|
||||
|
||||
assert_close_precision(&unoptimized_b, &out.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rms_norm() {
|
||||
// Test single and batch, unoptimized and optimized
|
||||
let inp_data = random_vec(15 * 32);
|
||||
let weight_data = random_vec(32);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());
|
||||
|
||||
let model = RMSNorm::<32>::initialize(&mut cx);
|
||||
model.weight.set(weight_data.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
|
||||
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
// Test against dfdx
|
||||
let dev = Cpu::default();
|
||||
let weight = dev.tensor_from_vec(weight_data, (DConst::<32>,));
|
||||
let a = dev.tensor_from_vec(inp_data, (DConst::<15>, DConst::<32>));
|
||||
let var_f32 = a.clone().square().mean::<_, DAxis<1>>();
|
||||
let std_f32 = (var_f32 + 1e-6).sqrt();
|
||||
let x_f32 = a / std_f32.broadcast();
|
||||
let out = weight.broadcast() * x_f32;
|
||||
|
||||
assert_close(&b.data(), &out.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_norm() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(15 * 16 * 32);
|
||||
let a = cx.tensor::<R3<15, 16, 32>>().set(a_data.clone());
|
||||
let mut b = a.layer_norm::<0, _>(1e-5).retrieve();
|
||||
let mut c = a.layer_norm::<2, _>(1e-5).retrieve();
|
||||
cx.compile(
|
||||
<(GenericCompiler, CudaCompiler<f32>)>::default(),
|
||||
(&mut b, &mut c),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<15>, DConst::<16>, DConst::<32>));
|
||||
let d_b = d_a.clone().normalize::<DAxis<0>>(1e-5);
|
||||
let d_c = d_a.normalize::<DAxis<2>>(1e-5);
|
||||
|
||||
assert_close_precision(&b.data(), &d_b.as_vec(), 2);
|
||||
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transformer_encoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<32, 64, 1> =
|
||||
InitModule::initialize(&mut cx);
|
||||
let w_k_weight = random_vec(32 * 32);
|
||||
model.attention.w_k.weight.set(w_k_weight.clone());
|
||||
let w_q_weight = random_vec(32 * 32);
|
||||
model.attention.w_q.weight.set(w_q_weight.clone());
|
||||
let w_v_weight = random_vec(32 * 32);
|
||||
model.attention.w_v.weight.set(w_v_weight.clone());
|
||||
let w_o_weight = random_vec(32 * 32);
|
||||
model.attention.w_o.weight.set(w_o_weight.clone());
|
||||
let ff_0_weight = random_vec(32 * 64);
|
||||
model.ff.0.weight.set(ff_0_weight.clone());
|
||||
let ff_1_weight = random_vec(64 * 32);
|
||||
model.ff.2.weight.set(ff_1_weight.clone());
|
||||
|
||||
let a_data = random_vec(2 * 32);
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>()
|
||||
.set_dyn(a_data.clone(), &[1, 2, 3])
|
||||
.keep();
|
||||
cx.keep_tensors(state_dict(&model));
|
||||
let mut b = model.forward(a).retrieve();
|
||||
cx.execute();
|
||||
let unopt_b = b.data();
|
||||
b.drop();
|
||||
|
||||
cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
assert_close_precision(&unopt_b, &b.data(), 2);
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> =
|
||||
d_dev
|
||||
.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<32, 1, 64>, f32>();
|
||||
d_model.self_attn.w_k.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_v.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_q.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_o.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_o.weight = d_dev
|
||||
.tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>))
|
||||
.permute();
|
||||
d_model.self_attn.w_k.weight = d_dev
|
||||
.tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>))
|
||||
.permute();
|
||||
d_model.self_attn.w_q.weight = d_dev
|
||||
.tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>))
|
||||
.permute();
|
||||
d_model.self_attn.w_v.weight = d_dev
|
||||
.tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>))
|
||||
.permute();
|
||||
d_model.ff.0 .0.weight = d_dev
|
||||
.tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>))
|
||||
.permute();
|
||||
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,));
|
||||
d_model.ff.0 .2.weight = d_dev
|
||||
.tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>))
|
||||
.permute();
|
||||
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
|
||||
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
|
||||
d_model.norm1.epsilon = 1e-5;
|
||||
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm2.epsilon = 1e-5;
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>));
|
||||
let d_b = d_model.forward(d_a);
|
||||
|
||||
assert_close_precision(&b.data(), &d_b.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_common_buffer() {
|
||||
let data = random_vec(32);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<32>>();
|
||||
a.set(data.clone());
|
||||
let a1 = cx.tensor::<R1<32>>();
|
||||
a1.set(data.clone());
|
||||
let exped = a * a1;
|
||||
let mut b = exped.log2().retrieve();
|
||||
let mut c = exped.sin().retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut c));
|
||||
cx.execute();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding() {
|
||||
let mut cx = Graph::new();
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 3>>("Batch")
|
||||
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0])
|
||||
.keep();
|
||||
let a = cx
|
||||
.named_tensor::<R1<3>>("Single")
|
||||
.set(vec![1.0, 0.0, 1.0])
|
||||
.keep();
|
||||
|
||||
let model: luminal::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx);
|
||||
model
|
||||
.weight
|
||||
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
|
||||
let mut b = model.forward(a).retrieve();
|
||||
let mut batch_out = model.forward(batch).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut b, &mut batch_out));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let mut d_model: modules::Embedding<3, 4, f32, Cpu> =
|
||||
<dfdx::nn::modules::builders::Embedding<3, 4>>::build_on_device(&d_dev);
|
||||
d_model.weight = d_dev.tensor_from_vec(
|
||||
vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.],
|
||||
(DConst::<3>, DConst::<4>),
|
||||
);
|
||||
let d_a = d_dev.tensor_from_vec(vec![1, 0, 1], (DConst::<3>,));
|
||||
let d_batch = d_dev.tensor_from_vec(vec![1, 0, 2, 1, 0, 1], (DConst::<2>, DConst::<3>));
|
||||
|
||||
let d_b = d_model.forward(d_a);
|
||||
let d_batch_out = d_model.forward(d_batch);
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&batch_out.data(), &d_batch_out.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice() {
|
||||
let data = random_vec(256);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<256>>().set(data.clone());
|
||||
let mut c: GraphTensor<R1<20>> = a
|
||||
.slice((..Expression::from(20),))
|
||||
.realize()
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<256>,));
|
||||
let d_c = d_a.slice((..20,));
|
||||
|
||||
assert_exact(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pad() {
|
||||
// Pad a 8x2 mat to 10x4
|
||||
let data = random_vec(8 * 2);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R2<8, 2>>().set(data.clone());
|
||||
let mut c = a
|
||||
.pad::<R2<10, 4>, _, _>(&[(0, 2), (0, 2)])
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (8, 2));
|
||||
// There is no pad function in dfdx, so we concat with zero tensors
|
||||
let d_b = (d_a, d_dev.zeros_like(&(2, 2))).concat_along(DAxis::<0>);
|
||||
let d_c = (d_b, d_dev.zeros_like(&(10, 2))).concat_along(DAxis::<1>);
|
||||
|
||||
assert_exact(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pad_contig() {
|
||||
let m = 13;
|
||||
let k = 24;
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let mut a = cx
|
||||
.tensor::<(Dyn<'M'>, Dyn<'K'>)>()
|
||||
.set_dyn(a_data, &[m, k])
|
||||
.retrieve();
|
||||
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
|
||||
.pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')])
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
|
||||
(a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), (&mut a, &mut b, &mut c));
|
||||
cx.execute();
|
||||
|
||||
// Close because b and c are going through 16 bits, while a is not
|
||||
assert_close(&a.data(), &b.data());
|
||||
assert_close(&a.data(), &c.data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_movement() {
|
||||
let data = random_vec(32);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<32>>().set(data.clone());
|
||||
let b: GraphTensor<R1<42>> = a.pad(&[(0, 10)]).contiguous().retrieve();
|
||||
let mut c: GraphTensor<R1<25>> = b
|
||||
.slice((..Expression::from(25),))
|
||||
.realize()
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
|
||||
cx.compile(CudaCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (DConst::<32>,));
|
||||
let d_c = d_a.slice((..25,));
|
||||
|
||||
assert_exact(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
2
crates/luminal_cuda/src/tests/mod.rs
Normal file
2
crates/luminal_cuda/src/tests/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
mod fp16;
|
||||
mod fp32;
|
||||
19
crates/luminal_metal/Cargo.toml
Normal file
19
crates/luminal_metal/Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "luminal_metal"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
description = "Metal compiler for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.12.1"
|
||||
luminal = { path = "../.." }
|
||||
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"] }
|
||||
num-traits = "0.2.18"
|
||||
rand = "0.8.5"
|
||||
rustc-hash = "1.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = { version = "0.13", features = ["f16"] }
|
||||
652
crates/luminal_metal/src/binary.rs
Normal file
652
crates/luminal_metal/src/binary.rs
Normal file
@@ -0,0 +1,652 @@
|
||||
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
|
||||
|
||||
use itertools::Itertools;
|
||||
use metal_rs::{
|
||||
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
|
||||
ComputePipelineState, Device, MTLResourceOptions, MTLSize,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
compile_function, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims,
|
||||
render_dyn_dim_inputs, select_const, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel,
|
||||
MetalKernelWrapper, SetInt,
|
||||
};
|
||||
|
||||
use super::prim::*;
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::{
|
||||
petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction},
|
||||
*,
|
||||
},
|
||||
shape::symbolic::BigExpression,
|
||||
};
|
||||
|
||||
use super::other::MetalARange;
|
||||
|
||||
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
|
||||
pub struct MetalSub<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalSub<T> {
|
||||
pub fn new(
|
||||
a_shape: ShapeTracker,
|
||||
b_shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (a_idx_exp, a_valid_exp) = get_idx_valid_exps(a_shape);
|
||||
let (b_idx_exp, b_valid_exp) = get_idx_valid_exps(b_shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape], 4);
|
||||
let type_name = T::type_name();
|
||||
let code = format!(
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements) {{
|
||||
out[idx] =
|
||||
(({a_valid_exp}) == 0 ? 0.0 : inp_a[{a_idx_exp}])
|
||||
- (({b_valid_exp}) == 0 ? 0.0 : inp_b[{b_idx_exp}]);
|
||||
}}
|
||||
}}
|
||||
");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_symbols,
|
||||
dyn_map,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalSub<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(inputs[1].0), 0);
|
||||
encoder.set_buffer(2, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(3, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
4,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalSub<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[
|
||||
(get_buffer_from_tensor(&tensors[0].0), tensors[0].1),
|
||||
(get_buffer_from_tensor(&tensors[1].0), tensors[1].1),
|
||||
],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
// This op can accept non contiguous inputs
|
||||
if key == "non_contiguous" {
|
||||
return Some(Box::new(()));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("input0 - input1".to_string()));
|
||||
}
|
||||
if key == "recompile_shapes" {
|
||||
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
|
||||
*self = Self::new(
|
||||
input_shapes[0],
|
||||
input_shapes[1],
|
||||
self.device.clone(),
|
||||
self.queue.clone(),
|
||||
self.dyn_map,
|
||||
)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct MetalSubtractionCompiler<T: MetalFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for MetalSubtractionCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let (mut neg_one, mut mul, mut add) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let mut searcher = select_const!(-1.0, T)
|
||||
.ptr(&mut neg_one)
|
||||
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul))
|
||||
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add))
|
||||
.search(graph);
|
||||
|
||||
while searcher.next_match() {
|
||||
if check_no_delete(graph, &[neg_one, mul, add]) {
|
||||
continue;
|
||||
}
|
||||
let (a, a_edge) = graph
|
||||
.graph
|
||||
.edges_directed(add, petgraph::Direction::Incoming)
|
||||
.find(|e| e.source() != mul)
|
||||
.map(|e| (e.source(), e.weight().as_data().unwrap()))
|
||||
.unwrap();
|
||||
let (b, b_edge) = graph
|
||||
.graph
|
||||
.edges_directed(mul, petgraph::Direction::Incoming)
|
||||
.find(|e| e.source() != neg_one)
|
||||
.map(|e| (e.source(), e.weight().as_data().unwrap()))
|
||||
.unwrap();
|
||||
let b_final_shape = graph
|
||||
.graph
|
||||
.edges_connecting(mul, add)
|
||||
.next()
|
||||
.unwrap()
|
||||
.weight()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2;
|
||||
if !b_final_shape.is_contiguous()
|
||||
|| b_final_shape.is_sliced()
|
||||
|| b_final_shape.is_padded()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let sub = graph
|
||||
.add_op(MetalSub::<T>::new(
|
||||
a_edge.2,
|
||||
b_edge.2,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(a, a_edge.1, a_edge.2)
|
||||
.input(b, b_edge.1, b_edge.2)
|
||||
.finish();
|
||||
move_outgoing_edge(add, sub, &mut graph.graph);
|
||||
|
||||
if graph.get_dests(neg_one).len() == 1 {
|
||||
graph.graph.remove_node(neg_one);
|
||||
}
|
||||
graph.graph.remove_node(mul);
|
||||
graph.graph.remove_node(add);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqTrue, LuminalPrint, Clone)]
|
||||
pub struct MetalEqual<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalEqual<T> {
|
||||
pub fn new(
|
||||
a_shape: ShapeTracker,
|
||||
b_shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (a_idx_exp, a_valid_exp) = get_idx_valid_exps(a_shape);
|
||||
let (b_idx_exp, b_valid_exp) = get_idx_valid_exps(b_shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[a_shape, b_shape], 4);
|
||||
let type_name = T::type_name();
|
||||
let code = format!(
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements) {{
|
||||
{type_name} a_val = (({a_valid_exp}) == 0 ? 0.0 : inp_a[{a_idx_exp}]);
|
||||
{type_name} b_val = (({b_valid_exp}) == 0 ? 0.0 : inp_b[{b_idx_exp}]);
|
||||
out[idx] = ({type_name})(a_val == b_val);
|
||||
}}
|
||||
}}
|
||||
");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_symbols,
|
||||
dyn_map,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalEqual<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(inputs[1].0), 0);
|
||||
encoder.set_buffer(2, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(3, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
4,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalEqual<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[
|
||||
(get_buffer_from_tensor(&tensors[0].0), tensors[0].1),
|
||||
(get_buffer_from_tensor(&tensors[1].0), tensors[1].1),
|
||||
],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
// This op can accept non contiguous inputs
|
||||
if key == "non_contiguous" {
|
||||
return Some(Box::new(()));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("input0 == input1 ? 1.0 : 0.0".to_string()));
|
||||
}
|
||||
if key == "recompile_shapes" {
|
||||
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
|
||||
*self = Self::new(
|
||||
input_shapes[0],
|
||||
input_shapes[1],
|
||||
self.device.clone(),
|
||||
self.queue.clone(),
|
||||
self.dyn_map,
|
||||
)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct MetalEqualCompiler<T: MetalFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for MetalEqualCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let (mut less_than1, mut less_than2, mut add, mut one, mut sub) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let s = select_const!(1.0, T).ptr(&mut one).edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalLessThan<T>>()
|
||||
.ptr(&mut less_than1)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalLessThan<T>>()
|
||||
.ptr(&mut less_than2)
|
||||
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add)),
|
||||
)
|
||||
.edge(SelectOp::new().ty::<MetalSub<T>>().ptr(&mut sub)),
|
||||
);
|
||||
|
||||
let mut searcher = s.search(graph);
|
||||
while searcher.next_match() {
|
||||
let lt1_inputs = graph
|
||||
.graph
|
||||
.neighbors_directed(less_than1, Direction::Incoming)
|
||||
.sorted()
|
||||
.collect::<Vec<_>>();
|
||||
let lt2_inputs = graph
|
||||
.graph
|
||||
.neighbors_directed(less_than2, Direction::Incoming)
|
||||
.sorted()
|
||||
.collect::<Vec<_>>();
|
||||
if lt1_inputs != lt2_inputs {
|
||||
continue;
|
||||
}
|
||||
let inputs = graph
|
||||
.graph
|
||||
.edges_directed(less_than1, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
|
||||
.map(|e| e.source())
|
||||
.collect::<Vec<_>>();
|
||||
let (a, b) = (inputs[0], inputs[1]);
|
||||
if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) {
|
||||
continue;
|
||||
}
|
||||
let a_edge = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(a, less_than1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap();
|
||||
let b_edge = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(b, less_than1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap();
|
||||
let equals = graph
|
||||
.add_op(MetalEqual::<T>::new(
|
||||
a_edge.2,
|
||||
b_edge.2,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(a, a_edge.1, a_edge.2)
|
||||
.input(b, b_edge.1, b_edge.2)
|
||||
.finish();
|
||||
move_outgoing_edge(sub, equals, &mut graph.graph);
|
||||
|
||||
graph.graph.remove_node(sub);
|
||||
graph.safe_remove_node(add, 0);
|
||||
graph.safe_remove_node(one, 0);
|
||||
graph.safe_remove_node(less_than2, 0);
|
||||
graph.safe_remove_node(less_than1, 0);
|
||||
searcher.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
|
||||
pub struct MetalGather<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
pub embed_dim: usize,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalGather<T> {
|
||||
fn new(device: Device, queue: CommandQueue, embed_dim: usize) -> Self {
|
||||
let type_name = T::type_name();
|
||||
Self {pipeline: compile_function("metal_gather", &format!(
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void metal_gather(device float *inp [[buffer(0)]], device {type_name} *weights [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_embeddings [[buffer(3)]], device int& embedding_dim [[buffer(4)]], uint2 i_ [[thread_position_in_grid]]) {{
|
||||
if (i_.x < n_embeddings && i_.y < embedding_dim) {{
|
||||
out[i_.x * embedding_dim + i_.y] = weights[(int)inp[i_.x] * embedding_dim + i_.y];
|
||||
}}
|
||||
}}"), &device), device, embed_dim, queue, _phantom: Default::default()}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalGather<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup buffers
|
||||
let indexes = tensors[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<Vec<f32>>()
|
||||
.unwrap();
|
||||
let index_buffer = self.device.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(indexes.as_ptr()) },
|
||||
(indexes.len() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let b_inp = tensors[1]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<MetalBuffer>()
|
||||
.unwrap();
|
||||
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let out = self.device.new_buffer(
|
||||
(indexes.len() * self.embed_dim * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
let encoder = command_buffer
|
||||
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(&index_buffer), 0);
|
||||
encoder.set_buffer(1, Some(b_inp), 0);
|
||||
encoder.set_buffer(2, Some(&out), 0);
|
||||
encoder.set_u32(3, indexes.len() as u32);
|
||||
encoder.set_u32(4, self.embed_dim as u32);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_threads(
|
||||
MTLSize {
|
||||
width: indexes.len() as u64,
|
||||
height: self.embed_dim as u64,
|
||||
depth: 1,
|
||||
},
|
||||
MTLSize {
|
||||
width: 16,
|
||||
height: 16,
|
||||
depth: 1,
|
||||
},
|
||||
);
|
||||
encoder.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct MetalGatherCompiler<T: MetalFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for MetalGatherCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let (mut ind_copy, mut arange, mut equal, mut mul, mut sum_reduce) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
let s = SelectOp::new()
|
||||
.ty::<MetalARange<T>>()
|
||||
.ptr(&mut arange)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalCopyToDevice<T>>()
|
||||
.ptr(&mut ind_copy)
|
||||
.edge(SelectOp::new().ty::<MetalEqual<T>>().ptr(&mut equal)),
|
||||
)
|
||||
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul))
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalSumReduce<T>>()
|
||||
.ptr(&mut sum_reduce),
|
||||
);
|
||||
let mut searcher = s.search(graph);
|
||||
while searcher.next_match() {
|
||||
if check_no_delete(graph, &[arange, equal, mul, sum_reduce]) {
|
||||
continue;
|
||||
}
|
||||
let embedding_dim = graph
|
||||
.graph
|
||||
.edges_directed(mul, Direction::Incoming)
|
||||
.find(|e| e.source() != equal && !e.weight().is_schedule())
|
||||
.unwrap()
|
||||
.weight()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2
|
||||
.shape()[2]
|
||||
.to_usize()
|
||||
.unwrap();
|
||||
let gather = graph
|
||||
.add_op(MetalGather::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
embedding_dim,
|
||||
))
|
||||
.finish();
|
||||
move_incoming_edge(ind_copy, gather, &mut graph.graph);
|
||||
graph.safe_remove_node(equal, 1);
|
||||
move_incoming_edge(mul, gather, &mut graph.graph);
|
||||
move_outgoing_edge(sum_reduce, gather, &mut graph.graph);
|
||||
graph.graph.remove_node(sum_reduce);
|
||||
graph.safe_remove_node(mul, 0);
|
||||
graph.safe_remove_node(ind_copy, 0);
|
||||
graph.safe_remove_node(arange, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use luminal::{prelude::*, tests::assert_close};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
#[test]
|
||||
fn test_subtraction() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<R1<10>>()
|
||||
.set(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
|
||||
let b = cx.tensor::<R0>().set(vec![1.]);
|
||||
let mut c = (a - b.expand()).retrieve();
|
||||
let mut d = (-a + b.expand()).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
let unopt_c = c.data();
|
||||
c.drop();
|
||||
let unopt_d = d.data();
|
||||
d.drop();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
assert_close(&unopt_c, &c.data());
|
||||
assert_close(&unopt_d, &d.data());
|
||||
}
|
||||
}
|
||||
320
crates/luminal_metal/src/command_buffer.rs
Normal file
320
crates/luminal_metal/src/command_buffer.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::{any::Any, cell::UnsafeCell, ops::Deref, sync::Arc};
|
||||
|
||||
use itertools::Itertools;
|
||||
use metal_rs::{Buffer, CommandBuffer, CommandQueue, Device};
|
||||
use petgraph::{
|
||||
stable_graph::NodeIndex,
|
||||
visit::EdgeRef,
|
||||
Direction::{self},
|
||||
};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{MetalBuffer, MetalKernel, MetalKernelWrapper};
|
||||
|
||||
use super::get_buffer_from_tensor;
|
||||
|
||||
#[derive(Default, LuminalPrint)]
|
||||
pub struct CommandBufferCompiler;
|
||||
|
||||
impl Compiler for CommandBufferCompiler {
|
||||
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, _: T) {
|
||||
let is_metal: FxHashSet<NodeIndex> = graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.filter(|i| {
|
||||
graph
|
||||
.graph
|
||||
.node_weight_mut(*i)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.is_some()
|
||||
})
|
||||
.collect();
|
||||
// Do forward pass
|
||||
let mut forward_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for node in graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|n| graph.graph.edges_directed(*n, Direction::Incoming).count() == 0)
|
||||
.sorted()
|
||||
{
|
||||
let mut stack = vec![node];
|
||||
while let Some(node) = stack.pop() {
|
||||
// Get rank as max of predecessors
|
||||
let rank = graph
|
||||
.graph
|
||||
.neighbors_directed(node, Direction::Incoming)
|
||||
.filter_map(|i| forward_map.get(&i).map(|r| (i, *r)))
|
||||
.map(|(node_index, rank)| {
|
||||
if is_metal.contains(&node) != is_metal.contains(&node_index) {
|
||||
rank + 1
|
||||
} else {
|
||||
rank
|
||||
}
|
||||
})
|
||||
.max()
|
||||
.unwrap_or_default();
|
||||
// Max it with the current entry in the map or insert
|
||||
if let Some(entry) = forward_map.get_mut(&node) {
|
||||
if rank > *entry {
|
||||
*entry = rank;
|
||||
stack.extend(graph.graph.neighbors_directed(node, Direction::Outgoing));
|
||||
}
|
||||
} else {
|
||||
forward_map.insert(node, rank);
|
||||
stack.extend(graph.graph.neighbors_directed(node, Direction::Outgoing));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do backward pass
|
||||
let mut backward_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for node in graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|n| graph.graph.edges_directed(*n, Direction::Outgoing).count() == 0)
|
||||
.sorted()
|
||||
{
|
||||
let mut stack = vec![node];
|
||||
while let Some(node) = stack.pop() {
|
||||
// Get rank as max of successors
|
||||
let rank = graph
|
||||
.graph
|
||||
.neighbors_directed(node, Direction::Outgoing)
|
||||
.filter_map(|i| backward_map.get(&i).map(|r| (i, *r)))
|
||||
.map(|(node_index, rank)| {
|
||||
if is_metal.contains(&node) != is_metal.contains(&node_index) {
|
||||
rank + 1
|
||||
} else {
|
||||
rank
|
||||
}
|
||||
})
|
||||
.max()
|
||||
.unwrap_or_default();
|
||||
// Max it with the current entry in the map or insert
|
||||
if let Some(entry) = backward_map.get_mut(&node) {
|
||||
if rank > *entry {
|
||||
*entry = rank;
|
||||
stack.extend(graph.graph.neighbors_directed(node, Direction::Incoming));
|
||||
}
|
||||
} else {
|
||||
backward_map.insert(node, rank);
|
||||
stack.extend(graph.graph.neighbors_directed(node, Direction::Incoming));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Get sets (Rank -> # of nodes with that rank)
|
||||
let forward_sets = forward_map
|
||||
.iter()
|
||||
.sorted_by_key(|(_, v)| **v)
|
||||
.group_by(|(_, v)| **v)
|
||||
.into_iter()
|
||||
.map(|(k, g)| (k, g.count()))
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
let backward_sets = backward_map
|
||||
.iter()
|
||||
.sorted_by_key(|(_, v)| **v)
|
||||
.group_by(|(_, v)| **v)
|
||||
.into_iter()
|
||||
.map(|(k, g)| (k, g.count()))
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
|
||||
// Assign nodes to sets
|
||||
let mut node_sets: FxHashMap<(bool, usize), FxHashSet<NodeIndex>> = FxHashMap::default();
|
||||
for node in graph.graph.node_indices().filter(|i| is_metal.contains(i)) {
|
||||
let forward_bigger =
|
||||
forward_sets[&forward_map[&node]] >= backward_sets[&backward_map[&node]];
|
||||
node_sets
|
||||
.entry((
|
||||
forward_bigger,
|
||||
if forward_bigger {
|
||||
forward_map[&node]
|
||||
} else {
|
||||
backward_map[&node]
|
||||
},
|
||||
))
|
||||
.and_modify(|set| {
|
||||
set.insert(node);
|
||||
})
|
||||
.or_insert({
|
||||
let mut set = FxHashSet::default();
|
||||
set.insert(node);
|
||||
set
|
||||
});
|
||||
}
|
||||
// Add sets to graph
|
||||
let dev = Device::system_default().unwrap();
|
||||
let mut queue = dev.new_command_queue();
|
||||
let mut num_buffers_on_queue = 0;
|
||||
for set in node_sets.values() {
|
||||
if num_buffers_on_queue >= 63 {
|
||||
num_buffers_on_queue = 0;
|
||||
queue = dev.new_command_queue();
|
||||
} else {
|
||||
num_buffers_on_queue += 1;
|
||||
}
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let buffer = Arc::new(UnsafeCell::new(queue.new_command_buffer().to_owned()));
|
||||
let exec = graph
|
||||
.add_op(ExecuteMetalKernels {
|
||||
queue: queue.clone(),
|
||||
buffer: buffer.clone(),
|
||||
})
|
||||
.finish();
|
||||
for node in set {
|
||||
// Create schedule dependency
|
||||
graph.add_schedule_dependency(*node, exec);
|
||||
// Wrap node in MetalKernelOperation
|
||||
let wrapper = graph
|
||||
.graph
|
||||
.node_weight_mut(*node)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.unwrap()
|
||||
.downcast::<MetalKernelWrapper>()
|
||||
.unwrap();
|
||||
*graph.graph.node_weight_mut(*node).unwrap() = Box::new(CommandBufferWrapper {
|
||||
wrapper,
|
||||
buffer: buffer.clone(),
|
||||
dyn_map: &graph.dyn_map,
|
||||
});
|
||||
// Create schedule dependencies from exec to consumers
|
||||
for outside_node in graph
|
||||
.graph
|
||||
.edges_directed(*node, Direction::Outgoing)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.map(|e| e.target())
|
||||
.filter(|n| !set.contains(n))
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
graph.add_schedule_dependency(exec, outside_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqFalse, LuminalPrint)]
|
||||
struct ExecuteMetalKernels {
|
||||
queue: CommandQueue,
|
||||
buffer: Arc<UnsafeCell<CommandBuffer>>,
|
||||
}
|
||||
|
||||
impl Operator for ExecuteMetalKernels {
|
||||
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let buffer = unsafe { &mut *self.buffer.get() };
|
||||
buffer.commit();
|
||||
buffer.wait_until_completed();
|
||||
*buffer = self.queue.new_command_buffer().to_owned();
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, LuminalEqFalse)]
|
||||
struct CommandBufferWrapper {
|
||||
wrapper: Box<MetalKernelWrapper>,
|
||||
buffer: Arc<UnsafeCell<CommandBuffer>>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CommandBufferWrapper {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalKernel({:?})", self.wrapper.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalKernel for CommandBufferWrapper {
|
||||
fn intermediate_buffer_sizes(
|
||||
&self,
|
||||
input_shapes: &[ShapeTracker],
|
||||
) -> Vec<symbolic::BigExpression> {
|
||||
self.wrapper.0.intermediate_buffer_sizes(input_shapes)
|
||||
}
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<symbolic::BigExpression> {
|
||||
self.wrapper.0.output_buffer_sizes(input_shapes)
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
_: &metal_rs::CommandBufferRef,
|
||||
intermediate_buffers: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
self.wrapper.0.metal_forward(
|
||||
inputs,
|
||||
unsafe { &*self.buffer.get() },
|
||||
intermediate_buffers,
|
||||
output_buffers,
|
||||
);
|
||||
}
|
||||
fn without_command_buffer(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
intermediate_buffers: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
self.metal_forward(
|
||||
inputs,
|
||||
unsafe { &*self.buffer.get() },
|
||||
intermediate_buffers,
|
||||
output_buffers,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for CommandBufferWrapper {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
self.without_storage_buffers(
|
||||
&inp.iter()
|
||||
.map(|(t, sh)| (get_buffer_from_tensor(t).deref(), *sh))
|
||||
.collect::<Vec<_>>(),
|
||||
unsafe { &*self.buffer.get() },
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
)
|
||||
.into_iter()
|
||||
.map(|b| Tensor::new(MetalBuffer(b)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[test]
|
||||
fn test_common_buffer() {
|
||||
use luminal::{
|
||||
prelude::*,
|
||||
tests::{assert_close, random_vec},
|
||||
};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let b = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let c = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let mut d = ((a + b) * c).retrieve();
|
||||
|
||||
cx.execute();
|
||||
let d_unopt = d.data();
|
||||
d.drop();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut d);
|
||||
cx.execute();
|
||||
|
||||
assert_close(&d.data(), &d_unopt);
|
||||
}
|
||||
459
crates/luminal_metal/src/elementwise_fusion.rs
Normal file
459
crates/luminal_metal/src/elementwise_fusion.rs
Normal file
@@ -0,0 +1,459 @@
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::{any::Any, marker::PhantomData, ops::Deref, sync::Arc};
|
||||
|
||||
use itertools::Itertools;
|
||||
use metal_rs::{
|
||||
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
|
||||
ComputePipelineState, Device, MTLResourceOptions,
|
||||
};
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::{
|
||||
petgraph::{visit::EdgeRef, Direction},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper};
|
||||
|
||||
use self::symbolic::BigExpression;
|
||||
|
||||
use super::{
|
||||
compile_function, get_idx_valid_exps, input_dyn_dims, prim::MetalConstant,
|
||||
render_dyn_dim_inputs, DispatchNElements, SetInt,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ElementwiseFusionCompiler<T>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
let device = Device::system_default().unwrap();
|
||||
let queue = device.new_command_queue();
|
||||
// Find two elementwise ops that have a contiguous edge
|
||||
let (mut a, mut b) = (NodeIndex::default(), NodeIndex::default());
|
||||
let mut selector = SelectOp::new()
|
||||
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
|
||||
.ptr(&mut a)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
|
||||
.ptr(&mut b),
|
||||
)
|
||||
.search(graph);
|
||||
let mut fused_ops = FxHashSet::default();
|
||||
|
||||
while selector.next_match() {
|
||||
// More than one connecting edge
|
||||
if graph.no_delete.contains(&a)
|
||||
|| (graph
|
||||
.graph
|
||||
.edges_directed(a, Direction::Outgoing)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.count()
|
||||
> 1
|
||||
&& !graph
|
||||
.graph
|
||||
.node_weight(a)
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.is::<MetalConstant<T>>())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Connecting shape isn't contiguous
|
||||
let (edge_id, (to_input, _, connecting_shape)) = graph
|
||||
.graph
|
||||
.edges_connecting(a, b)
|
||||
.find_map(|e| e.weight().as_data().map(|i| (e.id(), i)))
|
||||
.unwrap();
|
||||
if !connecting_shape.is_contiguous()
|
||||
|| connecting_shape.is_sliced()
|
||||
|| connecting_shape.is_padded()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fuse into a FusedElementwiseOp
|
||||
let new_op;
|
||||
let mut a_equation = graph
|
||||
.node_custom::<String, _>(a, "elementwise", ())
|
||||
.unwrap();
|
||||
let mut curr_input = to_input;
|
||||
// Keep track of original edges to a and b
|
||||
let a_orig_edges = graph
|
||||
.graph
|
||||
.edges_directed(a, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
|
||||
.sorted_by_key(|i| i.1)
|
||||
.collect::<Vec<_>>();
|
||||
let b_orig_edges = graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
|
||||
.sorted_by_key(|i| i.1)
|
||||
.collect::<Vec<_>>();
|
||||
// Remove edge a -> b, and decrement indexes of all edges higher than it
|
||||
graph.graph.remove_edge(edge_id);
|
||||
for edge in graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.map(|e| e.id())
|
||||
.collect_vec()
|
||||
{
|
||||
if let Some(Dependency::Data { input_order, .. }) =
|
||||
graph.graph.edge_weight_mut(edge)
|
||||
{
|
||||
if *input_order > curr_input {
|
||||
*input_order -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add edges if they don't exist
|
||||
for input_edge in graph
|
||||
.graph
|
||||
.edges_directed(a, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
|
||||
.sorted_by_key(|i| i.1)
|
||||
.collect_vec()
|
||||
{
|
||||
// Find edge or add it
|
||||
if !graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
|
||||
.any(|(src, _, out_ind, _)| src == input_edge.0 && out_ind == input_edge.2)
|
||||
{
|
||||
// Move all edges >= curr_input up by one
|
||||
for edge in graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.map(|e| e.id())
|
||||
.collect_vec()
|
||||
{
|
||||
if let Some(Dependency::Data { input_order, .. }) =
|
||||
graph.graph.edge_weight_mut(edge)
|
||||
{
|
||||
if *input_order >= curr_input {
|
||||
*input_order += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add edge
|
||||
graph.graph.add_edge(
|
||||
input_edge.0,
|
||||
b,
|
||||
Dependency::Data {
|
||||
input_order: curr_input,
|
||||
output_order: input_edge.2,
|
||||
shape: input_edge.3,
|
||||
},
|
||||
);
|
||||
curr_input += 1;
|
||||
}
|
||||
}
|
||||
// Alter a_equation to reflect the correct input indexes
|
||||
let mut replacements = vec![];
|
||||
for (src, inp_ind, out_ind) in a_orig_edges {
|
||||
let n = graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
|
||||
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
|
||||
.unwrap();
|
||||
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
|
||||
}
|
||||
a_equation = multi_replace(&a_equation, &replacements);
|
||||
// Alter b_equation to reflect the correct input indexes
|
||||
replacements.clear();
|
||||
for (src, inp_ind, out_ind) in b_orig_edges {
|
||||
if inp_ind > to_input {
|
||||
let n = graph
|
||||
.graph
|
||||
.edges_directed(b, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
|
||||
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
|
||||
.unwrap();
|
||||
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(fused_op) = graph
|
||||
.graph
|
||||
.node_weight_mut(b)
|
||||
.unwrap()
|
||||
.as_any_mut()
|
||||
.downcast_mut::<FusedElementwiseOp<T>>()
|
||||
{
|
||||
// B is already fused, just combine with b
|
||||
new_op = b;
|
||||
// Render a into b as input to_input
|
||||
fused_op.equation = multi_replace(&fused_op.equation, &replacements)
|
||||
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
|
||||
} else {
|
||||
let mut b_equation = graph
|
||||
.node_custom::<String, _>(b, "elementwise", ())
|
||||
.unwrap();
|
||||
b_equation = multi_replace(&b_equation, &replacements)
|
||||
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
|
||||
// B is not a fused op, let's create a new one
|
||||
new_op = graph
|
||||
.add_op(FusedElementwiseOp::<T> {
|
||||
kernel: None,
|
||||
dyn_map: &graph.dyn_map,
|
||||
dyn_chars: vec![],
|
||||
equation: b_equation,
|
||||
queue: queue.clone(),
|
||||
device: device.clone(),
|
||||
_phantom: Default::default(),
|
||||
})
|
||||
.finish();
|
||||
move_incoming_edge(b, new_op, &mut graph.graph);
|
||||
move_outgoing_edge(b, new_op, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
b,
|
||||
new_op,
|
||||
);
|
||||
graph.graph.remove_node(b);
|
||||
fused_ops.remove(&b);
|
||||
}
|
||||
// Remove a
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
a,
|
||||
new_op,
|
||||
);
|
||||
if graph
|
||||
.graph
|
||||
.edges_directed(a, Direction::Outgoing)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.count()
|
||||
== 0
|
||||
{
|
||||
graph.graph.remove_node(a);
|
||||
}
|
||||
fused_ops.remove(&a);
|
||||
fused_ops.insert(new_op);
|
||||
selector.reset();
|
||||
}
|
||||
// Compile all the kernels we placed
|
||||
let type_name = T::type_name();
|
||||
for fused_op in fused_ops {
|
||||
let edges = graph
|
||||
.graph
|
||||
.edges_directed(fused_op, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data())
|
||||
.collect_vec();
|
||||
if let Some(op) = graph
|
||||
.graph
|
||||
.node_weight_mut(fused_op)
|
||||
.unwrap()
|
||||
.as_any_mut()
|
||||
.downcast_mut::<FusedElementwiseOp<T>>()
|
||||
{
|
||||
let (dyn_chars, rendered) = render_dyn_dim_inputs(
|
||||
&edges.iter().map(|i| i.2).collect_vec(),
|
||||
edges.len() + 2,
|
||||
);
|
||||
for (inp_ind, _, sh) in &edges {
|
||||
let (ind, val) = get_idx_valid_exps(*sh);
|
||||
if (sh.is_contiguous() && !sh.is_sliced() && !sh.is_padded())
|
||||
|| (!sh.is_sliced() && !sh.is_padded())
|
||||
{
|
||||
op.equation = op.equation.replace(
|
||||
&format!("input{inp_ind}"),
|
||||
&format!("(float)input{inp_ind}[{ind}]"),
|
||||
);
|
||||
} else {
|
||||
op.equation = op.equation.replace(
|
||||
&format!("input{inp_ind}"),
|
||||
&format!("(({val} != 0) ? (float)input{inp_ind}[{ind}] : 0.0)"),
|
||||
);
|
||||
}
|
||||
}
|
||||
let kernel = format!(
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel({} device {type_name} *out [[buffer({})]], device uint& n_elements [[buffer({})]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements) {{
|
||||
out[idx] = ({type_name})({});
|
||||
}}
|
||||
}}",
|
||||
edges
|
||||
.iter()
|
||||
.map(|(inp_ind, _, _)| format!(
|
||||
"device {type_name}* input{inp_ind} [[buffer({inp_ind})]],"
|
||||
))
|
||||
.collect_vec()
|
||||
.join(" "),
|
||||
edges.len(),
|
||||
edges.len() + 1,
|
||||
op.equation
|
||||
);
|
||||
op.kernel = Some(compile_function("mkernel", &kernel, &device));
|
||||
op.dyn_chars = dyn_chars;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn multi_replace(input: &str, replacements: &[(String, String)]) -> String {
|
||||
// Use Unicode Private Use Areas as unlikely placeholders
|
||||
// Starting at U+E000
|
||||
let mut placeholder_start = 0xE000;
|
||||
|
||||
let mut output = input.to_string();
|
||||
|
||||
// Generate placeholder characters for each replacement pair
|
||||
let mut placeholders: Vec<(String, char)> = Vec::new();
|
||||
for (from, _) in replacements {
|
||||
let placeholder = std::char::from_u32(placeholder_start).unwrap();
|
||||
placeholder_start += 1;
|
||||
placeholders.push((from.clone(), placeholder));
|
||||
}
|
||||
|
||||
// First pass: Replace all target strings with placeholders
|
||||
for (from, placeholder) in &placeholders {
|
||||
output = output.replace(from, &placeholder.to_string());
|
||||
}
|
||||
|
||||
// Second pass: Replace placeholders with final strings
|
||||
for ((_, placeholder), (_, to)) in placeholders.iter().zip(replacements) {
|
||||
output = output.replace(&placeholder.to_string(), to);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
|
||||
pub struct FusedElementwiseOp<T> {
|
||||
kernel: Option<ComputePipelineState>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
dyn_chars: Vec<char>,
|
||||
equation: String,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
impl<T> MetalKernel for FusedElementwiseOp<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
if input_shapes.len() == 1 {
|
||||
// Assume since it's a unary op, we're outputting 1-1 elements from input
|
||||
vec![input_shapes[0].n_physical_elements() * std::mem::size_of::<T>()]
|
||||
} else {
|
||||
// If it isn't a unary op, output the contiguous buffer length
|
||||
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
|
||||
}
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(self.kernel.as_ref().unwrap());
|
||||
let out_size = inputs
|
||||
.iter()
|
||||
.map(|i| i.1.n_elements().to_usize().unwrap())
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
// Set function inputs
|
||||
for (i, (buf, _)) in inputs.iter().enumerate() {
|
||||
encoder.set_buffer(i as u64, Some(*buf), 0);
|
||||
}
|
||||
encoder.set_buffer(inputs.len() as u64, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(inputs.len() + 1, out_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_chars,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
inputs.len() + 2,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(out_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for FusedElementwiseOp<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let out = self.device.new_buffer(
|
||||
self.output_buffer_sizes(&tensors.iter().map(|(_, s)| *s).collect_vec())[0]
|
||||
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
|
||||
.unwrap() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&tensors
|
||||
.iter()
|
||||
.map(|(t, s)| (get_buffer_from_tensor(t).deref(), *s))
|
||||
.collect_vec(),
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
// This op can accept non contiguous inputs
|
||||
if key == "non_contiguous" {
|
||||
return Some(Box::new(()));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new(self.equation.clone()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use luminal::{
|
||||
prelude::*,
|
||||
tests::{assert_close, random_vec},
|
||||
};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
#[test]
|
||||
fn test_fusion() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
|
||||
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
|
||||
let mut c = (a.exp2() - b.sin()).relu().retrieve();
|
||||
|
||||
cx.execute();
|
||||
let unopt_c = c.data();
|
||||
c.drop();
|
||||
|
||||
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
assert_close(&c.data(), &unopt_c);
|
||||
}
|
||||
}
|
||||
307
crates/luminal_metal/src/kernels/bf16.h
Normal file
307
crates/luminal_metal/src/kernels/bf16.h
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
|
||||
_MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \
|
||||
float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
|
||||
addr_space _MLX_BFloat16 &lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \
|
||||
_MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
|
||||
addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) { return x != x; }
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
|
||||
#include "bf16_math.h"
|
||||
365
crates/luminal_metal/src/kernels/bf16_math.h
Normal file
365
crates/luminal_metal/src/kernels/bf16_math.h
Normal file
@@ -0,0 +1,365 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Following the Metal Shading Language Specification (Metal 3.1)
|
||||
|
||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||
converted to itype, it cannot be implicitly converted to half, and neither
|
||||
itype nor half can be implicitly converted to bfloat."
|
||||
|
||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||
for bfloat and calling with an argument of type bfloat will result in that
|
||||
argument getting implicitly converted to itype which then returns an output
|
||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||
|
||||
This leads to situations where
|
||||
bfloat a = 5.0bf;
|
||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||
|
||||
For the moment, I will be adding overloaded instantiations of the math
|
||||
functions to accordingly automatically handle the casting
|
||||
|
||||
*/
|
||||
|
||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||
\
|
||||
METAL_FUNC otype abs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acos(itype x) { \
|
||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acosh(itype x) { \
|
||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asin(itype x) { \
|
||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asinh(itype x) { \
|
||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan(itype y_over_x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atanh(itype x) { \
|
||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype ceil(itype x) { \
|
||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cos(itype x) { \
|
||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cosh(itype x) { \
|
||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cospi(itype x) { \
|
||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype divide(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp(itype x) { \
|
||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp10(itype x) { \
|
||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp2(itype x) { \
|
||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fabs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||
ctype t = static_cast<ctype>(x - y); \
|
||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||
} \
|
||||
METAL_FUNC otype floor(itype x) { \
|
||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fma( \
|
||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||
} \
|
||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fract(itype x) { \
|
||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype frexp(itype x, thread int &exp) { \
|
||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||
} \
|
||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log(itype x) { \
|
||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log10(itype x) { \
|
||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log2(itype x) { \
|
||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3(static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||
} \
|
||||
METAL_FUNC otype pow(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype powr(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rint(itype x) { \
|
||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype round(itype x) { \
|
||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rsqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sin(itype x) { \
|
||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinh(itype x) { \
|
||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinpi(itype x) { \
|
||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tan(itype x) { \
|
||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanh(itype x) { \
|
||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanpi(itype x) { \
|
||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype trunc(itype x) { \
|
||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
|
||||
__METAL_MAYBE_FAST_MATH__);
|
||||
|
||||
namespace fast {
|
||||
|
||||
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
|
||||
__METAL_FAST_MATH__);
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace precise {
|
||||
|
||||
instantiate_metal_math_funcs(bfloat16_t, bfloat16_t, float,
|
||||
__METAL_PRECISE_MATH__);
|
||||
|
||||
} // namespace precise
|
||||
|
||||
} // namespace metal
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal simd for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_metal_simd_comm_funcs(itype, otype, ctype, itype_to_ctype, \
|
||||
ctype_to_otype) \
|
||||
\
|
||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down(itype data, itype filling_data, \
|
||||
ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down(itype data, itype filling_data, \
|
||||
ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up(itype data, itype filling_data, \
|
||||
ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up(itype data, itype filling_data, \
|
||||
ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||
}
|
||||
|
||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||
\
|
||||
METAL_FUNC otype simd_max(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_min(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_product(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_sum(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_xor(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(bfloat16_t, bfloat16_t, uint16_t,
|
||||
bfloat16_to_uint16, uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
||||
115
crates/luminal_metal/src/kernels/complex.h
Normal file
115
crates/luminal_metal/src/kernels/complex.h
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_complex64 =
|
||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_complex64 =
|
||||
!is_same_v<T, complex64_t> &&
|
||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||
|
||||
struct complex64_t {
|
||||
float real;
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Conversions from complex64_t
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const thread {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const threadgroup {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const device {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const constant {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr complex64_t operator-(complex64_t x) { return {-x.real, -x.imag}; }
|
||||
|
||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
return a.real == b.real && a.imag == b.imag;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a.real * b.real + a.imag * b.imag;
|
||||
auto y = a.imag * b.real - a.real * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
return {real, imag};
|
||||
}
|
||||
16
crates/luminal_metal/src/kernels/defines.h
Normal file
16
crates/luminal_metal/src/kernels/defines.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
539
crates/luminal_metal/src/kernels/gemm.h
Normal file
539
crates/luminal_metal/src/kernels/gemm.h
Normal file
@@ -0,0 +1,539 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// #pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BROWS,
|
||||
int BCOLS,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
bool transpose,
|
||||
bool ldK,
|
||||
int tgp_padding = 0>
|
||||
struct BlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
||||
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
// Stride along reduction axis between blocks
|
||||
const int tstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tstride(
|
||||
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
||||
|
||||
// Iterate over rows of block
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
// Row is in bounds, we check against column
|
||||
if ((bi + i) < src_tile_dim.y) {
|
||||
// Use fast thread memory for bound checks
|
||||
short tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indices into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out unneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Row is out of bounds, we just fill tgp memory with zeros
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tstride;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct GEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
BM,
|
||||
BK,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_a,
|
||||
true,
|
||||
tgp_padding_a>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
BK,
|
||||
BN,
|
||||
BK,
|
||||
vec_size,
|
||||
tgp_size,
|
||||
transpose_b,
|
||||
false,
|
||||
tgp_padding_b>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant int& M [[buffer(3)]],
|
||||
const constant int& N [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& batch_stride_a [[buffer(6)]],
|
||||
const constant int& batch_stride_b [[buffer(7)]],
|
||||
const constant int& batch_size_b [[buffer(8)]],
|
||||
const constant int& batch_stride_c [[buffer(9)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Adjust for batch
|
||||
A += batch_stride_a * tid.z;
|
||||
B += batch_stride_b * (tid.z / batch_size_b);
|
||||
C += batch_stride_c * tid.z;
|
||||
|
||||
// Adjust for transpose
|
||||
const int lda_dev = transpose_a ? M : K;
|
||||
const int ldb_dev = transpose_b ? K : N;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * K;
|
||||
B += transpose_b ? c_col * K : c_col;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
||||
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned && K_aligned) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN aligned, K unaligned loop
|
||||
else if (MN_aligned && !K_aligned) {
|
||||
// Main loop
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
// Loop tail
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
|
||||
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
||||
|
||||
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, BM));
|
||||
loader_b.load_safe(short2(BN, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
mma_op.store_result(C, N);
|
||||
return;
|
||||
|
||||
} else {
|
||||
int k = 0;
|
||||
for (; k + BK <= K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (k < K) {
|
||||
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
||||
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
mma_op.store_result_safe(C, N, src_tile_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
95
crates/luminal_metal/src/kernels/gemm.metal
Normal file
95
crates/luminal_metal/src/kernels/gemm.metal
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant int &M [[buffer(3)]],
|
||||
const constant int &N [[buffer(4)]],
|
||||
const constant int &K [[buffer(5)]],
|
||||
const constant int &batch_stride_a [[buffer(6)]],
|
||||
const constant int& batch_stride_b [[buffer(7)]],
|
||||
const constant int& batch_size_b [[buffer(8)]],
|
||||
const constant int& batch_stride_c [[buffer(9)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
batch_stride_a, batch_stride_b, batch_size_b, batch_stride_c,
|
||||
tgp_memory,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant int &M [[buffer(3)]], \
|
||||
const constant int &N [[buffer(4)]], \
|
||||
const constant int &K [[buffer(5)]], \
|
||||
const constant int &batch_stride_a [[buffer(6)]], \
|
||||
const constant int& batch_stride_b [[buffer(7)]], \
|
||||
const constant int& batch_size_b [[buffer(8)]], \
|
||||
const constant int& batch_stride_c [[buffer(9)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
// TODO: Accumulation in different type
|
||||
575
crates/luminal_metal/src/kernels/gemv.metal
Normal file
575
crates/luminal_metal/src/kernels/gemv.metal
Normal file
@@ -0,0 +1,575 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/defines.h"
|
||||
#include "KERNEL_PATH/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
struct GEMVKernel {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
// Block position
|
||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if(out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * in_vec_size;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
|
||||
// Load for the row
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if(simd_lid == 0) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN;
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
||||
228
crates/luminal_metal/src/kernels/softmax.metal
Normal file
228
crates/luminal_metal/src/kernels/softmax.metal
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/defines.h"
|
||||
#include "KERNEL_PATH/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[0] = normalizer;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||
// lsize) parts. We need to combine them.
|
||||
// 1. We start by finding the max across simd groups
|
||||
// 2. We then change the partial normalizers to account for a possible
|
||||
// change in max
|
||||
// 3. We sum all normalizers
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||
// them shared memory and combine them.
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= softmax_exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
312
crates/luminal_metal/src/kernels/steel/gemm/gemm.h
Normal file
312
crates/luminal_metal/src/kernels/steel/gemm/gemm.h
Normal file
@@ -0,0 +1,312 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "loader.h"
|
||||
#include "mma.h"
|
||||
#include "transforms.h"
|
||||
#include "../utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
if (!M_aligned) {
|
||||
short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
}
|
||||
|
||||
if (!N_aligned) {
|
||||
short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
}
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(mask_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(mask_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.set_mask(tile_dims_A_last, mask_A);
|
||||
loader_b.set_mask(tile_dims_B_last, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant GEMMParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
@@ -0,0 +1,260 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel =
|
||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||
transpose_a, transpose_b,
|
||||
MN_aligned, K_aligned,
|
||||
AccumType, Epilogue>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col * params->fdc;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
@@ -0,0 +1,280 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "KERNEL_PATH/bf16.h"
|
||||
#include "KERNEL_PATH/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device U *C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iter_remaining,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device otype *C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
|
||||
}
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT *C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
C += gid.x * fdc + gid.y * ldc;
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype *C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float);
|
||||
160
crates/luminal_metal/src/kernels/steel/gemm/loader.h
Normal file
160
crates/luminal_metal/src/kernels/steel/gemm/loader.h
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void set_mask(
|
||||
thread const short2& src_tile_dims,
|
||||
thread bool mask[n_rows][vec_size]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
mask[i][j] =
|
||||
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Use fast thread memory for bound checks
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
264
crates/luminal_metal/src/kernels/steel/gemm/mma.h
Normal file
264
crates/luminal_metal/src/kernels/steel/gemm/mma.h
Normal file
@@ -0,0 +1,264 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "transforms.h"
|
||||
#include "../utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
79
crates/luminal_metal/src/kernels/steel/gemm/params.h
Normal file
79
crates/luminal_metal/src/kernels/steel/gemm/params.h
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct GEMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMSpiltKParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int split_k_partitions;
|
||||
const int split_k_partition_stride;
|
||||
const int split_k_partition_size;
|
||||
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMAddMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
const int ldd;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
const int batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
const int fdc;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
63
crates/luminal_metal/src/kernels/steel/gemm/transforms.h
Normal file
63
crates/luminal_metal/src/kernels/steel/gemm/transforms.h
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
5
crates/luminal_metal/src/kernels/steel/host.h
Normal file
5
crates/luminal_metal/src/kernels/steel/host.h
Normal file
@@ -0,0 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/params.h"
|
||||
9
crates/luminal_metal/src/kernels/steel/utils.h
Normal file
9
crates/luminal_metal/src/kernels/steel/utils.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "host.h"
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
212
crates/luminal_metal/src/kernels/utils.h
Normal file
212
crates/luminal_metal/src/kernels/utils.h
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "bf16.h"
|
||||
#include "complex.h"
|
||||
#include <metal_math>
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename U> struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> struct Limits<type> { \
|
||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
metal::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> struct Limits<type> { \
|
||||
static constexpr constant type max = \
|
||||
metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type min = \
|
||||
-metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
-metal::numeric_limits<type>::max(); \
|
||||
};
|
||||
|
||||
instantiate_float_limit(half);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
|
||||
template <> struct Limits<bool> {
|
||||
static constexpr constant bool max = true;
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(uint elem, device const int *shape,
|
||||
device const size_t *strides, int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(uint elem, constant const int *shape,
|
||||
constant const size_t *strides, int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(uint3 elem, constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {static_cast<uint>(elem.x * a_strides[NDIM - 1] +
|
||||
elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(elem.x * b_strides[NDIM - 1] +
|
||||
elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(uint3 elem, constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t &stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(uint3 elem, constant const int *shape,
|
||||
constant const size_t *strides, int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape,
|
||||
constant const size_t *a_strides,
|
||||
constant const size_t *b_strides, int ndim) {
|
||||
uint2 loc = {static_cast<uint>(elem.x * a_strides[ndim - 1] +
|
||||
elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(elem.x * b_strides[ndim - 1] +
|
||||
elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(uint elem, device const int *shape,
|
||||
device const size_t *strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(uint elem, device const int *shape,
|
||||
device const size_t *strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(uint elem, device const int *shape,
|
||||
device const size_t *strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(uint elem, device const int *shape,
|
||||
device const size_t *strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(uint elem, device const int *shape,
|
||||
device const size_t *strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) { return (N + M - 1) / M; }
|
||||
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
bfloat16_t ret =
|
||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
return ret;
|
||||
}
|
||||
458
crates/luminal_metal/src/lib.rs
Normal file
458
crates/luminal_metal/src/lib.rs
Normal file
@@ -0,0 +1,458 @@
|
||||
use std::{
|
||||
any::{Any, TypeId},
|
||||
fmt::{Debug, Write},
|
||||
ops::Deref,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
mod binary;
|
||||
mod command_buffer;
|
||||
mod elementwise_fusion;
|
||||
mod matmul;
|
||||
mod other;
|
||||
mod prim;
|
||||
mod quantized;
|
||||
mod storage_buffer;
|
||||
mod unary;
|
||||
|
||||
use itertools::Itertools;
|
||||
use metal_rs::*;
|
||||
pub use quantized::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use luminal::{
|
||||
op::InputTensor,
|
||||
prelude::{
|
||||
symbolic::{BigExpression, Term},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
/// Compile graphs to run on Metal-supported macOS devices in supported data formats
|
||||
pub type MetalCompiler<T> = (
|
||||
prim::PrimitiveCompiler<T>,
|
||||
SpecialOpsCompiler<T>,
|
||||
other::CopyCompiler<T>,
|
||||
other::ContiguousElimination<T>,
|
||||
elementwise_fusion::ElementwiseFusionCompiler<T>,
|
||||
// BufferCompilers,
|
||||
);
|
||||
|
||||
/// Compilers to share command and storage buffers
|
||||
type BufferCompilers = (
|
||||
command_buffer::CommandBufferCompiler,
|
||||
storage_buffer::StorageBufferCompiler,
|
||||
);
|
||||
|
||||
/// Compiler to replace metal ops with specialized variants
|
||||
type SpecialOpsCompiler<T> = (
|
||||
binary::MetalSubtractionCompiler<T>,
|
||||
binary::MetalEqualCompiler<T>,
|
||||
other::ARangeCompiler<T>,
|
||||
binary::MetalGatherCompiler<T>,
|
||||
unary::MetalExpCompiler<T>,
|
||||
unary::MetalCosCompiler<T>,
|
||||
unary::MeanReduceCompiler<T>,
|
||||
unary::StdNormCompiler<T>,
|
||||
unary::SoftmaxCompiler<T>,
|
||||
unary::RopeCompiler<T>,
|
||||
matmul::MetalMatMulCompiler<T>,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalBuffer(pub Buffer);
|
||||
|
||||
impl Deref for MetalBuffer {
|
||||
type Target = Buffer;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Data for MetalBuffer {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MetalFloat: Copy + 'static {
|
||||
fn to_f32(self) -> f32;
|
||||
fn from_f32(a: f32) -> Self;
|
||||
fn is_f32() -> bool;
|
||||
fn type_name() -> &'static str;
|
||||
}
|
||||
|
||||
// Quantization types
|
||||
|
||||
pub trait MetalQuantizationType {
|
||||
type MatmulCompiler;
|
||||
}
|
||||
|
||||
/// 8-bit quantization. Equivalent to the ggml Q8_0 datatype
|
||||
pub struct Q8_0;
|
||||
|
||||
impl MetalQuantizationType for Q8_0 {
|
||||
type MatmulCompiler = matmul::MetalMatMulCompiler<f16>;
|
||||
}
|
||||
|
||||
impl MetalQuantizationType for f32 {
|
||||
type MatmulCompiler = matmul::MetalMatMulCompiler<Self>;
|
||||
}
|
||||
|
||||
impl MetalQuantizationType for f16 {
|
||||
type MatmulCompiler = matmul::MetalMatMulCompiler<Self>;
|
||||
}
|
||||
|
||||
// Main metal dtypes
|
||||
|
||||
impl MetalFloat for f32 {
|
||||
fn from_f32(a: f32) -> Self {
|
||||
a
|
||||
}
|
||||
fn to_f32(self) -> f32 {
|
||||
self
|
||||
}
|
||||
fn is_f32() -> bool {
|
||||
true
|
||||
}
|
||||
fn type_name() -> &'static str {
|
||||
"float"
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalFloat for f16 {
|
||||
fn from_f32(a: f32) -> Self {
|
||||
f16::from_f32(a)
|
||||
}
|
||||
fn to_f32(self) -> f32 {
|
||||
self.to_f32()
|
||||
}
|
||||
fn is_f32() -> bool {
|
||||
false
|
||||
}
|
||||
fn type_name() -> &'static str {
|
||||
"half"
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MetalKernel: Debug {
|
||||
/// Annotate the buffer sizes of the intermediate buffers
|
||||
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![]
|
||||
}
|
||||
/// Annotate the buffer sizes of the output buffers
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression>;
|
||||
/// Set up the kernel on the buffer
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
intermediate_buffers: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
);
|
||||
fn without_command_buffer(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
intermediate_buffers: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let command_buffer = queue.new_command_buffer();
|
||||
self.metal_forward(inputs, command_buffer, intermediate_buffers, output_buffers);
|
||||
}
|
||||
fn without_storage_buffers(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<Buffer> {
|
||||
let dev = Device::system_default().unwrap();
|
||||
// Allocate storage buffers
|
||||
let inp_shapes = inputs.iter().map(|(_, s)| *s).collect::<Vec<_>>();
|
||||
let intermediate_buffers = self
|
||||
.intermediate_buffer_sizes(&inp_shapes)
|
||||
.into_iter()
|
||||
.map(|n| {
|
||||
dev.new_buffer(
|
||||
n.exec(dyn_map).unwrap() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let intermediate_buffers_ref = intermediate_buffers.iter().collect::<Vec<_>>();
|
||||
let output_buffers = self
|
||||
.output_buffer_sizes(&inp_shapes)
|
||||
.into_iter()
|
||||
.map(|n| {
|
||||
dev.new_buffer(
|
||||
n.exec(dyn_map).unwrap() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let output_buffers_ref = output_buffers.iter().collect::<Vec<_>>();
|
||||
self.metal_forward(
|
||||
inputs,
|
||||
command_buffer,
|
||||
&intermediate_buffers_ref,
|
||||
&output_buffers_ref,
|
||||
);
|
||||
output_buffers
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalPrint, LuminalEqFalse, Clone)]
|
||||
pub struct MetalKernelWrapper(pub Arc<Box<dyn MetalKernel>>);
|
||||
|
||||
impl Default for MetalKernelWrapper {
|
||||
fn default() -> Self {
|
||||
Self(Arc::new(Box::new(())))
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalKernel for () {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
_: &[(&Buffer, ShapeTracker)],
|
||||
_: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
_: &[&Buffer],
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_lib(device: &Device, source: &str) -> Library {
|
||||
let options = CompileOptions::new();
|
||||
options.set_fast_math_enabled(true);
|
||||
// options.set_install_name(
|
||||
// &rand::thread_rng()
|
||||
// .sample_iter(&rand::distributions::Alphanumeric)
|
||||
// .take(7)
|
||||
// .map(char::from)
|
||||
// .collect::<String>(),
|
||||
// );
|
||||
device
|
||||
.new_library_with_source(
|
||||
&source.replace(
|
||||
"KERNEL_PATH",
|
||||
&format!("{}/src/kernels", env!("CARGO_MANIFEST_DIR")),
|
||||
),
|
||||
&options,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn select_function_from_lib(
|
||||
lib: &Library,
|
||||
function: &str,
|
||||
device: &Device,
|
||||
) -> ComputePipelineState {
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor
|
||||
.set_compute_function(Some(&lib.get_function(function, None).unwrap()));
|
||||
device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn compile_function(name: &str, code: &str, device: &Device) -> ComputePipelineState {
|
||||
let library = compile_lib(device, code);
|
||||
select_function_from_lib(&library, name, device)
|
||||
}
|
||||
|
||||
fn is<T: Any>(type_id: TypeId) -> bool {
|
||||
type_id == TypeId::of::<T>()
|
||||
}
|
||||
|
||||
trait DispatchNElements {
|
||||
fn dispatch_1d(&self, n: usize);
|
||||
}
|
||||
|
||||
impl DispatchNElements for ComputeCommandEncoderRef {
|
||||
fn dispatch_1d(&self, n: usize) {
|
||||
self.dispatch_thread_groups(
|
||||
MTLSize {
|
||||
width: n.div_ceil(1024) as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
},
|
||||
MTLSize {
|
||||
width: 1024,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
trait SetInt {
|
||||
fn set_i32(&self, index: usize, value: i32);
|
||||
fn set_u32(&self, index: usize, value: u32);
|
||||
fn set_f32(&self, index: usize, value: f32);
|
||||
fn set_i64(&self, index: usize, value: i64);
|
||||
fn set_u64(&self, index: usize, value: u64);
|
||||
fn set_f64(&self, index: usize, value: f64);
|
||||
}
|
||||
|
||||
impl SetInt for ComputeCommandEncoderRef {
|
||||
fn set_i32(&self, index: usize, value: i32) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<i32>() as u64,
|
||||
&value as *const i32 as *const _,
|
||||
);
|
||||
}
|
||||
fn set_u32(&self, index: usize, value: u32) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<u32>() as u64,
|
||||
&value as *const u32 as *const _,
|
||||
);
|
||||
}
|
||||
fn set_f32(&self, index: usize, value: f32) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<f32>() as u64,
|
||||
&value as *const f32 as *const _,
|
||||
);
|
||||
}
|
||||
fn set_i64(&self, index: usize, value: i64) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<i64>() as u64,
|
||||
&value as *const i64 as *const _,
|
||||
);
|
||||
}
|
||||
fn set_u64(&self, index: usize, value: u64) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<u64>() as u64,
|
||||
&value as *const u64 as *const _,
|
||||
);
|
||||
}
|
||||
fn set_f64(&self, index: usize, value: f64) {
|
||||
self.set_bytes(
|
||||
index as u64,
|
||||
std::mem::size_of::<f64>() as u64,
|
||||
&value as *const f64 as *const _,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn input_dyn_dims(
|
||||
dyn_symbols: &[char],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
index: usize,
|
||||
) {
|
||||
for (i, s) in dyn_symbols.iter().enumerate() {
|
||||
encoder.set_u32(i + index, dyn_map[s] as u32);
|
||||
}
|
||||
}
|
||||
|
||||
fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>, String) {
|
||||
let symbols: Vec<char> = shapes
|
||||
.iter()
|
||||
.flat_map(|st| {
|
||||
st.shape()
|
||||
.into_iter()
|
||||
.chain(
|
||||
st.padding
|
||||
.into_iter()
|
||||
.flat_map(|i| [i.0.into(), i.1.into()]),
|
||||
)
|
||||
.chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
|
||||
})
|
||||
.flat_map(|d| d.to_symbols())
|
||||
.unique()
|
||||
.collect();
|
||||
(
|
||||
symbols.clone(),
|
||||
symbols
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.fold(String::default(), |mut acc, (i, c)| {
|
||||
write!(&mut acc, ", device int& {c} [[buffer({})]]", i + offset).unwrap();
|
||||
acc
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn expr_to_metal_string(expr: BigExpression) -> String {
|
||||
let mut symbols = vec![];
|
||||
for term in expr.terms {
|
||||
let new_symbol = match term {
|
||||
Term::Num(n) => n.to_string(),
|
||||
Term::Var(c) => {
|
||||
if c == 'z' {
|
||||
"(int)idx".to_string()
|
||||
} else {
|
||||
c.to_string()
|
||||
}
|
||||
}
|
||||
Term::Max => format!(
|
||||
"max((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
Term::Min => format!(
|
||||
"min((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
_ => format!(
|
||||
"({}{term:?}{})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
};
|
||||
symbols.push(new_symbol);
|
||||
}
|
||||
symbols.pop().unwrap()
|
||||
}
|
||||
|
||||
fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) {
|
||||
(
|
||||
expr_to_metal_string(shape.index_expression()),
|
||||
expr_to_metal_string(shape.valid_expression()),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_buffer_from_tensor<'a>(tensor: &'a InputTensor) -> &'a MetalBuffer {
|
||||
tensor
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<MetalBuffer>()
|
||||
.expect("Tensor does not contain a metal buffer")
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select_const {
|
||||
($i: expr, $t: tt) => {
|
||||
luminal::compiler_utils::SelectOp::new().check(|o, _| {
|
||||
if let Some(c) = o.as_any().downcast_ref::<$crate::prim::MetalConstant<$t>>() {
|
||||
if let luminal::op::ConstantValue::Float(f) = c.0 {
|
||||
(f - $i).abs() < 0.0001
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
};
|
||||
}
|
||||
480
crates/luminal_metal/src/matmul.rs
Normal file
480
crates/luminal_metal/src/matmul.rs
Normal file
@@ -0,0 +1,480 @@
|
||||
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::*,
|
||||
shape::symbolic::BigExpression,
|
||||
};
|
||||
|
||||
use metal_rs::{objc::rc::autoreleasepool, *};
|
||||
|
||||
use crate::{
|
||||
compile_lib, get_buffer_from_tensor,
|
||||
prim::{MetalContiguous, MetalMul, MetalSumReduce},
|
||||
select_function_from_lib, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper, SetInt,
|
||||
};
|
||||
|
||||
/// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix
|
||||
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
|
||||
pub struct Matmul<T> {
|
||||
matmul_pipeline: ComputePipelineState,
|
||||
matvec_pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
const BM: u64 = 8;
|
||||
const BN: u64 = 32;
|
||||
impl<T> MetalKernel for Matmul<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
|
||||
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
|
||||
let batch_size = input_shapes[0]
|
||||
.shape()
|
||||
.into_iter()
|
||||
.take(input_shapes[0].len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(BigExpression::from(1));
|
||||
vec![batch_size * m * n * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let (a_shape, b_shape) = (
|
||||
inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
inputs[1]
|
||||
.1
|
||||
.shape()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let a_dims = a_shape.len();
|
||||
let m = a_shape[a_dims - 2];
|
||||
let batch_size = a_shape.iter().take(a_dims - 2).product::<usize>().max(1);
|
||||
let b_batch_size = b_shape
|
||||
.iter()
|
||||
.enumerate()
|
||||
.take(b_shape.len() - 2)
|
||||
.filter(|(i, _)| !inputs[1].1.fake[inputs[1].1.indexes[*i]])
|
||||
.map(|(_, i)| *i)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
let b_dims = b_shape.len();
|
||||
let k = b_shape[b_dims - 2];
|
||||
let n = b_shape[b_dims - 1];
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
if m == 1 && batch_size == 1 {
|
||||
// Matvec
|
||||
encoder.set_compute_pipeline_state(&self.matvec_pipeline);
|
||||
encoder.set_buffer(0, Some(inputs[1].0), 0);
|
||||
encoder.set_buffer(1, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(2, Some(output_buffers[0]), 0);
|
||||
encoder.set_i32(3, if m == 1 { k } else { m } as i32);
|
||||
encoder.set_i32(4, if m == 1 { n } else { m } as i32);
|
||||
encoder.set_i32(5, 0);
|
||||
encoder.set_i32(6, 0);
|
||||
encoder.set_threadgroup_memory_length(
|
||||
0,
|
||||
if inputs[1].1.indexes[inputs[1].1.len() - 1]
|
||||
> inputs[1].1.indexes[inputs[1].1.len() - 2]
|
||||
{
|
||||
BN * BM * 4
|
||||
} else {
|
||||
BN * 8
|
||||
},
|
||||
);
|
||||
let b = if inputs[1].1.is_contiguous() { BN } else { BM };
|
||||
encoder.dispatch_thread_groups(
|
||||
MTLSize::new((n as u64 + b * 4 - 1).div_ceil(b * 4), 1, 1),
|
||||
MTLSize::new(BN, BM, 1),
|
||||
);
|
||||
} else {
|
||||
// Matmul
|
||||
encoder.set_compute_pipeline_state(&self.matmul_pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(inputs[1].0), 0);
|
||||
encoder.set_buffer(2, Some(output_buffers[0]), 0);
|
||||
encoder.set_i32(3, m as i32);
|
||||
encoder.set_i32(4, n as i32);
|
||||
encoder.set_i32(5, k as i32);
|
||||
encoder.set_i32(6, (m * k) as i32); // A batch stride
|
||||
if inputs[1].1.len() > 2 // 3D or larger
|
||||
&& inputs[1].1.fake[inputs[1].1.indexes[inputs[1].1.len() - 3]] // 3rd to last dimension is fake
|
||||
&& inputs[1]
|
||||
.1
|
||||
.indexes
|
||||
.iter()
|
||||
.take(inputs[1].1.len().saturating_sub(4))
|
||||
.any(|i| !inputs[1].1.fake[*i])
|
||||
// At least one non-fake dimension before 3rd to last
|
||||
{
|
||||
encoder.set_i32(7, (k * n) as i32); // B batch stride
|
||||
// B batch size 2
|
||||
encoder.set_i32(8, b_shape[inputs[1].1.len() - 3] as i32);
|
||||
} else {
|
||||
encoder.set_i32(7, if b_batch_size == 1 { 0 } else { n * k } as i32); // B batch stride
|
||||
encoder.set_i32(8, 1); // B batch size
|
||||
}
|
||||
encoder.set_i32(9, (m * n) as i32); // C batch stride
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_thread_groups(
|
||||
MTLSize::new(
|
||||
(n + 31).div_ceil(32) as u64,
|
||||
(m + 31).div_ceil(32) as u64,
|
||||
batch_size as u64,
|
||||
),
|
||||
MTLSize::new(32, 2, 2),
|
||||
);
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static + Clone> Operator for Matmul<T> {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let n = b_shape.last().unwrap().to_usize().unwrap();
|
||||
let batch_size = a_shape
|
||||
.iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.take(a_shape.len() - 2)
|
||||
.product::<usize>();
|
||||
let m = a_shape[a_shape.len() - 2].to_usize().unwrap();
|
||||
|
||||
let out = self.device.new_buffer(
|
||||
(batch_size * m * n * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[
|
||||
(get_buffer_from_tensor(&inp[0].0), inp[0].1),
|
||||
(get_buffer_from_tensor(&inp[1].0), inp[1].1),
|
||||
],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct MetalMatMulCompiler<T>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for MetalMatMulCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let (mut sum_reduce, mut mul) = (NodeIndex::default(), NodeIndex::default());
|
||||
|
||||
// Look for the matmul pattern
|
||||
// Mul ([A, C(fake), B] | [A(fake), C, B]) -> SumReduce(2) -> [A, C]
|
||||
// Actually starts at [A,B] | [B, C]
|
||||
let mut searcher_2d = SelectOp::new()
|
||||
.ty::<MetalMul<T>>()
|
||||
.shapes([['M', 'N', 'K'], ['M', 'N', 'K']])
|
||||
.fakes([
|
||||
[None, Some(true), Some(false)],
|
||||
[Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
|
||||
o.dim == 2
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
)
|
||||
.search(graph);
|
||||
let mut searcher_3d = SelectOp::new()
|
||||
.ty::<MetalMul<T>>()
|
||||
.shapes([['D', 'A', 'C', 'B'], ['D', 'A', 'C', 'B']])
|
||||
.fakes([
|
||||
[Some(false), Some(false), Some(true), Some(false)],
|
||||
[None, Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalSumReduce<T>>()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
|
||||
o.dim == 3
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
)
|
||||
.search(graph);
|
||||
let mut searcher_4d = SelectOp::new()
|
||||
.ty::<MetalMul<T>>()
|
||||
.shapes([['E', 'D', 'A', 'C', 'B'], ['E', 'D', 'A', 'C', 'B']])
|
||||
.fakes([
|
||||
[
|
||||
Some(false),
|
||||
Some(false),
|
||||
Some(false),
|
||||
Some(true),
|
||||
Some(false),
|
||||
],
|
||||
[None, None, Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalSumReduce<T>>()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
|
||||
o.dim == 4
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
)
|
||||
.search(graph);
|
||||
let mut searcher_5d = SelectOp::new()
|
||||
.ty::<MetalMul<T>>()
|
||||
.shapes([
|
||||
['F', 'E', 'D', 'A', 'C', 'B'],
|
||||
['F', 'E', 'D', 'A', 'C', 'B'],
|
||||
])
|
||||
.fakes([
|
||||
[
|
||||
Some(false),
|
||||
Some(false),
|
||||
Some(false),
|
||||
Some(false),
|
||||
Some(true),
|
||||
Some(false),
|
||||
],
|
||||
[None, None, None, Some(true), Some(false), Some(false)],
|
||||
])
|
||||
.ptr(&mut mul)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalSumReduce<T>>()
|
||||
.check(|o, _| {
|
||||
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<T>>() {
|
||||
o.dim == 5
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ptr(&mut sum_reduce),
|
||||
)
|
||||
.search(graph);
|
||||
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
|
||||
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
|
||||
while searcher_2d.next_match()
|
||||
|| searcher_3d.next_match()
|
||||
|| searcher_4d.next_match()
|
||||
|| searcher_5d.next_match()
|
||||
{
|
||||
if graph.no_delete.contains(&mul) {
|
||||
// The intermediate mul can't be deleted
|
||||
continue;
|
||||
}
|
||||
// Insert Matmul op
|
||||
let srcs = graph.get_sources(mul);
|
||||
let (mut src1, mut src1_shape) = (srcs[0].0, srcs[0].2);
|
||||
let (mut src2, mut src2_shape) = (srcs[1].0, srcs[1].2);
|
||||
// Undo expansions and permute
|
||||
src1_shape.remove_dim(src1_shape.len() - 2);
|
||||
src2_shape.remove_dim(src2_shape.len() - 3);
|
||||
let mut dims = (0..src2_shape.len()).collect::<Vec<_>>();
|
||||
dims.swap(src2_shape.len() - 2, src2_shape.len() - 1);
|
||||
src2_shape.permute(&dims);
|
||||
// If src1 is padded or sliced, or batch dim isn't first, we need to make it contiguous
|
||||
if src1_shape
|
||||
.indexes
|
||||
.iter()
|
||||
.take(src1_shape.len() - 2)
|
||||
.enumerate()
|
||||
.any(|(a, b)| a != *b)
|
||||
|| src1_shape.is_sliced()
|
||||
|| src1_shape.is_padded()
|
||||
{
|
||||
src1 = graph
|
||||
.add_op(MetalContiguous::<T>::new(
|
||||
src1_shape,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(src1, 0, src1_shape)
|
||||
.finish();
|
||||
src1_shape = src1_shape.contiguous();
|
||||
}
|
||||
// If src2 is padded or sliced, or batch dim isn't first, we need to make it contiguous
|
||||
if src2_shape
|
||||
.indexes
|
||||
.iter()
|
||||
.take(src2_shape.len() - 2)
|
||||
.filter(|i| !src2_shape.fake[**i])
|
||||
.enumerate()
|
||||
.any(|(a, b)| a != *b)
|
||||
|| src2_shape.is_sliced()
|
||||
|| src2_shape.is_padded()
|
||||
{
|
||||
src2 = graph
|
||||
.add_op(MetalContiguous::<T>::new(
|
||||
src2_shape,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.input(src2, 0, src2_shape)
|
||||
.finish();
|
||||
src2_shape = src2_shape.contiguous();
|
||||
}
|
||||
let type_name = if T::is_f32() { "float32" } else { "float16" };
|
||||
let matmul_op = graph
|
||||
.add_op(Matmul::<T> {
|
||||
matmul_pipeline: select_function_from_lib(
|
||||
&matmul_library,
|
||||
&format!( "gemm_{}{}_{type_name}_{type_name}_bm32_bn32_bk16_wm2_wn2_MN_naligned_K_taligned", if src1_shape.is_contiguous() {"n"} else {"t"}, if src2_shape.indexes[src2_shape.len() - 1] > src2_shape.indexes[src2_shape.len() - 2] {"n"} else {"t"}),
|
||||
&dev
|
||||
),
|
||||
matvec_pipeline: select_function_from_lib(
|
||||
&matvec_library,
|
||||
&format!(
|
||||
"gemv_{}{type_name}_bm{BM}_bn{BN}_tm4_tn4",
|
||||
if src2_shape.indexes[src2_shape.len() - 1] > src2_shape.indexes[src2_shape.len() - 2] { "t_" } else { "" }
|
||||
),
|
||||
&dev
|
||||
),
|
||||
queue: queue.clone(),
|
||||
device: dev.clone(),
|
||||
_phantom: Default::default()
|
||||
})
|
||||
.input(src1, 0, src1_shape)
|
||||
.input(src2, 0, src2_shape)
|
||||
.finish();
|
||||
|
||||
// Create edges to dests
|
||||
move_outgoing_edge(sum_reduce, matmul_op, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
sum_reduce,
|
||||
matmul_op,
|
||||
);
|
||||
|
||||
// Remove the old ops
|
||||
graph.graph.remove_node(mul);
|
||||
graph.graph.remove_node(sum_reduce);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use dfdx::{
|
||||
tensor::TensorFromVec,
|
||||
tensor_ops::{PermuteTo, TryMatMul},
|
||||
};
|
||||
use luminal::{
|
||||
prelude::*,
|
||||
tests::{assert_close_precision, random_vec},
|
||||
};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
#[test]
|
||||
fn test_matrix_vector() {
|
||||
const M: usize = 53;
|
||||
const N: usize = 256;
|
||||
let mut cx = Graph::new();
|
||||
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
|
||||
let mut a = cx.named_tensor::<R2<1, M>>("Vec").set(a_vec.clone());
|
||||
let mut b = cx.named_tensor::<R2<N, M>>("Mat").set(b_mat.clone());
|
||||
let mut c = a.matmul(b.permute()).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
(&mut a, &mut b, &mut c),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = dfdx::tensor::Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_vec, (dfdx::shapes::Const::<M>,));
|
||||
let d_b =
|
||||
d_dev.tensor_from_vec(b_mat, (dfdx::shapes::Const::<N>, dfdx::shapes::Const::<M>));
|
||||
let d_c = d_a.matmul(d_b.permute());
|
||||
|
||||
assert_close_precision(&c.data(), &d_c.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_matrix_vector() {
|
||||
const M: usize = 256;
|
||||
const N: usize = 256;
|
||||
let mut cx = Graph::new();
|
||||
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
|
||||
let mut a = cx.named_tensor::<R3<1, 1, M>>("Vec").set(a_vec.clone());
|
||||
let mut b = cx.named_tensor::<R2<M, N>>("Mat").set(b_mat.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
(&mut a, &mut b, &mut c),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = dfdx::tensor::Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
a_vec,
|
||||
(
|
||||
dfdx::shapes::Const::<1>,
|
||||
dfdx::shapes::Const::<1>,
|
||||
dfdx::shapes::Const::<M>,
|
||||
),
|
||||
);
|
||||
let d_b =
|
||||
d_dev.tensor_from_vec(b_mat, (dfdx::shapes::Const::<M>, dfdx::shapes::Const::<N>));
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close_precision(&c.data(), &d_c.to_dtype::<f32>().as_vec(), 2);
|
||||
}
|
||||
}
|
||||
370
crates/luminal_metal/src/other.rs
Normal file
370
crates/luminal_metal/src/other.rs
Normal file
@@ -0,0 +1,370 @@
|
||||
use std::{any::Any, marker::PhantomData, sync::Arc};
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::{
|
||||
petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction},
|
||||
*,
|
||||
},
|
||||
shape::symbolic::BigExpression,
|
||||
};
|
||||
use metal_rs::{
|
||||
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
|
||||
ComputePipelineState, Device, MTLResourceOptions,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
compile_function,
|
||||
prim::{MetalAdd, MetalContiguous, MetalCopyFromDevice, MetalCopyToDevice, MetalSumReduce},
|
||||
select_const, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
|
||||
SetInt,
|
||||
};
|
||||
|
||||
use super::binary::MetalSub;
|
||||
|
||||
/// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up
|
||||
#[derive(LuminalPrint, Default)]
|
||||
pub struct CopyCompiler<T>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for CopyCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
let (mut first, mut second) = (NodeIndex::default(), NodeIndex::default());
|
||||
let mut selector = SelectOp::new()
|
||||
.ty::<MetalCopyToDevice<T>>()
|
||||
.ptr(&mut first)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalCopyToDevice<T>>()
|
||||
.ptr(&mut second),
|
||||
)
|
||||
.search(graph);
|
||||
while selector.next_match() {
|
||||
// Ensure there are no dests from first that are not copies
|
||||
if graph
|
||||
.graph
|
||||
.edges_directed(first, petgraph::Direction::Outgoing)
|
||||
.filter(|e| {
|
||||
let target = graph.graph.node_weight(e.target()).unwrap().as_any();
|
||||
!target.is::<MetalCopyFromDevice<T>>() && !target.is::<MetalCopyToDevice<T>>()
|
||||
})
|
||||
.count()
|
||||
> 0
|
||||
|| graph.no_delete.contains(&first)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let Some((source, _, _)) = graph.get_sources(first).pop() else {
|
||||
continue;
|
||||
};
|
||||
move_outgoing_edge(second, source, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
second,
|
||||
source,
|
||||
);
|
||||
graph.graph.remove_node(second);
|
||||
for dest in graph
|
||||
.get_dests(first)
|
||||
.iter()
|
||||
.map(|(i, _)| *i)
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
move_outgoing_edge(dest, source, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
dest,
|
||||
source,
|
||||
);
|
||||
graph.graph.remove_node(dest);
|
||||
}
|
||||
graph.graph.remove_node(first);
|
||||
selector.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Special kernel for producing aranges
|
||||
#[derive(Clone, LuminalEqFalse)]
|
||||
pub struct MetalARange<T: MetalFloat> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
pub size: BigExpression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> std::fmt::Debug for MetalARange<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalARange({:?})", self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalARange<T> {
|
||||
fn new(
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
size: BigExpression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let type_name = T::type_name();
|
||||
Self {
|
||||
pipeline: compile_function("metal_arange", &format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void metal_arange(device {type_name} *out [[buffer(0)]], device int& n_elements [[buffer(1)]], uint idx [[thread_position_in_grid]]) {{
|
||||
if (idx < n_elements) {{
|
||||
out[idx] = ({type_name})idx;
|
||||
}}
|
||||
}}"), &device),
|
||||
queue,
|
||||
device,
|
||||
size,
|
||||
dyn_map,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalKernel for MetalARange<T> {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![self.size.clone() * std::mem::size_of::<f16>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
_: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
// Calculate size
|
||||
let size = self
|
||||
.size
|
||||
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
|
||||
.unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(1, size as u32);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalARange<T> {
|
||||
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Set up command buffer and output buffer
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let size = self
|
||||
.size
|
||||
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
|
||||
.unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(size * std::mem::size_of::<f16>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(&[], command_buffer, &[], &[&out]);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the arange pattern with a special kernel. This must be ran **after** the subtraction compiler
|
||||
#[derive(Default, LuminalPrint)]
|
||||
pub struct ARangeCompiler<T: MetalFloat>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for ARangeCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
let (
|
||||
mut one_const,
|
||||
mut contig1,
|
||||
mut contig2,
|
||||
mut contig3,
|
||||
mut contig4,
|
||||
mut sum_reduce,
|
||||
mut subtraction_constant,
|
||||
mut subtraction,
|
||||
) = (
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
NodeIndex::default(),
|
||||
);
|
||||
|
||||
// TODO: Make sure this actually checks the shape transformations to ensure pooling happens
|
||||
let contig = SelectOp::new().ty::<MetalContiguous<T>>();
|
||||
let pre_sub_pattern = select_const!(1.0, T)
|
||||
.ptr(&mut one_const)
|
||||
.edge(contig.clone().ptr(&mut contig1))
|
||||
.edge(contig.clone().ptr(&mut contig2))
|
||||
.edge(contig.clone().ptr(&mut contig3))
|
||||
.edge(contig.clone().ptr(&mut contig4))
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.ty::<MetalSumReduce<T>>()
|
||||
.ptr(&mut sum_reduce),
|
||||
);
|
||||
let mut s1 = pre_sub_pattern
|
||||
.clone()
|
||||
.edge(
|
||||
select_const!(1.0, T)
|
||||
.ptr(&mut subtraction_constant)
|
||||
.edge(SelectOp::new().ty::<MetalSub<T>>().ptr(&mut subtraction)),
|
||||
)
|
||||
.search(graph);
|
||||
let mut s2 = pre_sub_pattern
|
||||
.edge(
|
||||
select_const!(-1.0, T)
|
||||
.ptr(&mut subtraction_constant)
|
||||
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut subtraction)),
|
||||
)
|
||||
.search(graph);
|
||||
|
||||
while s1.next_match() || s2.next_match() {
|
||||
let arange_amount = {
|
||||
let sh = graph
|
||||
.graph
|
||||
.edge_weight(
|
||||
graph
|
||||
.graph
|
||||
.edges_connecting(one_const, contig1)
|
||||
.next()
|
||||
.unwrap()
|
||||
.id(),
|
||||
)
|
||||
.unwrap()
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2;
|
||||
sh.dims[sh.indexes[sh.len() - 1]]
|
||||
};
|
||||
let arange_op = graph
|
||||
.add_op(MetalARange::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
arange_amount.into(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish();
|
||||
move_outgoing_edge(subtraction, arange_op, &mut graph.graph);
|
||||
|
||||
graph.graph.remove_node(subtraction);
|
||||
graph.safe_remove_node(subtraction_constant, 0);
|
||||
graph.safe_remove_node(sum_reduce, 0);
|
||||
graph.safe_remove_node(contig4, 0);
|
||||
graph.safe_remove_node(contig3, 0);
|
||||
graph.safe_remove_node(contig2, 0);
|
||||
graph.safe_remove_node(contig1, 0);
|
||||
graph.safe_remove_node(one_const, 0);
|
||||
s1.clear_cached_results();
|
||||
s2.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ContiguousElimination<T>(PhantomData<T>);
|
||||
|
||||
impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
// Look for contiguous calls going to ops that can accept non-contiguous inputs (marked non_contiguous)
|
||||
let (mut contig, mut op) = (NodeIndex::default(), NodeIndex::default());
|
||||
let pattern = SelectOp::new()
|
||||
.ty::<MetalContiguous<T>>()
|
||||
.ptr(&mut contig)
|
||||
.edge(
|
||||
SelectOp::new()
|
||||
.check(|op, _| op.custom("non_contiguous", Box::new(())).is_some())
|
||||
.ptr(&mut op),
|
||||
);
|
||||
let mut selector = pattern.search(graph);
|
||||
while selector.next_match() {
|
||||
if graph.no_delete.contains(&contig)
|
||||
|| graph
|
||||
.graph
|
||||
.edges_directed(contig, Direction::Outgoing)
|
||||
.count()
|
||||
> 1
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Shape going from contig to op
|
||||
// let first_shape = graph
|
||||
// .graph
|
||||
// .edges_directed(contig, Direction::Incoming)
|
||||
// .find_map(|e| e.weight().as_data())
|
||||
// .unwrap()
|
||||
// .2;
|
||||
let second_shape = graph
|
||||
.graph
|
||||
.edges_connecting(contig, op)
|
||||
.find_map(|e| e.weight().as_data())
|
||||
.unwrap()
|
||||
.2;
|
||||
// Here we should check if second shape and first shape are mergeable instead of just checking if second_shape is contiguous
|
||||
if second_shape.is_contiguous()
|
||||
&& !second_shape.is_sliced()
|
||||
&& !second_shape.is_padded()
|
||||
{
|
||||
let source = graph
|
||||
.graph
|
||||
.neighbors_directed(contig, petgraph::Direction::Incoming)
|
||||
.next()
|
||||
.unwrap();
|
||||
move_incoming_edge(contig, op, &mut graph.graph);
|
||||
move_references(
|
||||
&mut remap,
|
||||
&mut graph.no_delete,
|
||||
&mut graph.to_retrieve,
|
||||
contig,
|
||||
source,
|
||||
);
|
||||
graph.graph.remove_node(contig);
|
||||
let new_shapes = graph
|
||||
.get_sources(op)
|
||||
.into_iter()
|
||||
.map(|(_, _, s)| s)
|
||||
.collect::<Vec<_>>();
|
||||
graph
|
||||
.graph
|
||||
.node_weight_mut(op)
|
||||
.unwrap()
|
||||
.custom("recompile_shapes", Box::new(new_shapes));
|
||||
selector.clear_cached_results();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1871
crates/luminal_metal/src/prim.rs
Normal file
1871
crates/luminal_metal/src/prim.rs
Normal file
File diff suppressed because it is too large
Load Diff
508
crates/luminal_metal/src/quantized.rs
Normal file
508
crates/luminal_metal/src/quantized.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
|
||||
|
||||
use metal_rs::{
|
||||
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
|
||||
ComputePipelineState, Device, MTLResourceOptions, MTLSize,
|
||||
};
|
||||
use petgraph::visit::EdgeRef;
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::*,
|
||||
shape::symbolic::BigExpression,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
binary::MetalGather, get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel,
|
||||
MetalKernelWrapper,
|
||||
};
|
||||
|
||||
use super::{compile_function, SetInt};
|
||||
|
||||
/// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix. This expects the first input to be a quantized 2D matrix
|
||||
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
|
||||
pub struct QuantizedMatmul<T> {
|
||||
matvec_pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> QuantizedMatmul<T> {
|
||||
fn new(device: Device, queue: CommandQueue) -> Self {
|
||||
let type_name = T::type_name();
|
||||
Self {
|
||||
matvec_pipeline: compile_function("mkernel", &format!("
|
||||
using namespace metal;
|
||||
#define QK8_0 32
|
||||
#define NB_Q8_0 8
|
||||
typedef struct {{
|
||||
half d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
}} block_q8_0;
|
||||
|
||||
kernel void mkernel(
|
||||
device block_q8_0* x [[buffer(0)]], // Quantized 2D matrix
|
||||
device {type_name}* y [[buffer(1)]], // Float src vector
|
||||
device {type_name}* dst [[buffer(2)]], // Float dest vector
|
||||
constant int64_t & src_vec_size [[buffer(3)]], // Matrix n cols (src vector size) (Must be >= 32)
|
||||
constant int64_t & dest_vec_size [[buffer(4)]], // Matrix n rows (dest vector size) (Must be >= 4)
|
||||
constant int64_t & mat_batch_stride [[buffer(5)]], // Matrix batch stride
|
||||
constant int64_t & vec_batch_stride [[buffer(6)]], // Vector batch stride
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
||||
uint thread_index_in_simdgroup[[thread_index_in_simdgroup]],
|
||||
uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]] // 2 simdgroups in a threadgroup
|
||||
) {{
|
||||
const int num_rows = 4;
|
||||
const int num_simdgroups_per_threadgroup = 2;
|
||||
const int quant_width = 32;
|
||||
|
||||
const int num_quants_per_row = src_vec_size / 32; // Number of quants per row
|
||||
|
||||
// This is the first row the simdgroup will work on (each simdgroup handles a block of 4 rows)
|
||||
const int first_row = (threadgroup_position_in_grid.x * num_simdgroups_per_threadgroup + simdgroup_index_in_threadgroup) * num_rows;
|
||||
|
||||
// Offsets
|
||||
x += first_row * num_quants_per_row + threadgroup_position_in_grid.z * (mat_batch_stride / 32);
|
||||
y += threadgroup_position_in_grid.z * vec_batch_stride;
|
||||
dst += (threadgroup_position_in_grid.z * dest_vec_size);
|
||||
|
||||
// thread-local cache of vector values to work on. This thread must only work on 8 at a time
|
||||
{type_name} yl[8];
|
||||
// thread-local cache of 4 row sums
|
||||
float sumf[num_rows] = {{0.f}};
|
||||
|
||||
const int ix = thread_index_in_simdgroup / 4;
|
||||
const int il = thread_index_in_simdgroup % 4;
|
||||
|
||||
y += thread_index_in_simdgroup * 8;
|
||||
|
||||
// each thread in a SIMD group deals with 8 quants at a time
|
||||
// we start at 0-7 (ix) depending on the simdgroup index, and jump 8 indexes each time
|
||||
for (int ib = ix; ib < num_quants_per_row; ib += 8) {{ // ib: current column position
|
||||
// Load vector values into the cache
|
||||
for (int i = 0; i < 8; ++i) {{
|
||||
yl[i] = y[i];
|
||||
}}
|
||||
|
||||
// Loop through 4 matrix rows
|
||||
for (int row = 0; row < 4; ++row) {{
|
||||
// Get pointer to matrix data
|
||||
device const int8_t* qs = x[ib + row * num_quants_per_row].qs + il * 8;
|
||||
float sumq = 0.f; // Partial sum
|
||||
// Loop through 8 columns
|
||||
for (int iq = 0; iq < 8; ++iq) {{
|
||||
sumq += qs[iq] * yl[iq]; // Multiply int with vector value (auto converts to float?)
|
||||
}}
|
||||
sumf[row] += sumq * x[ib + row * num_quants_per_row].d; // multiply by delta (scaling factor)
|
||||
}}
|
||||
y += 256; // Jump by 256
|
||||
}}
|
||||
|
||||
// each simdgroup is responsible for saving 4 final vector values (n rows)
|
||||
for (int row = 0; row < num_rows; ++row) {{
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (thread_index_in_simdgroup == 0 && first_row + row < dest_vec_size) {{
|
||||
dst[first_row + row] = ({type_name})tot;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"), &device),
|
||||
queue,
|
||||
device,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for QuantizedMatmul<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
|
||||
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
|
||||
let batch_size = input_shapes[0]
|
||||
.shape()
|
||||
.into_iter()
|
||||
.take(input_shapes[0].len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(BigExpression::from(1));
|
||||
vec![batch_size * m * n * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
assert!(
|
||||
!inputs[1].1.is_contiguous(),
|
||||
"Weight matrix must be column-major"
|
||||
);
|
||||
let (a_shape, b_shape) = (
|
||||
inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
inputs[1]
|
||||
.1
|
||||
.shape()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let a_dims = a_shape.len();
|
||||
let m = a_shape[a_dims - 2];
|
||||
let batch_size = a_shape.iter().take(a_dims - 2).product::<usize>().max(1);
|
||||
let b_dims = b_shape.len();
|
||||
let k = b_shape[b_dims - 2];
|
||||
let n = b_shape[b_dims - 1];
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
if batch_size == 1 {
|
||||
// Matvec
|
||||
encoder.set_compute_pipeline_state(&self.matvec_pipeline);
|
||||
encoder.set_buffer(0, Some(inputs[1].0), 0); // Matrix
|
||||
encoder.set_buffer(1, Some(inputs[0].0), 0); // Vector
|
||||
encoder.set_buffer(2, Some(output_buffers[0]), 0); // Dest vector
|
||||
encoder.set_i64(3, k as i64); // Src vec size
|
||||
encoder.set_i64(4, n as i64); // Dest vec size
|
||||
encoder.set_i64(5, 0); // Matrix batch stride
|
||||
encoder.set_i64(6, k as i64); // Vector batch stride
|
||||
encoder.dispatch_thread_groups(
|
||||
MTLSize::new(n.div_ceil(8) as u64, 1, m as u64),
|
||||
MTLSize::new(8, 8, 1),
|
||||
);
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static + Clone> Operator for QuantizedMatmul<T> {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let n = b_shape[1].to_usize().unwrap();
|
||||
let (batch_size, m) = if a_shape.len() == 3 {
|
||||
(
|
||||
a_shape[0].to_usize().unwrap(),
|
||||
a_shape[1].to_usize().unwrap(),
|
||||
)
|
||||
} else {
|
||||
(0, a_shape[0].to_usize().unwrap())
|
||||
};
|
||||
|
||||
let out = self.device.new_buffer(
|
||||
(batch_size * m * n * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[
|
||||
(get_buffer_from_tensor(&inp[0].0), inp[0].1),
|
||||
(get_buffer_from_tensor(&inp[1].0), inp[1].1),
|
||||
],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqFalse, LuminalPrint, Clone)]
|
||||
pub struct QuantizedGather<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
embed_dim: usize,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> QuantizedGather<T> {
|
||||
fn new(device: Device, queue: CommandQueue, embed_dim: usize) -> Self {
|
||||
let type_name = T::type_name();
|
||||
Self {pipeline: compile_function("metal_gather", &format!(
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
#define QK8_0 32
|
||||
typedef struct {{
|
||||
half d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
}} block_q8_0;
|
||||
|
||||
kernel void metal_gather(device float *inp [[buffer(0)]], device block_q8_0 *weights [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_embeddings [[buffer(3)]], device int& embedding_dim [[buffer(4)]], uint2 idx [[thread_position_in_grid]]) {{
|
||||
if (idx.x < n_embeddings && idx.y < embedding_dim) {{
|
||||
int block_idx = ((int)inp[idx.x] * embedding_dim + idx.y) / QK8_0;
|
||||
out[idx.x * embedding_dim + idx.y] = weights[block_idx].qs[idx.y % QK8_0] * weights[block_idx].d;
|
||||
}}
|
||||
}}"), &device), device, embed_dim, queue, _phantom: Default::default()}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for QuantizedGather<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup buffers
|
||||
let indexes = tensors[0]
|
||||
.0
|
||||
.borrowed()
|
||||
.data
|
||||
.as_any()
|
||||
.downcast_ref::<Vec<f32>>()
|
||||
.unwrap();
|
||||
let index_buffer = self.device.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(indexes.as_ptr()) },
|
||||
(indexes.len() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let out = self.device.new_buffer(
|
||||
(indexes.len() * self.embed_dim * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
let encoder = command_buffer
|
||||
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(&index_buffer), 0);
|
||||
encoder.set_buffer(1, Some(get_buffer_from_tensor(&tensors[1].0)), 0);
|
||||
encoder.set_buffer(2, Some(&out), 0);
|
||||
encoder.set_u32(3, indexes.len() as u32);
|
||||
encoder.set_u32(4, self.embed_dim as u32);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_threads(
|
||||
MTLSize {
|
||||
width: indexes.len() as u64,
|
||||
height: self.embed_dim as u64,
|
||||
depth: 1,
|
||||
},
|
||||
MTLSize {
|
||||
width: 16,
|
||||
height: 16,
|
||||
depth: 1,
|
||||
},
|
||||
);
|
||||
encoder.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MetalQuantizedCompiler<T>(Vec<NodeIndex>, PhantomData<T>);
|
||||
|
||||
impl<T> MetalQuantizedCompiler<T> {
|
||||
pub fn new<To: ToIds>(weights: To) -> Self {
|
||||
Self(weights.to_ids(), Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat + Default> Compiler for MetalQuantizedCompiler<T> {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
|
||||
let device = Device::system_default().unwrap();
|
||||
let queue = device.new_command_queue();
|
||||
let mut weight_ids = self.0.clone();
|
||||
let mut local_remap = remap.to_ids_mut();
|
||||
for w in &mut weight_ids {
|
||||
local_remap.push(w);
|
||||
}
|
||||
// Normal metal compilation
|
||||
graph.compile(
|
||||
<(
|
||||
super::prim::PrimitiveCompiler<T>,
|
||||
super::SpecialOpsCompiler<T>,
|
||||
super::other::CopyCompiler<T>,
|
||||
super::other::ContiguousElimination<T>,
|
||||
super::elementwise_fusion::ElementwiseFusionCompiler<T>,
|
||||
)>::default(),
|
||||
&mut local_remap,
|
||||
);
|
||||
// Modify ops directly downstream of weights
|
||||
for weight in downstream(&weight_ids, graph) {
|
||||
for (target, (inp_ind, _, _)) in graph
|
||||
.graph
|
||||
.edges_directed(weight, petgraph::Direction::Outgoing)
|
||||
.filter_map(|e| e.weight().as_data().map(|i| (e.target(), i)))
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
assert_eq!(
|
||||
inp_ind, 1,
|
||||
"Quantized weight {target:?} is the wrong input!",
|
||||
);
|
||||
let op_node = graph.graph.node_weight_mut(target).unwrap();
|
||||
if let Some(gather) = op_node.as_any().downcast_ref::<MetalGather<T>>() {
|
||||
*op_node = Box::new(QuantizedGather::<T>::new(
|
||||
device.clone(),
|
||||
queue.clone(),
|
||||
gather.embed_dim,
|
||||
));
|
||||
} else if op_node.as_any().is::<super::matmul::Matmul<T>>() {
|
||||
*op_node = Box::new(QuantizedMatmul::<T>::new(device.clone(), queue.clone()));
|
||||
} else {
|
||||
panic!("Quantized weight {target:?} is an input to a node that isn't a matmul or gather!");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Finish normal metal compilation
|
||||
graph.compile(super::BufferCompilers::default(), &mut remap);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use dfdx::{
|
||||
tensor::TensorFromVec,
|
||||
tensor_ops::{PermuteTo, TryMatMul},
|
||||
};
|
||||
use luminal::{
|
||||
prelude::*,
|
||||
tests::{assert_close, random_vec_rng},
|
||||
};
|
||||
use metal_rs::{Device, MTLResourceOptions};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use crate::{MetalBuffer, MetalQuantizedCompiler};
|
||||
|
||||
#[repr(C, packed)]
|
||||
struct BlockQ8_0 {
|
||||
_d: f16,
|
||||
_qs: [i8; 32],
|
||||
}
|
||||
|
||||
fn quantized_buffer(weights: &[BlockQ8_0], dev: &Device) -> Tensor {
|
||||
let buffer = dev.new_buffer_with_bytes_no_copy(
|
||||
weights.as_ptr() as *mut _,
|
||||
std::mem::size_of_val(weights) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
None,
|
||||
);
|
||||
Tensor {
|
||||
data: Box::new(MetalBuffer(buffer)),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_matvec() {
|
||||
let mut rng = thread_rng();
|
||||
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
|
||||
let vec_data = random_vec_rng(1024, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let weights = cx.tensor::<R2<512, 1024>>();
|
||||
let vec = cx.tensor::<R1<1024>>().set(vec_data.clone());
|
||||
let mut out = vec.matmul(weights.permute()).retrieve();
|
||||
|
||||
// "Load" weights in 8bit
|
||||
let blocks = mat_data
|
||||
.chunks_exact(32)
|
||||
.map(|chunk| {
|
||||
let mut array = [0; 32];
|
||||
for (i, n) in chunk.iter().enumerate() {
|
||||
array[i] = *n;
|
||||
}
|
||||
BlockQ8_0 {
|
||||
_d: f16::from_f32(1.0),
|
||||
_qs: array,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let dev = Device::system_default().unwrap();
|
||||
cx.tensors
|
||||
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
|
||||
|
||||
cx.compile(
|
||||
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
|
||||
&mut out,
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let mut cx1 = Graph::new();
|
||||
let weights = cx1
|
||||
.tensor::<R2<512, 1024>>()
|
||||
.set(mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>());
|
||||
let vec = cx1.tensor::<R1<1024>>().set(vec_data);
|
||||
let out_32 = vec.matmul(weights.permute()).retrieve();
|
||||
cx1.execute();
|
||||
|
||||
assert_close(&out.data(), &out_32.data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_matmul() {
|
||||
let mut rng = thread_rng();
|
||||
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
|
||||
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let weights = cx.tensor::<R2<512, 1024>>();
|
||||
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
|
||||
let mut out = inp_mat.matmul(weights.permute()).retrieve();
|
||||
|
||||
// "Load" weights in 8bit
|
||||
let blocks = mat_data
|
||||
.chunks_exact(32)
|
||||
.map(|chunk| {
|
||||
let mut array = [0; 32];
|
||||
for (i, n) in chunk.iter().enumerate() {
|
||||
array[i] = *n;
|
||||
}
|
||||
BlockQ8_0 {
|
||||
_d: f16::from_f32(1.0),
|
||||
_qs: array,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let dev = Device::system_default().unwrap();
|
||||
cx.tensors
|
||||
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
|
||||
|
||||
cx.compile(
|
||||
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
|
||||
&mut out,
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let cpu = dfdx::tensor::Cpu::default();
|
||||
let d_a = cpu.tensor_from_vec(
|
||||
mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>(),
|
||||
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>),
|
||||
);
|
||||
let d_b = cpu.tensor_from_vec(
|
||||
inp_mat_data,
|
||||
(dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>),
|
||||
);
|
||||
let d_c = d_b.matmul(d_a.permute());
|
||||
assert_close(&out.data(), &d_c.as_vec());
|
||||
}
|
||||
}
|
||||
61
crates/luminal_metal/src/selectors.rs
Normal file
61
crates/luminal_metal/src/selectors.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use crate::{
|
||||
prelude::{metal::prim::MetalAdd, *},
|
||||
select_const, select_ty,
|
||||
};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
use super::{
|
||||
binary::MetalSub,
|
||||
prim::{MetalConstant, MetalLessThan, MetalMul},
|
||||
};
|
||||
|
||||
pub fn less_than<T: MetalFloat>(
|
||||
s1: SelectEdge,
|
||||
s2: SelectEdge,
|
||||
ptrs: &mut Vec<NodeIndex>,
|
||||
) -> SelectEdge {
|
||||
s2.edge(s1.edge(select_ty!(MetalLessThan<T>).ptr(ptrs)))
|
||||
}
|
||||
|
||||
pub fn mul<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
s2.edge(s1.edge(select_ty!(MetalMul<T>).ptr(ptrs)))
|
||||
}
|
||||
pub fn add<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
s2.edge(s1.edge(select_ty!(MetalAdd<T>).ptr(ptrs)))
|
||||
}
|
||||
pub fn sub<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
s2.edge(s1.edge(select_ty!(MetalSub<T>).ptr(ptrs)))
|
||||
}
|
||||
pub fn less_than_equal<T: MetalFloat>(
|
||||
s1: SelectEdge,
|
||||
s2: SelectEdge,
|
||||
mut ptrs: &mut Vec<NodeIndex>,
|
||||
) -> SelectEdge {
|
||||
sub::<T>(
|
||||
select_const!(1.0, T).ptr(&mut ptrs).into(),
|
||||
less_than::<T>(s2, s1, &mut ptrs),
|
||||
ptrs,
|
||||
)
|
||||
}
|
||||
pub fn max<T: MetalFloat>(s1: SelectEdge, s2: SelectEdge, ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
let a = mul::<T>(
|
||||
less_than::<T>(s1.clone(), s2.clone(), ptrs),
|
||||
s2.clone(),
|
||||
ptrs,
|
||||
);
|
||||
let b = mul::<T>(less_than_equal::<T>(s2, s1.clone(), ptrs), s1, ptrs);
|
||||
add::<T>(a, b, ptrs)
|
||||
}
|
||||
pub fn relu<T: MetalFloat>(s1: SelectEdge, mut ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
max::<T>(s1, select_const!(0.0, T).ptr(&mut ptrs).into(), &mut ptrs)
|
||||
}
|
||||
pub fn abs<T: MetalFloat>(s1: SelectEdge, mut ptrs: &mut Vec<NodeIndex>) -> SelectEdge {
|
||||
add::<T>(
|
||||
relu::<T>(s1.clone(), &mut ptrs),
|
||||
relu::<T>(
|
||||
mul::<T>(s1, select_const!(-1.0, T).ptr(&mut ptrs).into(), &mut ptrs),
|
||||
&mut ptrs,
|
||||
),
|
||||
&mut ptrs,
|
||||
)
|
||||
}
|
||||
404
crates/luminal_metal/src/storage_buffer.rs
Normal file
404
crates/luminal_metal/src/storage_buffer.rs
Normal file
@@ -0,0 +1,404 @@
|
||||
use std::{
|
||||
cell::UnsafeCell,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
ops::Deref,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
use metal_rs::{Buffer, Device, MTLResourceOptions};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::{
|
||||
petgraph::{algo::toposort, stable_graph::NodeIndex, visit::EdgeRef, Direction},
|
||||
symbolic::BigExpression,
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{MetalBuffer, MetalKernelWrapper};
|
||||
|
||||
use super::get_buffer_from_tensor;
|
||||
|
||||
#[derive(Default, LuminalPrint)]
|
||||
pub struct StorageBufferCompiler;
|
||||
|
||||
impl Compiler for StorageBufferCompiler {
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
|
||||
// First pass - get clear sets for each node
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut first_pass: FxHashMap<
|
||||
NodeIndex,
|
||||
(
|
||||
BTreeMap<NodeIndex, BTreeSet<NodeIndex>>,
|
||||
BTreeSet<NodeIndex>,
|
||||
),
|
||||
> = FxHashMap::default();
|
||||
let toposort = toposort(&graph.graph, None).unwrap();
|
||||
// Loop through nodes in graph
|
||||
for node in &toposort {
|
||||
// Run through parents to build new tenative set and clear set
|
||||
let (mut tenative_sets, mut clear_set) = (BTreeMap::default(), BTreeSet::default());
|
||||
for parent in graph
|
||||
.graph
|
||||
.edges_directed(*node, Direction::Incoming)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.map(|e| e.source())
|
||||
{
|
||||
let parent_children = graph
|
||||
.graph
|
||||
.edges_directed(parent, Direction::Outgoing)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.map(|e| e.target())
|
||||
.collect::<BTreeSet<_>>();
|
||||
tenative_sets.insert(parent, parent_children);
|
||||
if let Some((parent_tenative_set, parent_clear_set)) = first_pass.get(&parent) {
|
||||
for (node_index, new_tenative_set) in
|
||||
parent_tenative_set.iter().map(|(n, c)| {
|
||||
let mut c = c.clone();
|
||||
c.retain(|n| *n != parent);
|
||||
(*n, c)
|
||||
})
|
||||
{
|
||||
if let Some(set) = tenative_sets.get(&node_index) {
|
||||
*tenative_sets.get_mut(&node_index).unwrap() =
|
||||
btreeset_intersection(new_tenative_set, set);
|
||||
} else {
|
||||
tenative_sets.insert(node_index, new_tenative_set);
|
||||
}
|
||||
}
|
||||
clear_set.extend(
|
||||
tenative_sets
|
||||
.iter()
|
||||
.filter(|(_, v)| v.is_empty())
|
||||
.map(|(n, _)| *n),
|
||||
);
|
||||
tenative_sets.retain(|_, v| !v.is_empty());
|
||||
clear_set.extend(parent_clear_set);
|
||||
}
|
||||
}
|
||||
first_pass.insert(*node, (tenative_sets, clear_set));
|
||||
}
|
||||
|
||||
// Second pass - assign buffers
|
||||
let available_buffers = graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|n| !graph.no_delete.contains(n))
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.filter_map(|n| {
|
||||
if let Some(Ok(wrapper)) = graph
|
||||
.graph
|
||||
.node_weight_mut(n)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.map(|n| n.downcast::<MetalKernelWrapper>())
|
||||
{
|
||||
Some((n, wrapper))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|(n, wrapper)| {
|
||||
let input_shapes = graph
|
||||
.get_sources(n)
|
||||
.into_iter()
|
||||
.map(|(_, _, i)| i)
|
||||
.collect::<Vec<_>>();
|
||||
let output_buffers = wrapper.0.output_buffer_sizes(&input_shapes);
|
||||
let intermediate_buffers = wrapper.0.intermediate_buffer_sizes(&input_shapes);
|
||||
(n, (output_buffers, intermediate_buffers))
|
||||
})
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
// Loop through nodes in graph
|
||||
let mut buffers = vec![];
|
||||
let mut buffer_map = FxHashMap::default();
|
||||
let mut used = FxHashSet::<NodeIndex>::default();
|
||||
for node in &toposort {
|
||||
if graph.no_delete.contains(node) {
|
||||
continue;
|
||||
}
|
||||
let Some(Ok(wrapper)) = graph
|
||||
.graph
|
||||
.node_weight_mut(*node)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.map(|e| e.downcast::<MetalKernelWrapper>())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
buffer_map.insert(*node, (vec![], vec![]));
|
||||
let input_shapes = graph
|
||||
.get_sources(*node)
|
||||
.into_iter()
|
||||
.map(|(_, _, i)| i)
|
||||
.collect::<Vec<_>>();
|
||||
// Assign output buffers
|
||||
for required_buffer in wrapper.0.output_buffer_sizes(&input_shapes) {
|
||||
// Find an applicable buffer
|
||||
if let Some((buffer_index, source_node, _)) = first_pass[&node]
|
||||
.1
|
||||
.iter()
|
||||
.filter(|i| !graph.no_delete.contains(i))
|
||||
.filter(|i| !used.contains(i))
|
||||
.filter(|i| available_buffers.contains_key(i))
|
||||
.flat_map(|i| {
|
||||
available_buffers[i]
|
||||
.0
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.map(|(o, b)| (o, *i, b))
|
||||
})
|
||||
.find(|(_, _, size)| *size == required_buffer)
|
||||
{
|
||||
let buffer = buffer_map.get(&source_node).unwrap().0[buffer_index];
|
||||
buffer_map.get_mut(node).unwrap().0.push(buffer);
|
||||
// Remove this buffer from first_pass so it can't be used again
|
||||
used.insert(source_node);
|
||||
} else {
|
||||
// Allocate new buffer
|
||||
buffer_map.get_mut(node).unwrap().0.push(buffers.len());
|
||||
buffers.push(required_buffer);
|
||||
}
|
||||
}
|
||||
// Assign intermediate buffers
|
||||
for required_buffer in wrapper.0.intermediate_buffer_sizes(&input_shapes) {
|
||||
// Find an applicable buffer
|
||||
if let Some((buffer_index, source_node, _)) = first_pass[&node]
|
||||
.1
|
||||
.iter()
|
||||
.filter(|i| !graph.no_delete.contains(i))
|
||||
.filter(|i| !used.contains(i))
|
||||
.filter(|i| available_buffers.contains_key(i))
|
||||
.flat_map(|i| {
|
||||
available_buffers[i]
|
||||
.1
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.map(|(o, b)| (o, *i, b))
|
||||
})
|
||||
.find(|(_, _, size)| *size == required_buffer)
|
||||
{
|
||||
let buffer = buffer_map.get(&source_node).unwrap().1[buffer_index];
|
||||
buffer_map.get_mut(node).unwrap().1.push(buffer);
|
||||
used.insert(source_node);
|
||||
} else {
|
||||
// Allocate new buffer
|
||||
buffer_map.get_mut(node).unwrap().1.push(buffers.len());
|
||||
buffers.push(required_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Loop through no_delete nodes and add buffers just for them
|
||||
for node in &toposort {
|
||||
if !graph.no_delete.contains(node) {
|
||||
continue;
|
||||
}
|
||||
let Some(Ok(wrapper)) = graph
|
||||
.graph
|
||||
.node_weight_mut(*node)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.map(|e| e.downcast::<MetalKernelWrapper>())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
buffer_map.insert(*node, (vec![], vec![]));
|
||||
let input_shapes = graph
|
||||
.get_sources(*node)
|
||||
.into_iter()
|
||||
.map(|(_, _, i)| i)
|
||||
.collect::<Vec<_>>();
|
||||
// Assign output buffers
|
||||
for required_buffer in wrapper.0.output_buffer_sizes(&input_shapes) {
|
||||
// Allocate new buffer
|
||||
buffer_map.get_mut(node).unwrap().0.push(buffers.len());
|
||||
buffers.push(required_buffer);
|
||||
}
|
||||
// Assign intermediate buffers
|
||||
for required_buffer in wrapper.0.intermediate_buffer_sizes(&input_shapes) {
|
||||
// Allocate new buffer
|
||||
buffer_map.get_mut(node).unwrap().1.push(buffers.len());
|
||||
buffers.push(required_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
// We now have the buffers to allocate, and the buffers needed for each op.
|
||||
// Let's create the allocator op and wrap all the metal ops
|
||||
let shared_buffers = Arc::new(UnsafeCell::new(vec![]));
|
||||
let allocator = graph
|
||||
.add_op(AllocateMetalBuffers {
|
||||
dev: Device::system_default().unwrap(),
|
||||
dyn_map: &graph.dyn_map,
|
||||
buffer_sizes: buffers,
|
||||
buffers: shared_buffers.clone(),
|
||||
})
|
||||
.finish();
|
||||
// Ensure allocator is ran before any nodes that use the buffers
|
||||
for node in graph
|
||||
.graph
|
||||
.node_indices()
|
||||
// Starting node must have no incoming edges
|
||||
.filter(|e| {
|
||||
graph
|
||||
.graph
|
||||
.edges_directed(*e, Direction::Incoming)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.count()
|
||||
== 0
|
||||
})
|
||||
// Starting node must have at least one outgoing edge
|
||||
.filter(|e| {
|
||||
graph
|
||||
.graph
|
||||
.edges_directed(*e, Direction::Outgoing)
|
||||
.filter(|e| !e.weight().is_schedule())
|
||||
.count()
|
||||
> 0
|
||||
})
|
||||
.collect_vec()
|
||||
{
|
||||
graph.add_schedule_dependency(allocator, node);
|
||||
}
|
||||
// Wrap nodes in StorageBufferWrapper
|
||||
for (node, (output_buffers, intermediate_buffers)) in buffer_map
|
||||
.into_iter()
|
||||
.filter(|(_, b)| !b.0.is_empty() || !b.1.is_empty())
|
||||
{
|
||||
let wrapper = graph
|
||||
.graph
|
||||
.node_weight_mut(node)
|
||||
.unwrap()
|
||||
.custom("metal", Box::new(()))
|
||||
.unwrap()
|
||||
.downcast::<MetalKernelWrapper>()
|
||||
.unwrap();
|
||||
*graph.graph.node_weight_mut(node).unwrap() = Box::new(StorageBufferWrapper {
|
||||
wrapper,
|
||||
buffers: shared_buffers.clone(),
|
||||
output_buffers,
|
||||
intermediate_buffers,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn btreeset_intersection<T: Ord>(mut a: BTreeSet<T>, b: &BTreeSet<T>) -> BTreeSet<T> {
|
||||
a.retain(|i| b.contains(i));
|
||||
a
|
||||
}
|
||||
|
||||
#[derive(LuminalEqFalse, LuminalPrint)]
|
||||
struct AllocateMetalBuffers {
|
||||
dev: Device,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
buffer_sizes: Vec<BigExpression>,
|
||||
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
|
||||
}
|
||||
|
||||
impl Operator for AllocateMetalBuffers {
|
||||
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let buffers = unsafe { &mut *self.buffers.get() };
|
||||
let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() };
|
||||
// Allocate all buffers
|
||||
if buffers.is_empty() {
|
||||
*buffers = self
|
||||
.buffer_sizes
|
||||
.iter()
|
||||
.map(|e| {
|
||||
self.dev.new_buffer(
|
||||
e.exec(dyn_map).unwrap() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
} else {
|
||||
for (size, buffer) in self.buffer_sizes.iter().zip(buffers) {
|
||||
let size = size.exec(dyn_map).unwrap() as u64;
|
||||
if buffer.length() != size {
|
||||
// TODO: For some reason this causes bad outputs. Maybe we are relying on buffer length somewhere? We shouldn't be.
|
||||
// Also, it seems we are getting the benifits of this without actually doing it. Maybe metal is doing it in the background?
|
||||
// Similar allocation strategy to Rust's Vec
|
||||
// let mut length = buffer.length();
|
||||
// while length < size {
|
||||
// length *= 2;
|
||||
// }
|
||||
let length = size;
|
||||
*buffer = self
|
||||
.dev
|
||||
.new_buffer(length, MTLResourceOptions::StorageModeShared);
|
||||
}
|
||||
}
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LuminalEqFalse)]
|
||||
struct StorageBufferWrapper {
|
||||
wrapper: Box<MetalKernelWrapper>,
|
||||
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
|
||||
intermediate_buffers: Vec<usize>,
|
||||
output_buffers: Vec<usize>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StorageBufferWrapper {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.wrapper.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator for StorageBufferWrapper {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let buffers = unsafe { self.buffers.get().as_ref().unwrap() };
|
||||
let intermediate_buffers = self
|
||||
.intermediate_buffers
|
||||
.iter()
|
||||
.map(|i| &buffers[*i])
|
||||
.collect::<Vec<_>>();
|
||||
let output_buffers = self
|
||||
.output_buffers
|
||||
.iter()
|
||||
.map(|i| &buffers[*i])
|
||||
.collect::<Vec<_>>();
|
||||
self.wrapper.0.without_command_buffer(
|
||||
&inp.iter()
|
||||
.map(|(t, sh)| (get_buffer_from_tensor(t).deref(), *sh))
|
||||
.collect::<Vec<_>>(),
|
||||
&intermediate_buffers,
|
||||
&output_buffers,
|
||||
);
|
||||
output_buffers
|
||||
.iter()
|
||||
.map(|buf| Tensor::new(MetalBuffer((*buf).clone())))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shared_buffers() {
|
||||
use luminal::prelude::*;
|
||||
use luminal::tests::{assert_close_precision, random_vec};
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let b = a.exp2();
|
||||
let c = a.log2() * b;
|
||||
let d = b.recip();
|
||||
let mut e = (c + d).retrieve();
|
||||
|
||||
cx.execute();
|
||||
let e_unopt = e.data();
|
||||
e.drop();
|
||||
|
||||
cx.compile(crate::MetalCompiler::<f16>::default(), &mut e);
|
||||
cx.execute();
|
||||
|
||||
assert_close_precision(&e.data(), &e_unopt, 2);
|
||||
}
|
||||
1121
crates/luminal_metal/src/tests/fp16.rs
Normal file
1121
crates/luminal_metal/src/tests/fp16.rs
Normal file
File diff suppressed because it is too large
Load Diff
657
crates/luminal_metal/src/tests/fp32.rs
Normal file
657
crates/luminal_metal/src/tests/fp32.rs
Normal file
@@ -0,0 +1,657 @@
|
||||
use dfdx::prelude::{Module as DfdxModule, *};
|
||||
use itertools::Itertools;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
use luminal::{
|
||||
nn::{activation::ReLU, linear::Linear},
|
||||
prelude::{Module, *},
|
||||
tests::{assert_close, assert_close_precision, random_vec, random_vec_rng},
|
||||
};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
|
||||
#[test]
|
||||
fn test_contiguous() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
|
||||
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(data, (dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>));
|
||||
let d_b = d_a.permute::<Rank2<4, 3>, _>().reshape::<Rank2<12, 1>>();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(data.clone());
|
||||
let mut b = a.log2().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
assert_close(
|
||||
&b.data(),
|
||||
&data.into_iter().map(|i: f32| i.log2()).collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exp2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(data.clone());
|
||||
let mut b = a.exp2().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
assert_close(
|
||||
&b.data(),
|
||||
&data.into_iter().map(|i: f32| i.exp2()).collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recip() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 4096.]);
|
||||
let mut b = a.recip().retrieve();
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 4096.]);
|
||||
let d_b = d_a.recip();
|
||||
|
||||
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sin() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut b = a.sin().retrieve();
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_a.sin();
|
||||
|
||||
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut b = a.sqrt().retrieve();
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_a.sqrt();
|
||||
|
||||
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a + b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a + d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a - b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a - d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_square() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = rand::thread_rng();
|
||||
let data = (0..40960)
|
||||
.map(|_| rng.gen_range(-0.01..0.01))
|
||||
.collect::<Vec<f32>>();
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'s'>, luminal::prelude::Const<4096>)>()
|
||||
.set_dyn(data.clone(), &[1, 10, 4096]);
|
||||
let mut b = a * a;
|
||||
b.retrieve();
|
||||
|
||||
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec::<Rank3<1, 10, 4096>>(
|
||||
data,
|
||||
(
|
||||
dfdx::prelude::Const::<1>,
|
||||
dfdx::prelude::Const::<10>,
|
||||
dfdx::prelude::Const::<4096>,
|
||||
),
|
||||
);
|
||||
let d_b = d_a.clone() * d_a;
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a * b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a * d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul2() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<(
|
||||
luminal::prelude::Const<1>,
|
||||
luminal::prelude::Const<1>,
|
||||
Dyn<'a'>,
|
||||
Dyn<'a'>,
|
||||
)>()
|
||||
.set_dyn(vec![82.4, 783.0, 99.6, 974.5], &[1, 1, 2, 2]);
|
||||
let b = cx.tensor::<R0>().set(vec![0.57735026]);
|
||||
let mut c = a * b.expand();
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([[[[82.4, 783.0], [99.6, 974.5]]]]);
|
||||
let d_b = d_dev.tensor(0.57735026);
|
||||
let d_c = d_a * d_b.broadcast::<_, dfdx::shapes::Axes4<0, 1, 2, 3>>();
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a / b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a / d_b;
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let b = cx.tensor::<R1<3>>().set(vec![1., 2., 3.]);
|
||||
let mut c = a.max(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([1., 2., 3.]);
|
||||
let d_b = d_dev.tensor([1., 2., 3.]);
|
||||
let d_c = d_a.maximum(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mod() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(3);
|
||||
let b_data = random_vec(3);
|
||||
let a = cx.tensor::<R1<3>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R1<3>>().set(b_data.clone());
|
||||
let mut c = a % b;
|
||||
c.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
// No dfdx equivalent
|
||||
|
||||
assert_close(
|
||||
&c.data(),
|
||||
&a_data
|
||||
.into_iter()
|
||||
.zip(b_data)
|
||||
.map(|(a, b)| a % b)
|
||||
.collect_vec(),
|
||||
);
|
||||
}
|
||||
|
||||
// Reduction op tests
|
||||
|
||||
#[test]
|
||||
fn test_sum_reduce() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(4096);
|
||||
let a = cx.tensor::<R3<1, 1, 4096>>();
|
||||
a.set(data.clone());
|
||||
let mut b = a.sum_reduce::<_, luminal::prelude::Axis<1>>().retrieve();
|
||||
let mut c = a.sum_reduce::<_, luminal::prelude::Axis<0>>().retrieve();
|
||||
let mut d = a.sum_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
data,
|
||||
(
|
||||
dfdx::shapes::Const::<1>,
|
||||
dfdx::shapes::Const::<1>,
|
||||
dfdx::shapes::Const::<4096>,
|
||||
),
|
||||
);
|
||||
let d_b = d_a.clone().sum::<_, dfdx::shapes::Axis<1>>();
|
||||
let d_c = d_a.clone().sum::<_, dfdx::shapes::Axis<0>>();
|
||||
let d_d = d_a.sum::<_, dfdx::shapes::Axis<2>>();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_reduce() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R3<2, 2, 3>>();
|
||||
a.set(data.clone());
|
||||
let mut b = a.max_reduce::<_, luminal::prelude::Axis<1>>().retrieve();
|
||||
let mut c = a.max_reduce::<_, luminal::prelude::Axis<0>>().retrieve();
|
||||
let mut d = a.max_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
data,
|
||||
(
|
||||
dfdx::shapes::Const::<2>,
|
||||
dfdx::shapes::Const::<2>,
|
||||
dfdx::shapes::Const::<3>,
|
||||
),
|
||||
);
|
||||
let d_b = d_a.clone().max::<_, dfdx::shapes::Axis<1>>();
|
||||
let d_c = d_a.clone().max::<_, dfdx::shapes::Axis<0>>();
|
||||
let d_d = d_a.max::<_, dfdx::shapes::Axis<2>>();
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.mean_reduce::<_, luminal::prelude::Axis<2>>().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
data,
|
||||
(
|
||||
dfdx::shapes::Const::<1>,
|
||||
dfdx::shapes::Const::<10>,
|
||||
dfdx::shapes::Const::<4096>,
|
||||
),
|
||||
);
|
||||
let d_b = d_a.mean::<_, dfdx::shapes::Axis<2>>();
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_simple() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(256 * 256);
|
||||
let b_data = random_vec(256 * 256);
|
||||
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
a_data,
|
||||
(dfdx::shapes::Const::<256>, dfdx::shapes::Const::<256>),
|
||||
);
|
||||
let d_b = d_dev.tensor_from_vec(
|
||||
b_data,
|
||||
(dfdx::shapes::Const::<256>, dfdx::shapes::Const::<256>),
|
||||
);
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(512 * 512);
|
||||
let b_data = random_vec(512 * 512);
|
||||
let a = cx.tensor::<R2<512, 512>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<512, 512>>().set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
a_data,
|
||||
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<512>),
|
||||
);
|
||||
let d_b = d_dev.tensor_from_vec(
|
||||
b_data,
|
||||
(dfdx::shapes::Const::<512>, dfdx::shapes::Const::<512>),
|
||||
);
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<R3<2, 2, 3>>()
|
||||
.set(vec![1., 2., 3., 1., 2., 1., 1., 2., 3., 1., 2., 1.]);
|
||||
let b = cx
|
||||
.tensor::<R2<3, 4>>()
|
||||
.set(vec![1., 2., 3., 1., 1., 2., 1., 2., -1., -2., 1., 2.]);
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([[[1., 2., 3.], [1., 2., 1.]], [[1., 2., 3.], [1., 2., 1.]]]);
|
||||
let d_b = d_dev.tensor([[1., 2., 3., 1.], [1., 2., 1., 2.], [-1., -2., 1., 2.]]);
|
||||
let d_c = d_a.matmul(d_b);
|
||||
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_transpose() {
|
||||
const M: usize = 1024; // Any
|
||||
const K: usize = 16; // >= 16
|
||||
const N: usize = 256; // >= 256, power of 2
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(M * K, &mut rng);
|
||||
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(K * M, &mut rng);
|
||||
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a.matmul(b.permute()).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b.permute())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b_t)
|
||||
.retrieve();
|
||||
|
||||
cx.compile(
|
||||
MetalCompiler::<f32>::default(),
|
||||
(&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (dfdx::shapes::Const::<M>, dfdx::shapes::Const::<K>));
|
||||
let d_b = d_dev.tensor_from_vec(b_data, (dfdx::shapes::Const::<N>, dfdx::shapes::Const::<K>));
|
||||
let d_a_t = d_dev.tensor_from_vec(
|
||||
a_t_data,
|
||||
(dfdx::shapes::Const::<K>, dfdx::shapes::Const::<M>),
|
||||
);
|
||||
let d_b_t = d_dev.tensor_from_vec(
|
||||
b_t_data,
|
||||
(dfdx::shapes::Const::<K>, dfdx::shapes::Const::<N>),
|
||||
);
|
||||
let d_a_b = d_a.clone().matmul(d_b.clone().permute());
|
||||
let d_a_b_t = d_a.matmul(d_b_t.clone());
|
||||
let d_a_t_b = d_a_t
|
||||
.clone()
|
||||
.permute::<_, dfdx::shapes::Axes2<1, 0>>()
|
||||
.matmul(d_b.permute());
|
||||
let d_a_t_b_t = d_a_t
|
||||
.permute::<_, dfdx::shapes::Axes2<1, 0>>()
|
||||
.matmul(d_b_t);
|
||||
|
||||
assert_close(&a_b.data(), &d_a_b.as_vec());
|
||||
assert_close(&a_b_t.data(), &d_a_b_t.as_vec());
|
||||
assert_close(&a_t_b.data(), &d_a_t_b.as_vec());
|
||||
assert_close(&a_t_b_t.data(), &d_a_t_b_t.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_and_linear() {
|
||||
// Test single and batch, unoptimized and optimized
|
||||
let mut cx = Graph::new();
|
||||
let input_data = random_vec(32);
|
||||
let w1 = random_vec(32 * 64);
|
||||
let w2 = random_vec(32 * 64);
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 32>>("Batch")
|
||||
.set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
|
||||
|
||||
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
|
||||
model.0.weight.set(w1.clone());
|
||||
model.2.weight.set(w2.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
let mut batch_out = model.forward(batch).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let unoptimized_b = b.data();
|
||||
let unoptimized_batch_out = batch_out.data();
|
||||
b.drop();
|
||||
batch_out.drop();
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f32>)>::default(),
|
||||
(&mut b, &mut batch_out),
|
||||
);
|
||||
cx.execute();
|
||||
|
||||
assert_close_precision(&unoptimized_b, &b.data(), 2);
|
||||
assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2);
|
||||
|
||||
// Test against dfdx
|
||||
let dev = Cpu::default();
|
||||
let mut model = <(
|
||||
dfdx::nn::modules::builders::UnbiasedLinear<32, 64>,
|
||||
dfdx::nn::modules::builders::ReLU,
|
||||
dfdx::nn::modules::builders::UnbiasedLinear<64, 32>,
|
||||
)>::build_on_device(&dev);
|
||||
// Set weights
|
||||
model.0.weight = dev
|
||||
.tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>))
|
||||
.permute();
|
||||
model.2.weight = dev
|
||||
.tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>))
|
||||
.permute();
|
||||
let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,));
|
||||
let out = model.forward(a);
|
||||
|
||||
assert_close_precision(&unoptimized_b, &out.as_vec(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transformer_encoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<3, 4, 1> =
|
||||
InitModule::initialize(&mut cx);
|
||||
model
|
||||
.attention
|
||||
.w_k
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_q
|
||||
.weight
|
||||
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_v
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_o
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model
|
||||
.ff
|
||||
.0
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
|
||||
model
|
||||
.ff
|
||||
.2
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<3>)>()
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[1, 2, 3]);
|
||||
let mut b = model.forward(a).retrieve();
|
||||
|
||||
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> =
|
||||
d_dev.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<3, 1, 4>, f32>();
|
||||
d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]);
|
||||
d_model.self_attn.w_o.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
|
||||
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_k.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
|
||||
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_q.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.],
|
||||
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_v.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.],
|
||||
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .0.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.],
|
||||
(dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (dfdx::shapes::Const::<4>,));
|
||||
d_model.ff.0 .2.weight = d_dev
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.],
|
||||
(dfdx::shapes::Const::<4>, dfdx::shapes::Const::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
|
||||
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (dfdx::shapes::Const::<3>,));
|
||||
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (dfdx::shapes::Const::<3>,));
|
||||
d_model.norm1.epsilon = 1e-5;
|
||||
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
|
||||
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (dfdx::shapes::Const::<3>,));
|
||||
d_model.norm2.epsilon = 1e-5;
|
||||
let d_a = d_dev.tensor_from_vec(
|
||||
vec![-1., 2., 3., 3., 3., -1.],
|
||||
(dfdx::shapes::Const::<2>, dfdx::shapes::Const::<3>),
|
||||
);
|
||||
let d_b = d_model.forward(d_a);
|
||||
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
2
crates/luminal_metal/src/tests/mod.rs
Normal file
2
crates/luminal_metal/src/tests/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
mod fp16;
|
||||
mod fp32;
|
||||
1432
crates/luminal_metal/src/unary.rs
Normal file
1432
crates/luminal_metal/src/unary.rs
Normal file
File diff suppressed because it is too large
Load Diff
38
docs/01 Introduction.md
Normal file
38
docs/01 Introduction.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Luminal Introduction
|
||||
|
||||
Let's get up to speed with how to use luminal, and how it works internally.
|
||||
|
||||
First we'll take a look at what the simplest program will look like:
|
||||
```rust
|
||||
use luminal::prelude::*;
|
||||
|
||||
// Setup graph and tensors (1)
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.new_tensor::<R1<3>>()
|
||||
.set(vec![1.0, 2.0, 3.0]);
|
||||
let b = cx.new_tensor::<R1<3>>()
|
||||
.set(vec![1.0, 2.0, 3.0]);
|
||||
|
||||
// Actual operations (2)
|
||||
let c = (a + b).retrieve();
|
||||
|
||||
// Run graph (3)
|
||||
cx.execute();
|
||||
|
||||
// Get result (4)
|
||||
println!("Result: {:?}", c);
|
||||
// Prints out [2.0, 4.0, 6.0]
|
||||
```
|
||||
Wow! A lot is going on here just to add two tensors together. That's because luminal isn't really designed for such simple computation, and there's little benifit to using it here. But we'll see it pay off when we start doing more complex operations.
|
||||
|
||||
So what's happening here?
|
||||
1) We're setting up a new `Graph` which tracks all computation and actually does execution. We're also defining two new tensors, both of shape (3,). At this point, these "tensors" are actually `GraphTensor`s that don't hold any data. Also, notice we pass in the shape as a type generic. *Types are known at compile time, similar to [dfdx](https://github.com/coreylowman/dfdx)!*
|
||||
2) Now we can start doing the thing we came here for: the addition. So we add two `GraphTensor`s together, and get a new `GraphTensor`. Notice this *does not* consume anything, and we're free to use a or b later on. This is because `GraphTensor` is a super lightweight tracking struct which implements copy. "But wait, we never set tbe values of a and b, how can we add them? **We aren't actually adding them here.** Instead, we're writing this addition to the graph, and getting out c, which points to the result when it's actually done.
|
||||
|
||||
Then we set the data for these tensors. But if `GraphTensor` doesn't hold data, how can we set it? Well we aren't actually setting it *in* the tensor, just passing it through to the graph to say *once you run, set this tensor to this value.* We also need to mark the output we want to retrieve later. This is so that when the graph runs, it doesn't delete the data for c part-way through execution (a common optimization for unused tensors). Notice we're setting the sources *after* we define the computation. This is backward from a lot of other libs, but it means we can redefine the data and rerun everything without redefining the computation later on.
|
||||
3) Once we call `cx.execute()`, we've already set all our sources, so our addition actually gets ran and stored in c!
|
||||
4) Now since we're done computing c, we can fetch the data for c and see the result.
|
||||
|
||||
Alright, that was a lot but now we've touched on all the main aspects of running a model in luminal.
|
||||
|
||||
[Let's take a look at each piece in more depth.](https://github.com/jafioti/luminal/blob/main/docs/02%20GraphTensor%20API.md)
|
||||
19
docs/02 GraphTensor API.md
Normal file
19
docs/02 GraphTensor API.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# GraphTensors
|
||||
|
||||
We're working with pretty complicated graphs to build our computation on, but we don't want to manually place all the nodes ourselves! So how can we build these static graphs in a nice, familiar way? GraphTensors!
|
||||
|
||||
Essentially GraphTensors are pointers to a specific node on the graph, as well as some metadata about the output of that node, such as its shape. We can make a new GraphTensor by doing:
|
||||
```rust
|
||||
let mut cx = Graph::new(); // We need a graph to build!
|
||||
let a: GraphTensor<R1<3>> = cx.tensor(); // Here we create a new node on the graph and get a GraphTensor back, pointing to it.
|
||||
```
|
||||
Notice the type of `a`: `GraphTensor<R1<3>>`. So what's that generic all about? It's the shape! We make tensor shapes part of the type, so they're tracked at compile time! In this case, the shape is rank 1, with 3 elements, or in other words, a vector of 3 dimensions. (Side note: `R1<N>` is a typedef of `(Const<N>,)`) It should be impossible to accidentally get a runtime shape mismatch.
|
||||
|
||||
Now we can use the `a` as you would in a library like PyTorch, performing linear algebra:
|
||||
```rust
|
||||
let b = a.exp().sqrt();
|
||||
let c = b + a;
|
||||
```
|
||||
Looks familiar!
|
||||
|
||||
[Let's take a look at how GraphTensors are used to build whole neural networks.](https://github.com/jafioti/luminal/blob/main/docs/03%20Modules.md)
|
||||
31
docs/03 Modules.md
Normal file
31
docs/03 Modules.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# NN Modules
|
||||
Like any good DL library, we organize our networks into `Module`s. Here is the module trait:
|
||||
```rust
|
||||
/// A module with a forward pass
|
||||
pub trait Module<I> {
|
||||
type Output;
|
||||
fn forward(&self, input: I) -> Self::Output;
|
||||
}
|
||||
```
|
||||
Super simple, we just define a forward function that takes an input and returns an output. A consequence of this is it allows us to define seperate forward passes for single and batched inputs!
|
||||
|
||||
Now let's take a look at how `Linear` is defined:
|
||||
```rust
|
||||
/// A simple linear layer
|
||||
pub struct Linear<const A: usize, const B: usize> {
|
||||
pub(crate) weight: GraphTensor<R2<A, B>>,
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
|
||||
type Output = GraphTensor<R1<B>>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
|
||||
input.matmul(self.weight)
|
||||
}
|
||||
}
|
||||
```
|
||||
Here we see a single weight matrix as the internal state, of size AxB. We've written a single forward function for single input vectors of shape (A,) and matmul it by our weight matrix to get an output of shape (B,).
|
||||
|
||||
Now all of these ops are recorded on the graph, to be compiled and ran later on.
|
||||
|
||||
[So how does this compilation work? Let's find out!](https://github.com/jafioti/luminal/blob/main/docs/04%20Compilers.md)
|
||||
27
docs/04 Compilers.md
Normal file
27
docs/04 Compilers.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Compilers
|
||||
|
||||
So now we have our graph all set up. We did our forward passes through the model, so now what? Do we run it?
|
||||
|
||||
We could! But it wouldn't be very fast. Right now your graph is full of **primops**, which are the simplest set of primitive operations in luminal. One of the key tenants of luminal is a small primop set, which makes it easy to add new backends and write compilers for. But another consequence of a small primset is that even simple operations usually end up creating quite a few operations, and even small neural networks can end up with hundreds or thousands of primops, which are slow to run directly. So it's time to compile the graph!
|
||||
|
||||
Compilers are structs that implement the `Compiler` trait, which simply specifies a single function:
|
||||
```rust
|
||||
pub trait Compiler {
|
||||
/// Run a compilation pass
|
||||
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, remap: T);
|
||||
}
|
||||
```
|
||||
So all a compiler does is take a mutable reference to the graph, something called remap (beyond the scope of this introduction), and does something to the graph. That something is compilation, usually in the form of finding patterns of nodes and replacing them with other nodes. For instance, there's no Subtract operation in the primops, so subtractions are implemented as `add(a, mul(b, -1))`. We can have a compiler that looks for that pattern of nodes and directly replaces it with a `Subtract` operation. We'll look at how to do this in the [Writing Compilers](https://github.com/jafioti/luminal/blob/main/docs/06%20Writing%20Compilers.md) section.
|
||||
|
||||
All you need to know for now is that we can use this compiler on the graph by doing:
|
||||
```rust
|
||||
cx.compile(SubtractionCompiler::default());
|
||||
```
|
||||
Now the graph will have the old mul + add pattern removed and Subtract ops placed in. There are plenty of different compilers for different purposes. Some of the popular ones:
|
||||
- GenericCompiler - A handful of hardware-agnostic optimizations like [CSE](https://en.wikipedia.org/wiki/Common_subexpression_elimination) to be ran before any hardware-specific compilers.
|
||||
- CudaCompiler<T> - The full stack of cuda compilers to convert a graph to a cuda-specialized graph with T as the datatype (either f32 or f16). Imported from luminal_cuda
|
||||
- MetalCompiler<T> - Same as CudaCompiler. Imported from luminal_metal
|
||||
|
||||
Compilers are entirely seperate from luminal, so they can be fully implemented by third party crates. For instance, everything specific to Cuda is contained in luminal_cuda.
|
||||
|
||||
[Now let's look into how to load weights from a file.](https://github.com/jafioti/luminal/blob/main/docs/05%20Serialization.md)
|
||||
1
docs/05 Serialization.md
Normal file
1
docs/05 Serialization.md
Normal file
@@ -0,0 +1 @@
|
||||
Coming Soon
|
||||
1
docs/06 Writing Compilers.md
Normal file
1
docs/06 Writing Compilers.md
Normal file
@@ -0,0 +1 @@
|
||||
Coming Soon
|
||||
10
docs/CONTRIBUTING.md
Normal file
10
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Contributing to luminal
|
||||

|
||||
|
||||
Please take a look at the [issues](https://github.com/jafioti/luminal/issues) and [roadmap](https://github.com/users/jafioti/projects/1) to see what's targeted for upcoming releases. Contributions for those features are preferred and will be reviewed and merged very rapidly. Other contributions are welcome, but please note luminal is and always will be a fairly minimal library.
|
||||
|
||||
The core design of luminal is heavily predicated on extensibility. Compilers alow for immense complexity to be removed from the core library and added with third party compilers. For instance, datatypes and devices are typically first class primitives. In luminal, they're compilers and the core has no idea about them. This is the general trend we'll stick to: core remains brutally simple, and everything that can be externalized to a compiler will be.
|
||||
|
||||
We will be adding training support soon, and as you guessed, it will entirely reside in a compiler. Just define the model's graph, run the output through an optimizer, and then run the `AutogradCompiler` before any other compilers. Boom, we got training, and the core of the library has no idea! (aside from some quality of life apis)
|
||||
|
||||
PRs that remove complexity are always welcome, but note that line count often is a bad proxy for complexity. Ideally the entire luminal core should be a few thousand lines of code, but anything remotely resembling code golf is not allowed.
|
||||
@@ -1,68 +0,0 @@
|
||||
## Luminal Introduction
|
||||
|
||||
Let's get up to speed with how to use luminal, and how it works internally.
|
||||
|
||||
First we'll take a look at what the simplest program will look like:
|
||||
```rust
|
||||
use luminal::prelude::*;
|
||||
|
||||
// Setup graph and tensors (1)
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.new_tensor::<R1<3>>();
|
||||
let b = cx.new_tensor::<R1<3>>();
|
||||
|
||||
// Actual operations (2)
|
||||
let c = a + b;
|
||||
|
||||
// Set inputs and mark outputs (3)
|
||||
a.set(vec![1.0, 2.0, 3.0]);
|
||||
b.set(vec![1.0, 2.0, 3.0]);
|
||||
c.mark();
|
||||
|
||||
// Run graph (4)
|
||||
cx.execute();
|
||||
|
||||
// Get result (5)
|
||||
println!("Result: {:?}", c.retrieve().unwrap().real_data(c.view().unwrap()).unwrap());
|
||||
// Prints out [2.0, 4.0, 6.0]
|
||||
```
|
||||
Wow! A lot is going on here just to add two tensors together. That's because luminal isn't really designed for such simple computation, and there's little benifit to using it here. But we'll see it pay off when we start doing more complex operations.
|
||||
|
||||
So what's happening here?
|
||||
1) We're setting up a new `Graph` which tracks all computation and actually does execution. We're also defining two new tensors, both of shape (3,). At this point, these "tensors" are actually `GraphTensor`s that don't hold any data. Also, notice we pass in the shape as a type generic. *Types are known at compile time, similar to [dfdx](https://github.com/coreylowman/dfdx)!*
|
||||
2) Now we can start doing the thing we came here for: the addition. So we add two `GraphTensor`s together, and get a new `GraphTensor`. Notice this *does not* consume anything, and we're free to use a or b later on. This is because `GraphTensor` is a super lightweight tracking struct which implements copy. "But wait, we never set tbe values of a and b, how can we add them? **We aren't actually adding them here.** Instead, we're writing this addition to the graph, and getting out c, which points to the result when it's actually done.
|
||||
3) Then we set the data for these tensors. But if `GraphTensor` doesn't hold data, how can we set it? Well we aren't actually setting it *in* the tensor, just passing it through to the graph to say *once you run, set this tensor to this value.* We also need to mark the output we want to retrieve later. This is so that when the graph runs, it doesn't delete the data for c part-way through execution (a common optimization for unused tensors). Notice we're setting the sources *after* we define the computation. This is backward from a lot of other libs, but it means we can redefine the data and rerun everything without redefining the computation later on.
|
||||
4) Once we call `cx.execute()`, we've already set all our sources, so our addition actually gets ran and stored in c!
|
||||
5) Now since we're done computing c, we can fetch the data for c and see the result. *This API is likely to change, as it's very ugly.*
|
||||
|
||||
Alright, that was a lot but now we've touched on all the main aspects of running a model in luminal.
|
||||
|
||||
## NN Modules
|
||||
Like any good DL library, we organize our networks into `Module`s. Here is the module trait:
|
||||
```rust
|
||||
/// A module with a forward pass
|
||||
pub trait Module<I> {
|
||||
type Output;
|
||||
fn forward(&self, input: I) -> Self::Output;
|
||||
}
|
||||
```
|
||||
Super simple, we just define a forward function that takes an input and returns an output. A consequence of this is it allows us to define seperate forward passes for single and batched inputs!
|
||||
|
||||
Now let's take a look at how `Linear` is defined:
|
||||
```rust
|
||||
/// A simple linear layer
|
||||
pub struct Linear<const A: usize, const B: usize> {
|
||||
pub(crate) weight: GraphTensor<R2<A, B>>,
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> Module<GraphTensor<R1<A>>> for Linear<A, B> {
|
||||
type Output = GraphTensor<R1<B>>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<R1<A>>) -> Self::Output {
|
||||
input.matmul(self.weight)
|
||||
}
|
||||
}
|
||||
```
|
||||
Here we see a single weight matrix as the internal state, of size AxB. We've written a single forward function for single input vectors of shape (A,) and matmul it by our weight matrix to get an output of shape (B,).
|
||||
|
||||
Again, notice we're only dealing with `GraphTensor`s here, so when this code actually gets ran, **no computation happens, it just gets recorded to the graph.**
|
||||
16
examples/llama/.gitignore
vendored
Normal file
16
examples/llama/.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
setup/llama-7b-hf
|
||||
.vscode
|
||||
20
examples/llama/Cargo.toml
Normal file
20
examples/llama/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "llama"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
metal = ["dep:luminal_metal", "dep:metal-rs"]
|
||||
cuda = ["dep:luminal_cuda"]
|
||||
|
||||
[dependencies]
|
||||
luminal = {path="../.."}
|
||||
luminal_metal = {path="../../crates/luminal_metal", optional=true}
|
||||
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
|
||||
rust_tokenizers = "8.1.0"
|
||||
clap = { version = "4.4.18", features = ["derive"] }
|
||||
byteorder = "1.5.0"
|
||||
memmap2 = "0.9.4"
|
||||
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
|
||||
colored = "2.1.0"
|
||||
itertools = "0.12.1"
|
||||
@@ -1,22 +0,0 @@
|
||||
// Common
|
||||
pub const VOCAB: usize = 32_000;
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
pub const HEAD_DIM_OVER_2: usize = 64;
|
||||
|
||||
// 7B
|
||||
pub const HIDDEN: usize = 4096;
|
||||
pub const INTERMEDIATE: usize = 11008;
|
||||
pub const HEADS: usize = 32;
|
||||
pub const LAYERS: usize = 1;
|
||||
|
||||
// 13B
|
||||
// pub const HIDDEN: usize = 5120;
|
||||
// pub const INTERMEDIATE: usize = 13824;
|
||||
// pub const HEADS: usize = 40;
|
||||
// pub const LAYERS: usize = 40;
|
||||
|
||||
// 65B
|
||||
// pub const HIDDEN: usize = 8192;
|
||||
// pub const INTERMEDIATE: usize = 22016;
|
||||
// pub const HEADS: usize = 64;
|
||||
// pub const LAYERS: usize = 80;
|
||||
@@ -1,92 +0,0 @@
|
||||
mod config;
|
||||
mod loader;
|
||||
mod model;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use model::LlamaForCausalLM;
|
||||
|
||||
use crate::model::KVCache;
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn main() {
|
||||
let tokenizer = tokenizers::tokenizer::Tokenizer::from_pretrained("oobabooga/llama-tokenizer", None).unwrap();
|
||||
|
||||
let mut input: Vec<usize> = tokenizer.encode("The young boy ran over to the", false).unwrap().get_ids().iter().map(|i| *i as usize).collect();
|
||||
|
||||
println!("Creating Graph...");
|
||||
let mut cx = Graph::new();
|
||||
let model: LlamaForCausalLM<
|
||||
{ config::VOCAB },
|
||||
{ config::HEADS },
|
||||
{ config::HIDDEN },
|
||||
{ config::INTERMEDIATE },
|
||||
{ config::HEAD_DIM },
|
||||
{ config::HEAD_DIM_OVER_2 },
|
||||
{ config::LAYERS },
|
||||
> = InitModule::initialize(&mut cx);
|
||||
let inp = cx.new_tensor::<(usize, usize)>("Input");
|
||||
let (out, cache_src) = model.forward(inp);
|
||||
out.mark();
|
||||
for (k, v) in &cache_src {
|
||||
k.mark_no_delete();
|
||||
v.mark_no_delete();
|
||||
}
|
||||
|
||||
println!("Loading...");
|
||||
loader::DfdxDeferredLoader::new("../../Desktop/llama-dfdx-main/llama-7b-hf").load(&model, &mut cx);
|
||||
|
||||
println!("Inferencing...");
|
||||
// First pass
|
||||
inp.set_dyn(input.clone(), vec![1, input.len()]);
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
cx.display_shapes();
|
||||
cx.execute();
|
||||
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());
|
||||
|
||||
let out = out.retrieve().unwrap().real_data(out.view().unwrap()).unwrap();
|
||||
input.push(sample_index(&out[(input.len() - 1) * 32_000..]));
|
||||
println!("{}", tokenizer.decode(input.iter().map(|i| *i as u32).collect(), false).unwrap());
|
||||
|
||||
|
||||
// Build KV cache forward graph
|
||||
let (out, cache_dest): (_, Vec<KVCache<_, usize, {config::HEADS}, {config::HEAD_DIM}>>) = model.forward_kv((inp, cache_src.clone()));
|
||||
out.mark();
|
||||
for (k, v) in &cache_dest {
|
||||
k.mark_no_delete();
|
||||
v.mark_no_delete();
|
||||
}
|
||||
cx.prune([out.id], cache_src.iter().flat_map(|(k, v)| [k.id, v.id]));
|
||||
|
||||
loop {
|
||||
inp.set_dyn(vec![*input.last().unwrap()], vec![1, 1]);
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
cx.execute();
|
||||
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());
|
||||
|
||||
let o = out.retrieve().unwrap().real_data(out.view().unwrap()).unwrap();
|
||||
// Sample tokens
|
||||
input.push(sample_index(&o));
|
||||
println!("{}", tokenizer.decode(input.iter().map(|i| *i as u32).collect(), false).unwrap());
|
||||
|
||||
// Swap caches
|
||||
for ((src_k, src_v), (dest_k, dest_v)) in cache_src.iter().copied().zip(cache_dest.iter().copied()) {
|
||||
// Move dest caches to src
|
||||
cx.swap_tensors(src_k, dest_k);
|
||||
cx.swap_tensors(src_v, dest_v);
|
||||
// Drop dest caches
|
||||
dest_k.drop();
|
||||
dest_v.drop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Currently just an argmax, do actual sampling here
|
||||
fn sample_index(dist: &[f32]) -> usize {
|
||||
dist.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
@@ -1,749 +0,0 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
use luminal::{
|
||||
nn::{activation::RMSNorm, embedding::Embedding},
|
||||
op,
|
||||
prelude::{movement::TryConcatAlong, *},
|
||||
};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
// Full LLaMa model implementation, heavily based off of https://github.com/coreylowman/llama-dfdx/blob/main/src/modeling.rs
|
||||
|
||||
pub type KVCache<Batch, Seq, const NUM_HEADS: usize, const HEAD_DIM: usize> = (
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
|
||||
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize, B: Dim, S: Dim> Module<GraphTensor<(B, S, Const<H>)>>
|
||||
for Mlp<I, H>
|
||||
{
|
||||
type Output = GraphTensor<(B, S, Const<H>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(B, S, Const<H>)>) -> Self::Output {
|
||||
let gate = input.matmul(self.gate_proj.permute());
|
||||
let gate = gate.sigmoid() * gate;
|
||||
let up = input.matmul(self.up_proj.permute()) * gate;
|
||||
up.matmul(self.down_proj.permute())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: cx.new_tensor("Gate Weight"),
|
||||
up_proj: cx.new_tensor("Up Weight"),
|
||||
down_proj: cx.new_tensor("Down Weight"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("gate_proj/weight", self.gate_proj);
|
||||
s.tensor("up_proj/weight", self.up_proj);
|
||||
s.tensor("down_proj/weight", self.down_proj);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RotaryEmbedding<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> {
|
||||
pub inv_freq: GraphTensor<R1<HEAD_DIM_OVER_2>>,
|
||||
}
|
||||
|
||||
impl<
|
||||
Batch: Dim,
|
||||
const NUM_HEADS: usize,
|
||||
Seq: Dim,
|
||||
PrevSeq: Dim,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
)> for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(q, k, cache): (
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let (sin, cos) = self.get_sincos(q, cache);
|
||||
let sin = sin.expand();
|
||||
let cos = cos.expand();
|
||||
let q_embed = (Self::rotate_half(q) * sin) + (q * cos);
|
||||
let k_embed = (Self::rotate_half(k) * sin) + (k * cos);
|
||||
(q_embed, k_embed)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize>
|
||||
RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn get_sincos<Batch: Dim, const NUM_HEADS: usize, Seq: Dim, PrevSeq: Dim>(
|
||||
&self,
|
||||
seq_tensor: GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
cache: Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
) -> (
|
||||
GraphTensor<(Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Seq, Const<HEAD_DIM>)>,
|
||||
) {
|
||||
let graph = unsafe { self.inv_freq.graph_ref.as_mut().unwrap() };
|
||||
let has_cache = cache.is_some();
|
||||
let mut op = graph
|
||||
.add_op(
|
||||
op::Function(
|
||||
"ARange".to_string(),
|
||||
Box::new(move |inp, i| {
|
||||
let offset = if has_cache {
|
||||
inp[1].1.shape.shape()[2]
|
||||
} else {
|
||||
0
|
||||
};
|
||||
(
|
||||
Some(Tensor {
|
||||
data: Box::new(
|
||||
(0..inp[0].1.shape.shape()[2])
|
||||
.map(|i| (i + offset) as f32)
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
}),
|
||||
TensorView {
|
||||
tensor_id: i,
|
||||
shape: ShapeTracker::new(vec![inp[0].1.shape.shape()[2]]),
|
||||
},
|
||||
)
|
||||
}),
|
||||
),
|
||||
vec![Seq::const_size()],
|
||||
)
|
||||
.input(seq_tensor.id);
|
||||
if has_cache {
|
||||
op = op.input(cache.unwrap().0.id);
|
||||
}
|
||||
let t: GraphTensor<(Seq,)> = GraphTensor::from_id(op.finish(), graph);
|
||||
let freqs = t
|
||||
.expand::<(Seq, Const<1>), _>()
|
||||
.matmul(
|
||||
self.inv_freq
|
||||
.expand::<(Const<1>, Const<HEAD_DIM_OVER_2>), _>(),
|
||||
)
|
||||
.realize::<(Seq, usize)>();
|
||||
let emb = (freqs, freqs).concat_along(Axis::<1>);
|
||||
(emb.sin().realize(), emb.cos().realize())
|
||||
}
|
||||
|
||||
fn rotate_half<Batch: Dim, NumHeads: Dim, Seq: Dim>(
|
||||
x: GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)>,
|
||||
) -> GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)> {
|
||||
let x1 = x.slice((.., .., .., ..HEAD_DIM_OVER_2));
|
||||
let x2 = x.slice((.., .., .., HEAD_DIM_OVER_2..));
|
||||
(-x2, x1).concat_along(Axis::<3>).realize()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> InitModule
|
||||
for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
let s = Self {
|
||||
inv_freq: cx.new_tensor("Inv Freq"),
|
||||
};
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
s.inv_freq.set(
|
||||
(0..HEAD_DIM_OVER_2)
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize> SerializeModule
|
||||
for RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("inv_freq", self.inv_freq);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Attention<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> {
|
||||
pub q_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub k_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub v_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub o_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub rotary_embed: RotaryEmbedding<HEAD_DIM, HEAD_DIM_OVER_2>,
|
||||
}
|
||||
|
||||
fn attn_forward<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
Batch: Dim,
|
||||
Seq: Dim,
|
||||
PrevSeq: Dim,
|
||||
>(
|
||||
attn: &Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>,
|
||||
x: GraphTensor<(Batch, Seq, Const<HIDDEN>)>,
|
||||
cache: Option<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
) -> (
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
) {
|
||||
let q = x
|
||||
.matmul(attn.q_proj.permute())
|
||||
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
|
||||
Batch::const_size().to_reshape(0),
|
||||
Seq::const_size().to_reshape(1),
|
||||
ReshapeDim::Const(NUM_HEADS),
|
||||
ReshapeDim::Const(HEAD_DIM),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
|
||||
let k = x
|
||||
.matmul(attn.k_proj.permute())
|
||||
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
|
||||
Batch::const_size().to_reshape(0),
|
||||
Seq::const_size().to_reshape(1),
|
||||
ReshapeDim::Const(NUM_HEADS),
|
||||
ReshapeDim::Const(HEAD_DIM),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let v = x
|
||||
.matmul(attn.v_proj.permute())
|
||||
.dyn_reshape::<(Batch, Seq, Const<NUM_HEADS>, Const<HEAD_DIM>)>(vec![
|
||||
Batch::const_size().to_reshape(0),
|
||||
Seq::const_size().to_reshape(1),
|
||||
ReshapeDim::Const(NUM_HEADS),
|
||||
ReshapeDim::Const(HEAD_DIM),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let (q, k) = attn.rotary_embed.forward((
|
||||
q.realize::<(Batch, Const<NUM_HEADS>, Seq, Const<HEAD_DIM>)>(),
|
||||
k.realize(),
|
||||
cache,
|
||||
));
|
||||
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
Batch: Dim,
|
||||
CurSeq: Dim,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
GraphTensor<(CurSeq, CurSeq)>,
|
||||
)> for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>,
|
||||
);
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(x, attn_mask): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
GraphTensor<(CurSeq, CurSeq)>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let (q, k, v) = attn_forward(
|
||||
self,
|
||||
x,
|
||||
Option::<KVCache<_, usize, NUM_HEADS, HEAD_DIM>>::None,
|
||||
);
|
||||
let inv_head_scale = (HEAD_DIM as f64).sqrt().recip() as f32;
|
||||
let w = q
|
||||
.batch_matmul(k.permute())
|
||||
.mul(inv_head_scale)
|
||||
.add(attn_mask.expand())
|
||||
.softmax::<3>();
|
||||
|
||||
let o = w
|
||||
.batch_matmul(v)
|
||||
.permute::<(Batch, CurSeq, Const<NUM_HEADS>, Const<HEAD_DIM>), _>()
|
||||
.dyn_reshape::<(Batch, CurSeq, Const<HIDDEN>)>(vec![
|
||||
Batch::const_size().to_reshape(0),
|
||||
CurSeq::const_size().to_reshape(1),
|
||||
ReshapeDim::Const(HIDDEN),
|
||||
]);
|
||||
|
||||
(o.matmul(self.o_proj.permute()), (k, v))
|
||||
}
|
||||
}
|
||||
|
||||
// KV cache forward
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
|
||||
&self,
|
||||
(x, cache): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>,
|
||||
),
|
||||
) -> (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>,
|
||||
) {
|
||||
let (q, k, v) = attn_forward(self, x, Some(cache));
|
||||
|
||||
// Add KV cache
|
||||
let k = (
|
||||
cache
|
||||
.0
|
||||
.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
|
||||
k.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
|
||||
)
|
||||
.concat_along(Axis::<2>)
|
||||
.realize::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>)>();
|
||||
let v = (
|
||||
cache
|
||||
.1
|
||||
.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
|
||||
v.realize::<(Batch, Const<NUM_HEADS>, usize, Const<HEAD_DIM>)>(),
|
||||
)
|
||||
.concat_along(Axis::<2>)
|
||||
.realize::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>)>();
|
||||
|
||||
let w = q
|
||||
.batch_matmul(k.permute())
|
||||
.mul((HEAD_DIM as f64).sqrt().recip() as f32) // Inv head scale
|
||||
.softmax::<3>();
|
||||
|
||||
let o = w
|
||||
.batch_matmul(v)
|
||||
.permute::<(Batch, CurSeq, Const<NUM_HEADS>, Const<HEAD_DIM>), _>()
|
||||
.dyn_reshape::<(Batch, CurSeq, Const<HIDDEN>)>(vec![
|
||||
Batch::const_size().to_reshape(0),
|
||||
CurSeq::const_size().to_reshape(1),
|
||||
ReshapeDim::Const(HIDDEN),
|
||||
]);
|
||||
|
||||
(o.matmul(self.o_proj.permute()), (k, v))
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> InitModule for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.new_tensor("Query Weight"),
|
||||
k_proj: cx.new_tensor("Key Weight"),
|
||||
v_proj: cx.new_tensor("Value Weight"),
|
||||
o_proj: cx.new_tensor("Output Weight"),
|
||||
rotary_embed: InitModule::initialize(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> SerializeModule for Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("q_proj/weight", self.q_proj);
|
||||
s.tensor("k_proj/weight", self.k_proj);
|
||||
s.tensor("v_proj/weight", self.v_proj);
|
||||
s.tensor("o_proj/weight", self.o_proj);
|
||||
s.module("rotary_emb", &self.rotary_embed);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecoderLayer<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> {
|
||||
pub self_attn: Attention<NUM_HEADS, HIDDEN, HEAD_DIM, HEAD_DIM_OVER_2>,
|
||||
pub mlp: Mlp<INTERMEDIATE, HIDDEN>,
|
||||
pub input_layer_norm: RMSNorm<HIDDEN>,
|
||||
pub post_attention_layer_norm: RMSNorm<HIDDEN>,
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
Batch: Dim,
|
||||
CurSeq: Dim,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
GraphTensor<(CurSeq, CurSeq)>,
|
||||
)> for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, attn_mask): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
GraphTensor<(CurSeq, CurSeq)>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let (y, kv_cache) = self
|
||||
.self_attn
|
||||
.forward((self.input_layer_norm.forward(x), attn_mask));
|
||||
let x = x + y;
|
||||
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
|
||||
(x + y, kv_cache)
|
||||
}
|
||||
}
|
||||
|
||||
// KV cache forward
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
|
||||
&self,
|
||||
(x, cache): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>,
|
||||
),
|
||||
) -> (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>,
|
||||
) {
|
||||
let (y, kv_cache) = self
|
||||
.self_attn
|
||||
.forward_kv((self.input_layer_norm.forward(x), cache));
|
||||
let x = x + y;
|
||||
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
|
||||
(x + y, kv_cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> InitModule for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
self_attn: InitModule::initialize(cx),
|
||||
mlp: InitModule::initialize(cx),
|
||||
input_layer_norm: InitModule::initialize(cx),
|
||||
post_attention_layer_norm: InitModule::initialize(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
> SerializeModule for DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>
|
||||
{
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("self_attn", &self.self_attn);
|
||||
s.module("mlp", &self.mlp);
|
||||
s.module("input_layernorm", &self.input_layer_norm);
|
||||
s.module("post_attention_layernorm", &self.post_attention_layer_norm);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> {
|
||||
pub embed_tokens: Embedding<VOCAB, HIDDEN>,
|
||||
pub layers: Vec<DecoderLayer<NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2>>,
|
||||
pub norm: RMSNorm<HIDDEN>,
|
||||
pub graph_ref: *mut Graph,
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
Batch: Dim,
|
||||
CurSeq: Dim,
|
||||
> Module<GraphTensor<(Batch, CurSeq)>>
|
||||
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Vec<KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
);
|
||||
fn forward(&self, input: GraphTensor<(Batch, CurSeq)>) -> Self::Output {
|
||||
let graph = unsafe { self.graph_ref.as_mut().unwrap() };
|
||||
let attn_mask: GraphTensor<(CurSeq, CurSeq)> = GraphTensor::from_id(
|
||||
graph
|
||||
.add_op(
|
||||
op::Function(
|
||||
"AttentionMask".to_string(),
|
||||
Box::new(|inp, i| {
|
||||
let seq_len = inp[0].1.shape.shape()[1];
|
||||
let mut data = vec![0.; seq_len * seq_len];
|
||||
for i in 0..seq_len {
|
||||
for j in (i + 1)..seq_len {
|
||||
data[i * seq_len + j] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
(
|
||||
Some(Tensor {
|
||||
data: Box::new(data),
|
||||
}),
|
||||
TensorView {
|
||||
tensor_id: i,
|
||||
shape: ShapeTracker::new(vec![
|
||||
inp[0].1.shape.shape()[1],
|
||||
inp[0].1.shape.shape()[1],
|
||||
]),
|
||||
},
|
||||
)
|
||||
}),
|
||||
),
|
||||
vec![CurSeq::const_size(), CurSeq::const_size()],
|
||||
)
|
||||
.input(input.id)
|
||||
.finish(),
|
||||
graph,
|
||||
);
|
||||
|
||||
let mut hidden_states = self.embed_tokens.forward(input);
|
||||
let mut caches = vec![];
|
||||
for layer_i in &self.layers {
|
||||
let (new_hidden_states, kv_cache) = layer_i.forward((hidden_states, attn_mask));
|
||||
hidden_states = new_hidden_states;
|
||||
caches.push(kv_cache);
|
||||
}
|
||||
(self.norm.forward(hidden_states), caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
pub fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
|
||||
&self,
|
||||
(input, caches): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
),
|
||||
) -> (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
) {
|
||||
let mut hidden_states = self.embed_tokens.forward(input);
|
||||
let mut new_caches = vec![];
|
||||
for (layer_i, cache) in self.layers.iter().zip(caches.into_iter()) {
|
||||
let (new_hidden_states, kv_cache) = layer_i.forward_kv((hidden_states, cache));
|
||||
hidden_states = new_hidden_states;
|
||||
new_caches.push(kv_cache);
|
||||
}
|
||||
(self.norm.forward(hidden_states), new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> InitModule
|
||||
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
norm: InitModule::initialize(cx),
|
||||
embed_tokens: InitModule::initialize(cx),
|
||||
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
|
||||
graph_ref: cx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> SerializeModule
|
||||
for Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("norm", &self.norm);
|
||||
s.module("embed_tokens", &self.embed_tokens);
|
||||
for (i, l) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("layers/{i}"), l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LlamaForCausalLM<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> {
|
||||
pub llama: Llama<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>,
|
||||
pub lm_head: GraphTensor<(Const<VOCAB>, Const<HIDDEN>)>,
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
Batch: Dim,
|
||||
CurSeq: Dim,
|
||||
> Module<GraphTensor<(Batch, CurSeq)>>
|
||||
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
|
||||
Vec<KVCache<Batch, CurSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
);
|
||||
fn forward(&self, input: GraphTensor<(Batch, CurSeq)>) -> Self::Output {
|
||||
let (hidden_states, caches) = self.llama.forward(input);
|
||||
(hidden_states.matmul(self.lm_head.permute()), caches)
|
||||
}
|
||||
}
|
||||
|
||||
// KV cache forward
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
pub fn forward_kv<Batch: Dim, CurSeq: Dim, PrevSeq: Dim, TotSeq: Dim>(
|
||||
&self,
|
||||
(input, caches): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Vec<KVCache<Batch, PrevSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
),
|
||||
) -> (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
|
||||
Vec<KVCache<Batch, TotSeq, NUM_HEADS, HEAD_DIM>>,
|
||||
) {
|
||||
let (hidden_states, caches) = self.llama.forward_kv((input, caches));
|
||||
(hidden_states.matmul(self.lm_head.permute()), caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> InitModule
|
||||
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
llama: InitModule::initialize(cx),
|
||||
lm_head: cx.new_tensor("LM Head"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const VOCAB: usize,
|
||||
const NUM_HEADS: usize,
|
||||
const HIDDEN: usize,
|
||||
const INTERMEDIATE: usize,
|
||||
const HEAD_DIM: usize,
|
||||
const HEAD_DIM_OVER_2: usize,
|
||||
const LAYERS: usize,
|
||||
> SerializeModule
|
||||
for LlamaForCausalLM<VOCAB, NUM_HEADS, HIDDEN, INTERMEDIATE, HEAD_DIM, HEAD_DIM_OVER_2, LAYERS>
|
||||
{
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("model", &self.llama);
|
||||
s.tensor("lm_head/weight", self.lm_head);
|
||||
}
|
||||
}
|
||||
28
examples/llama/setup/convert.py
Normal file
28
examples/llama/setup/convert.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("src", help="root directory", default="llama-7b-hf")
|
||||
args = parser.parse_args()
|
||||
|
||||
for f in os.listdir(args.src):
|
||||
if not f.endswith(".bin"):
|
||||
continue
|
||||
print(f"Loading {f}")
|
||||
sd = torch.load(os.path.join(args.src, f))
|
||||
for key, tensor in sd.items():
|
||||
print("Saving", key, tensor.shape, tensor.dtype)
|
||||
path = os.path.sep.join(key.split("."))
|
||||
os.makedirs(os.path.join(args.src, os.path.dirname(path)), exist_ok=True)
|
||||
np_array = tensor.numpy()
|
||||
with open(os.path.join(args.src, path), "w") as fp:
|
||||
np_array.tofile(fp)
|
||||
del np_array
|
||||
del sd
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
examples/llama/setup/setup.sh
Normal file
20
examples/llama/setup/setup.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
# Setup git LFS
|
||||
echo "Setting up git LFS..."
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
||||
sudo apt install git-lfs
|
||||
elif [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
brew install git-lfs
|
||||
fi
|
||||
git lfs install
|
||||
|
||||
echo "Downloading Model..."
|
||||
git lfs clone https://huggingface.co/decapoda-research/llama-7b-hf $SCRIPT_DIR/llama-7b-hf
|
||||
|
||||
# Convert the model
|
||||
echo "Converting Model..."
|
||||
python3 $SCRIPT_DIR/convert.py $SCRIPT_DIR/llama-7b-hf
|
||||
|
||||
echo "Done!"
|
||||
@@ -1,4 +1,3 @@
|
||||
use half::f16;
|
||||
use luminal::{op::Function, prelude::*};
|
||||
|
||||
/// Load the model in the same way dfdx-llama does
|
||||
@@ -16,36 +15,29 @@ impl DfdxDeferredLoader {
|
||||
}
|
||||
|
||||
impl Loader for DfdxDeferredLoader {
|
||||
type Output = ();
|
||||
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) {
|
||||
let mut serializer = Serializer::default();
|
||||
model.serialize(&mut serializer);
|
||||
|
||||
for (s, n) in serializer.state {
|
||||
let shape: Vec<usize> = graph
|
||||
for (s, n) in state_dict(model) {
|
||||
let Some(n_elements) = graph
|
||||
.graph
|
||||
.node_weight_mut(n)
|
||||
.unwrap()
|
||||
.1
|
||||
.iter()
|
||||
.map(|i| match i {
|
||||
RealDim::Const(m) => *m,
|
||||
RealDim::Dyn => panic!("Dyn dimension in a weight"),
|
||||
})
|
||||
.collect();
|
||||
.edges_directed(n, petgraph::Direction::Outgoing)
|
||||
.find_map(|e| e.weight().as_data())
|
||||
.map(|(_, _, s)| s.n_physical_elements().to_usize().unwrap())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
if let Some(inp_func) = graph
|
||||
.graph
|
||||
.node_weight_mut(n)
|
||||
.unwrap()
|
||||
.0
|
||||
.as_any_mut()
|
||||
.downcast_mut::<Function>()
|
||||
{
|
||||
let path = self.path.clone();
|
||||
inp_func.1 = Box::new(move |_, i| {
|
||||
inp_func.1 = Box::new(move |_| {
|
||||
// Get memmapped tensor
|
||||
let bytes = std::fs::read(format!("{path}/{s}")).unwrap();
|
||||
let num_params: usize = shape.iter().product();
|
||||
let data: Vec<f32> = if bytes.len() == num_params * 2 {
|
||||
let data: Vec<f32> = if bytes.len() == n_elements * 2 {
|
||||
// Half-precision
|
||||
bytes
|
||||
.chunks_exact(std::mem::size_of::<f16>())
|
||||
@@ -53,7 +45,7 @@ impl Loader for DfdxDeferredLoader {
|
||||
std::mem::transmute::<[u8; 2], f16>([chunk[0], chunk[1]]).to_f32()
|
||||
})
|
||||
.collect()
|
||||
} else if bytes.len() == num_params * 4 {
|
||||
} else if bytes.len() == n_elements * 4 {
|
||||
// Full precision
|
||||
bytes
|
||||
.chunks_exact(std::mem::size_of::<f32>())
|
||||
@@ -65,23 +57,16 @@ impl Loader for DfdxDeferredLoader {
|
||||
.collect()
|
||||
} else {
|
||||
panic!(
|
||||
"Expected {} or {} bytes, got {} when loading {}{}",
|
||||
num_params * 2,
|
||||
num_params * 4,
|
||||
"Expected {} or {} bytes, got {} when loading {path}/{s}",
|
||||
n_elements * 2,
|
||||
n_elements * 4,
|
||||
bytes.len(),
|
||||
path,
|
||||
s
|
||||
)
|
||||
};
|
||||
(
|
||||
Some(Tensor {
|
||||
data: Box::new(data),
|
||||
}),
|
||||
TensorView {
|
||||
tensor_id: i,
|
||||
shape: ShapeTracker::new(shape.clone()),
|
||||
},
|
||||
)
|
||||
|
||||
vec![Tensor {
|
||||
data: Box::new(data),
|
||||
}]
|
||||
});
|
||||
};
|
||||
}
|
||||
164
examples/llama/src/main.rs
Normal file
164
examples/llama/src/main.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
mod loader;
|
||||
mod model;
|
||||
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
marker::PhantomData,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use colored::Colorize;
|
||||
use luminal::{prelude::*, shape::symbolic::Expression};
|
||||
use rust_tokenizers::tokenizer::{
|
||||
SentencePieceBpeTokenizer, Tokenizer,
|
||||
TruncationStrategy::{self},
|
||||
};
|
||||
|
||||
use crate::model::KVCache;
|
||||
#[cfg(feature = "metal")]
|
||||
type DeviceCompiler = luminal_metal::MetalCompiler<luminal::prelude::f16>;
|
||||
#[cfg(feature = "cuda")]
|
||||
type DeviceCompiler = luminal_cuda::CudaCompiler<luminal::prelude::f16>;
|
||||
#[cfg(all(not(feature = "cuda"), not(feature = "metal")))]
|
||||
type DeviceCompiler = CPUCompiler;
|
||||
|
||||
fn main() {
|
||||
let prompt = "Here is a python implementation of merge sort:";
|
||||
let tokens_to_generate = 128;
|
||||
let tokenizer =
|
||||
SentencePieceBpeTokenizer::from_file("setup/llama-7b-hf/tokenizer.model", false).unwrap();
|
||||
|
||||
print!("Defining graph");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::LAYERS)
|
||||
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, model::HEADS, 0, model::HEAD_DIM]);
|
||||
let model = model::Llama::initialize(&mut cx);
|
||||
let (logits, mut cache_dest) =
|
||||
model.forward((input, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.retrieve();
|
||||
cache_dest.keep();
|
||||
loader::DfdxDeferredLoader::new("setup/llama-7b-hf").load(&model, &mut cx);
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
print!("Compiling graph");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
cx.compile(
|
||||
<(GenericCompiler, DeviceCompiler)>::default(),
|
||||
(&mut input, &mut logits, &mut cache_src, &mut cache_dest),
|
||||
);
|
||||
// Keep model weights
|
||||
let model_weights = downstream(state_set(&model), &cx);
|
||||
cx.keep_tensors(&model_weights);
|
||||
let cache_src_set = downstream(&cache_src, &cx);
|
||||
let cache_dest_set = cache_dest.to_ids();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Initial forward pass to load weights
|
||||
print!("Loading model");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
input.set_dyn(vec![0.], &[1, 1]);
|
||||
cx.set_dyn_dim('t', 1);
|
||||
cx.execute();
|
||||
logits.drop();
|
||||
cache_dest.drop();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Now that weights are loaded, delete the loading nodes so they don't run again
|
||||
delete_inputs(&model_weights, &mut cx);
|
||||
// Run prompt processing pass
|
||||
let mut input_ids = encode(&tokenizer, prompt);
|
||||
input.set_dyn(
|
||||
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
|
||||
&[1, input_ids.len()],
|
||||
);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
print!("Processing Prompt");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
cx.execute();
|
||||
let elapsed_ms = now.elapsed().as_millis();
|
||||
println!(
|
||||
"\t - {elapsed_ms}ms ({:.2} tok/s)",
|
||||
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
|
||||
);
|
||||
delete_inputs(&cache_src_set, &mut cx);
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
|
||||
// Decode token
|
||||
print!(
|
||||
"{}{}",
|
||||
prompt.white().bold(),
|
||||
decode(&tokenizer, &[output_id]).bright_green()
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
|
||||
|
||||
// Decode loop
|
||||
let mut token_decode_times = vec![];
|
||||
for _ in 0..tokens_to_generate {
|
||||
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
|
||||
cx.set_dyn_dim('p', input_ids.len() - 1);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
|
||||
let now = Instant::now();
|
||||
cx.execute();
|
||||
token_decode_times.push(now.elapsed().as_micros());
|
||||
|
||||
// Sample tokens
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
|
||||
}
|
||||
let avg_token_time = token_decode_times
|
||||
.iter()
|
||||
.map(|t| *t as f32 / 1000.)
|
||||
.sum::<f32>()
|
||||
/ token_decode_times.len() as f32;
|
||||
println!(
|
||||
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
|
||||
avg_token_time,
|
||||
1000.0 / avg_token_time
|
||||
);
|
||||
}
|
||||
|
||||
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
|
||||
let mut vector = tokenizer
|
||||
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
|
||||
.token_ids;
|
||||
vector.insert(0, 1); // Start token
|
||||
vector
|
||||
}
|
||||
|
||||
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
|
||||
tokenizer
|
||||
.decode(token_ids, true, false)
|
||||
.replace("<0x0A>", "\n")
|
||||
}
|
||||
|
||||
// Currently just an argmax, do actual sampling here
|
||||
fn sample_index(dist: &[f32]) -> i64 {
|
||||
dist.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap()
|
||||
.0 as i64
|
||||
}
|
||||
365
examples/llama/src/model.rs
Normal file
365
examples/llama/src/model.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
use std::{marker::PhantomData, ops::Mul};
|
||||
|
||||
// LLaMa 1 7B Config
|
||||
pub const VOCAB: usize = 32_000;
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const HIDDEN: usize = 4096;
|
||||
pub const INTERMEDIATE: usize = 11008;
|
||||
pub const HEADS: usize = 32;
|
||||
pub const LAYERS: usize = 32;
|
||||
|
||||
use luminal::{
|
||||
nn::{embedding::Embedding, norm::RMSNorm},
|
||||
prelude::*,
|
||||
shape::symbolic::{BigExpression, Expression},
|
||||
};
|
||||
|
||||
// Full LLaMa model implementation, heavily based off of https://github.com/coreylowman/llama-dfdx/blob/main/src/modeling.rs
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
|
||||
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
}
|
||||
|
||||
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
|
||||
where
|
||||
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
|
||||
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
|
||||
{
|
||||
type Output = GraphTensor<Sh>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
|
||||
let gate = input.matmul(self.gate_proj.permute()).swish();
|
||||
let up = input.matmul(self.up_proj.permute()) * gate;
|
||||
up.matmul(self.down_proj.permute())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: cx.named_tensor("Gate Weight"),
|
||||
up_proj: cx.named_tensor("Up Weight"),
|
||||
down_proj: cx.named_tensor("Down Weight"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("gate_proj/weight", self.gate_proj);
|
||||
s.tensor("up_proj/weight", self.up_proj);
|
||||
s.tensor("down_proj/weight", self.down_proj);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RotaryEmbedding {
|
||||
pub inv_freq: GraphTensor<R1<HEAD_DIM_OVER_2>>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, Seq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
BigExpression,
|
||||
)> for RotaryEmbedding
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(inp, prev_seq): (
|
||||
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
BigExpression,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let (sin, cos) = self.get_sincos::<Seq>(prev_seq);
|
||||
(Self::rotate_half(inp) * sin.expand()) + (inp * cos.expand())
|
||||
}
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn get_sincos<Seq: Dimension>(
|
||||
&self,
|
||||
prev_seq: BigExpression,
|
||||
) -> (
|
||||
GraphTensor<(Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Seq, Const<HEAD_DIM>)>,
|
||||
) {
|
||||
let t = self.inv_freq.graph().arange::<Seq>() + prev_seq;
|
||||
let freqs = t.expand::<(Seq, Const<1>), _>().matmul(
|
||||
self.inv_freq
|
||||
.expand::<(Const<1>, Const<HEAD_DIM_OVER_2>), _>(),
|
||||
);
|
||||
let emb = freqs.concat_along::<(Seq, Const<HEAD_DIM>), Axis<1>, _>(freqs);
|
||||
(emb.sin().reshape(), emb.cos().reshape())
|
||||
}
|
||||
|
||||
fn rotate_half<Batch: Dimension, Seq: Dimension>(
|
||||
x: GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
) -> GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
let x1 = x
|
||||
.slice((.., .., .., ..Expression::from(HEAD_DIM_OVER_2)))
|
||||
.contiguous();
|
||||
let x2 = x
|
||||
.slice((.., .., .., Expression::from(HEAD_DIM_OVER_2)..))
|
||||
.contiguous();
|
||||
(-x2).concat_along::<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>), Axis<3>, _>(x1)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for RotaryEmbedding {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
inv_freq: cx.named_tensor("Inv Freq"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for RotaryEmbedding {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("inv_freq", self.inv_freq);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Attention {
|
||||
pub q_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub k_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub v_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub o_proj: GraphTensor<(Const<HIDDEN>, Const<HIDDEN>)>,
|
||||
pub rotary_embed: RotaryEmbedding,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for Attention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let queries = self
|
||||
.rotary_embed
|
||||
.forward((queries.permute(), PrevSeq::const_size().into()));
|
||||
let keys = self
|
||||
.rotary_embed
|
||||
.forward((keys, PrevSeq::const_size().into()));
|
||||
|
||||
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
|
||||
(
|
||||
k_cache.concat_along::<_, Axis<2>, _>(keys),
|
||||
v_cache.concat_along::<_, Axis<2>, _>(values),
|
||||
)
|
||||
} else {
|
||||
(keys.realize(), values.contiguous().realize())
|
||||
};
|
||||
|
||||
let mut weights = queries
|
||||
.matmul(keys.permute())
|
||||
.mul((HEAD_DIM as f64).sqrt().recip() as f32);
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq), _, _>(&[
|
||||
(0.into(), Expression::from(0)),
|
||||
(TotSeq::const_size() - CurSeq::const_size(), 0.into()),
|
||||
])
|
||||
.expand();
|
||||
|
||||
let outputs = weights
|
||||
.softmax::<3>()
|
||||
.matmul(values)
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN>)>();
|
||||
(
|
||||
outputs.matmul(self.o_proj.permute()),
|
||||
(keys.contiguous(), values.contiguous()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for Attention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Query Weight"),
|
||||
k_proj: cx.named_tensor("Key Weight"),
|
||||
v_proj: cx.named_tensor("Value Weight"),
|
||||
o_proj: cx.named_tensor("Output Weight"),
|
||||
rotary_embed: InitModule::initialize(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Attention {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("q_proj/weight", self.q_proj);
|
||||
s.tensor("k_proj/weight", self.k_proj);
|
||||
s.tensor("v_proj/weight", self.v_proj);
|
||||
s.tensor("o_proj/weight", self.o_proj);
|
||||
s.module("rotary_emb", &self.rotary_embed);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub self_attn: Attention,
|
||||
pub mlp: Mlp<INTERMEDIATE, HIDDEN>,
|
||||
pub input_layer_norm: RMSNorm<HIDDEN>,
|
||||
pub post_attention_layer_norm: RMSNorm<HIDDEN>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.input_layer_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.self_attn
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
|
||||
// Residual Addition
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.mlp.forward(self.post_attention_layer_norm.forward(x));
|
||||
|
||||
// Residual Addition
|
||||
(x + y, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
self_attn: InitModule::initialize(cx),
|
||||
mlp: InitModule::initialize(cx),
|
||||
input_layer_norm: InitModule::initialize(cx),
|
||||
post_attention_layer_norm: InitModule::initialize(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for TransformerBlock {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("self_attn", &self.self_attn);
|
||||
s.module("mlp", &self.mlp);
|
||||
s.module("input_layernorm", &self.input_layer_norm);
|
||||
s.module("post_attention_layernorm", &self.post_attention_layer_norm);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
// Token embeddings
|
||||
pub embedding: Embedding<VOCAB, HIDDEN>,
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Final Norm layer
|
||||
pub norm: RMSNorm<HIDDEN>,
|
||||
// LM Head Layer
|
||||
pub lm_head: GraphTensor<R2<VOCAB, HIDDEN>>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Option<Vec<KVCache<Batch, PrevSeq>>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for Llama
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Option<Vec<KVCache<Batch, PrevSeq>>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Embed tokens
|
||||
let mut x = self.embedding.forward(input);
|
||||
|
||||
// Run through layers and collect new caches
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) =
|
||||
layer.forward((x, cache.as_ref().map(|c| c[i]), PhantomData::<TotSeq>));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
let output = self.norm.forward(x);
|
||||
let output = output.matmul(self.lm_head.permute());
|
||||
|
||||
(output, new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for Llama {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
norm: InitModule::initialize(cx),
|
||||
embedding: InitModule::initialize(cx),
|
||||
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
|
||||
lm_head: cx.named_tensor("LM Head"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Llama {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("model/norm", &self.norm);
|
||||
s.module("model/embed_tokens", &self.embedding);
|
||||
for (i, l) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("model/layers/{i}"), l);
|
||||
}
|
||||
s.tensor("lm_head/weight", self.lm_head);
|
||||
}
|
||||
}
|
||||
14
examples/mistral/.gitignore
vendored
Normal file
14
examples/mistral/.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
20
examples/mistral/Cargo.toml
Normal file
20
examples/mistral/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "mistral"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
metal = ["dep:luminal_metal", "dep:metal-rs"]
|
||||
cuda = ["dep:luminal_cuda"]
|
||||
|
||||
[dependencies]
|
||||
luminal = {path="../.."}
|
||||
luminal_metal = {path="../../crates/luminal_metal", optional=true}
|
||||
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
|
||||
rust_tokenizers = "8.1.0"
|
||||
clap = { version = "4.4.18", features = ["derive"] }
|
||||
byteorder = "1.5.0"
|
||||
memmap2 = "0.9.4"
|
||||
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
|
||||
colored = "2.1.0"
|
||||
itertools = "0.12.1"
|
||||
10
examples/mistral/prompts/asimov.txt
Normal file
10
examples/mistral/prompts/asimov.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# Three Laws of Robotics
|
||||
|
||||
**The Three Laws of Robotics** (often shortened to **The Three Laws** or **Asimov's Laws**) are a set of rules devised by science fiction author Isaac Asimov, which were to be followed by robots in several of his stories. The rules were introduced in his 1942 short story "Runaround" (included in the 1950 collection I, Robot), although similar restrictions had been implied in earlier stories.
|
||||
|
||||
## The Laws
|
||||
|
||||
The Three Laws, presented to be from the fictional "Handbook of Robotics, 56th Edition, 2058 A.D.", are:
|
||||
- The First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm.
|
||||
- The Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
|
||||
- The Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
|
||||
1
examples/mistral/prompts/merge_sort.txt
Normal file
1
examples/mistral/prompts/merge_sort.txt
Normal file
@@ -0,0 +1 @@
|
||||
[INST]Write me a python implementation of merge sort[/INST]
|
||||
209
examples/mistral/prompts/shakespeare.txt
Normal file
209
examples/mistral/prompts/shakespeare.txt
Normal file
@@ -0,0 +1,209 @@
|
||||
[INST] Complete the following
|
||||
|
||||
## SCENE VII. The forest.
|
||||
A table set out. Enter DUKE SENIOR, AMIENS, and Lords like outlaws
|
||||
|
||||
### DUKE SENIOR
|
||||
I think he be transform'd into a beast;
|
||||
For I can no where find him like a man.
|
||||
|
||||
### First Lord
|
||||
My lord, he is but even now gone hence:
|
||||
Here was he merry, hearing of a song.
|
||||
|
||||
### DUKE SENIOR
|
||||
If he, compact of jars, grow musical,
|
||||
We shall have shortly discord in the spheres.
|
||||
Go, seek him: tell him I would speak with him.
|
||||
Enter JAQUES
|
||||
|
||||
### First Lord
|
||||
He saves my labour by his own approach.
|
||||
|
||||
### DUKE SENIOR
|
||||
Why, how now, monsieur! what a life is this,
|
||||
That your poor friends must woo your company?
|
||||
What, you look merrily!
|
||||
|
||||
### JAQUES
|
||||
A fool, a fool! I met a fool i' the forest,
|
||||
A motley fool; a miserable world!
|
||||
As I do live by food, I met a fool
|
||||
Who laid him down and bask'd him in the sun,
|
||||
And rail'd on Lady Fortune in good terms,
|
||||
In good set terms and yet a motley fool.
|
||||
'Good morrow, fool,' quoth I. 'No, sir,' quoth he,
|
||||
'Call me not fool till heaven hath sent me fortune:'
|
||||
And then he drew a dial from his poke,
|
||||
And, looking on it with lack-lustre eye,
|
||||
Says very wisely, 'It is ten o'clock:
|
||||
Thus we may see,' quoth he, 'how the world wags:
|
||||
'Tis but an hour ago since it was nine,
|
||||
And after one hour more 'twill be eleven;
|
||||
And so, from hour to hour, we ripe and ripe,
|
||||
And then, from hour to hour, we rot and rot;
|
||||
And thereby hangs a tale.' When I did hear
|
||||
The motley fool thus moral on the time,
|
||||
My lungs began to crow like chanticleer,
|
||||
That fools should be so deep-contemplative,
|
||||
And I did laugh sans intermission
|
||||
An hour by his dial. O noble fool!
|
||||
A worthy fool! Motley's the only wear.
|
||||
|
||||
### DUKE SENIOR
|
||||
What fool is this?
|
||||
|
||||
### JAQUES
|
||||
O worthy fool! One that hath been a courtier,
|
||||
And says, if ladies be but young and fair,
|
||||
They have the gift to know it: and in his brain,
|
||||
Which is as dry as the remainder biscuit
|
||||
After a voyage, he hath strange places cramm'd
|
||||
With observation, the which he vents
|
||||
In mangled forms. O that I were a fool!
|
||||
I am ambitious for a motley coat.
|
||||
|
||||
### DUKE SENIOR
|
||||
Thou shalt have one.
|
||||
|
||||
### JAQUES
|
||||
It is my only suit;
|
||||
Provided that you weed your better judgments
|
||||
Of all opinion that grows rank in them
|
||||
That I am wise. I must have liberty
|
||||
Withal, as large a charter as the wind,
|
||||
To blow on whom I please; for so fools have;
|
||||
And they that are most galled with my folly,
|
||||
They most must laugh. And why, sir, must they so?
|
||||
The 'why' is plain as way to parish church:
|
||||
He that a fool doth very wisely hit
|
||||
Doth very foolishly, although he smart,
|
||||
Not to seem senseless of the bob: if not,
|
||||
The wise man's folly is anatomized
|
||||
Even by the squandering glances of the fool.
|
||||
Invest me in my motley; give me leave
|
||||
To speak my mind, and I will through and through
|
||||
Cleanse the foul body of the infected world,
|
||||
If they will patiently receive my medicine.
|
||||
|
||||
### DUKE SENIOR
|
||||
Fie on thee! I can tell what thou wouldst do.
|
||||
|
||||
### JAQUES
|
||||
What, for a counter, would I do but good?
|
||||
|
||||
### DUKE SENIOR
|
||||
Most mischievous foul sin, in chiding sin:
|
||||
For thou thyself hast been a libertine,
|
||||
As sensual as the brutish sting itself;
|
||||
And all the embossed sores and headed evils,
|
||||
That thou with licence of free foot hast caught,
|
||||
Wouldst thou disgorge into the general world.
|
||||
|
||||
### JAQUES
|
||||
Why, who cries out on pride,
|
||||
That can therein tax any private party?
|
||||
Doth it not flow as hugely as the sea,
|
||||
Till that the weary very means do ebb?
|
||||
What woman in the city do I name,
|
||||
When that I say the city-woman bears
|
||||
The cost of princes on unworthy shoulders?
|
||||
Who can come in and say that I mean her,
|
||||
When such a one as she such is her neighbour?
|
||||
Or what is he of basest function
|
||||
That says his bravery is not of my cost,
|
||||
Thinking that I mean him, but therein suits
|
||||
His folly to the mettle of my speech?
|
||||
There then; how then? what then? Let me see wherein
|
||||
My tongue hath wrong'd him: if it do him right,
|
||||
Then he hath wrong'd himself; if he be free,
|
||||
Why then my taxing like a wild-goose flies,
|
||||
Unclaim'd of any man. But who comes here?
|
||||
Enter ORLANDO, with his sword drawn
|
||||
|
||||
### ORLANDO
|
||||
Forbear, and eat no more.
|
||||
|
||||
### JAQUES
|
||||
Why, I have eat none yet.
|
||||
|
||||
### ORLANDO
|
||||
Nor shalt not, till necessity be served.
|
||||
|
||||
### JAQUES
|
||||
Of what kind should this cock come of?
|
||||
|
||||
### DUKE SENIOR
|
||||
Art thou thus bolden'd, man, by thy distress,
|
||||
Or else a rude despiser of good manners,
|
||||
That in civility thou seem'st so empty?
|
||||
|
||||
### ORLANDO
|
||||
You touch'd my vein at first: the thorny point
|
||||
Of bare distress hath ta'en from me the show
|
||||
Of smooth civility: yet am I inland bred
|
||||
And know some nurture. But forbear, I say:
|
||||
He dies that touches any of this fruit
|
||||
Till I and my affairs are answered.
|
||||
|
||||
### JAQUES
|
||||
An you will not be answered with reason, I must die.
|
||||
|
||||
### DUKE SENIOR
|
||||
What would you have? Your gentleness shall force
|
||||
More than your force move us to gentleness.
|
||||
|
||||
### ORLANDO
|
||||
I almost die for food; and let me have it.
|
||||
|
||||
### DUKE SENIOR
|
||||
Sit down and feed, and welcome to our table.
|
||||
|
||||
### ORLANDO
|
||||
Speak you so gently? Pardon me, I pray you:
|
||||
I thought that all things had been savage here;
|
||||
And therefore put I on the countenance
|
||||
Of stern commandment. But whate'er you are
|
||||
That in this desert inaccessible,
|
||||
Under the shade of melancholy boughs,
|
||||
Lose and neglect the creeping hours of time
|
||||
If ever you have look'd on better days,
|
||||
If ever been where bells have knoll'd to church,
|
||||
If ever sat at any good man's feast,
|
||||
If ever from your eyelids wiped a tear
|
||||
And know what 'tis to pity and be pitied,
|
||||
Let gentleness my strong enforcement be:
|
||||
In the which hope I blush, and hide my sword.
|
||||
|
||||
### DUKE SENIOR
|
||||
True is it that we have seen better days,
|
||||
And have with holy bell been knoll'd to church
|
||||
And sat at good men's feasts and wiped our eyes
|
||||
Of drops that sacred pity hath engender'd:
|
||||
And therefore sit you down in gentleness
|
||||
And take upon command what help we have
|
||||
That to your wanting may be minister'd.
|
||||
|
||||
### ORLANDO
|
||||
Then but forbear your food a little while,
|
||||
Whiles, like a doe, I go to find my fawn
|
||||
And give it food. There is an old poor man,
|
||||
Who after me hath many a weary step
|
||||
Limp'd in pure love: till he be first sufficed,
|
||||
Oppress'd with two weak evils, age and hunger,
|
||||
I will not touch a bit.
|
||||
|
||||
### DUKE SENIOR
|
||||
Go find him out,
|
||||
And we will nothing waste till you return.
|
||||
|
||||
### ORLANDO
|
||||
I thank ye; and be blest for your good comfort!
|
||||
Exit
|
||||
|
||||
### DUKE SENIOR
|
||||
Thou seest we are not all alone unhappy:
|
||||
This wide and universal theatre
|
||||
Presents more woeful pageants than the scene
|
||||
Wherein we play in.
|
||||
[/INST]
|
||||
8
examples/mistral/setup/setup.sh
Normal file
8
examples/mistral/setup/setup.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
echo "Downloading Tokenizer"
|
||||
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral_tokenizer.model
|
||||
echo "Downloading Model"
|
||||
curl --location https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q8_0.gguf?download=true --output $SCRIPT_DIR/mistral-7b-instruct-v0.2.Q8_0.gguf
|
||||
echo "Done Downloading Model"
|
||||
302
examples/mistral/src/gguf.rs
Normal file
302
examples/mistral/src/gguf.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
//! Support for the GGUF file format.
|
||||
//!
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const DEFAULT_ALIGNMENT: u64 = 32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Gguf,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = ();
|
||||
fn try_from(value: u32) -> Result<Self, ()> {
|
||||
let magic = match value {
|
||||
0x46554747 | 0x47475546 => Self::Gguf,
|
||||
_ => panic!("unknown magic 0x{value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
GgufV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
pub fn read<R: std::io::Read>(reader: &mut R) -> Result<Self, ()> {
|
||||
let magic = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let magic = Magic::try_from(magic).unwrap();
|
||||
let version = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => panic!("gguf: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub tensor_infos: HashMap<String, (usize, usize, GgmlDType)>, // buffer size and offset
|
||||
pub tensor_data_offset: u64,
|
||||
}
|
||||
|
||||
pub fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String, ()> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>().unwrap() as usize
|
||||
}
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v).unwrap();
|
||||
// GGUF strings are supposed to be non-null terminated but in practice this happens.
|
||||
while let Some(0) = v.last() {
|
||||
v.pop();
|
||||
}
|
||||
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ValueType {
|
||||
// The value is a 8-bit unsigned integer.
|
||||
U8,
|
||||
// The value is a 8-bit signed integer.
|
||||
I8,
|
||||
// The value is a 16-bit unsigned little-endian integer.
|
||||
U16,
|
||||
// The value is a 16-bit signed little-endian integer.
|
||||
I16,
|
||||
// The value is a 32-bit unsigned little-endian integer.
|
||||
U32,
|
||||
// The value is a 32-bit signed little-endian integer.
|
||||
I32,
|
||||
// The value is a 64-bit unsigned little-endian integer.
|
||||
U64,
|
||||
// The value is a 64-bit signed little-endian integer.
|
||||
I64,
|
||||
// The value is a 32-bit IEEE754 floating point number.
|
||||
F32,
|
||||
// The value is a 64-bit IEEE754 floating point number.
|
||||
F64,
|
||||
// The value is a boolean.
|
||||
// 1-byte value where 0 is false and 1 is true.
|
||||
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
|
||||
Bool,
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Value {
|
||||
U8(u8),
|
||||
I8(i8),
|
||||
U16(u16),
|
||||
I16(i16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
I64(i64),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
Bool(bool),
|
||||
String(String),
|
||||
Array(Vec<Value>),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn read<R: std::io::Read>(
|
||||
reader: &mut R,
|
||||
value_type: ValueType,
|
||||
magic: &VersionedMagic,
|
||||
) -> Result<Self, ()> {
|
||||
let v = match value_type {
|
||||
ValueType::U8 => Self::U8(reader.read_u8().unwrap()),
|
||||
ValueType::I8 => Self::I8(reader.read_i8().unwrap()),
|
||||
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>().unwrap()),
|
||||
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>().unwrap()),
|
||||
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>().unwrap()),
|
||||
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>().unwrap()),
|
||||
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>().unwrap()),
|
||||
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>().unwrap()),
|
||||
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>().unwrap()),
|
||||
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>().unwrap()),
|
||||
ValueType::Bool => match reader.read_u8().unwrap() {
|
||||
0 => Self::Bool(false),
|
||||
1 => Self::Bool(true),
|
||||
b => panic!("unexpected bool value {b}"),
|
||||
},
|
||||
ValueType::String => Self::String(read_string(reader, magic).unwrap()),
|
||||
ValueType::Array => {
|
||||
let value_type = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let value_type = ValueType::from_u32(value_type).unwrap();
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>().unwrap() as usize
|
||||
}
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
vs.push(Value::read(reader, value_type, magic).unwrap())
|
||||
}
|
||||
Self::Array(vs)
|
||||
}
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
pub fn from_u32(v: u32) -> Result<Self, ()> {
|
||||
let v = match v {
|
||||
0 => Self::U8,
|
||||
1 => Self::I8,
|
||||
2 => Self::U16,
|
||||
3 => Self::I16,
|
||||
4 => Self::U32,
|
||||
5 => Self::I32,
|
||||
6 => Self::F32,
|
||||
7 => Self::Bool,
|
||||
8 => Self::String,
|
||||
9 => Self::Array,
|
||||
10 => Self::U64,
|
||||
11 => Self::I64,
|
||||
12 => Self::F64,
|
||||
v => panic!("unrecognized value-type {v:#08x}"),
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self, ()> {
|
||||
let magic = VersionedMagic::read(reader).unwrap();
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>().unwrap() as usize
|
||||
}
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>().unwrap() as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>().unwrap() as usize
|
||||
}
|
||||
};
|
||||
|
||||
// Read metadata
|
||||
let mut metadata = HashMap::new();
|
||||
for _idx in 0..metadata_kv_count {
|
||||
let key = read_string(reader, &magic).unwrap();
|
||||
let value_type = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let value_type = ValueType::from_u32(value_type).unwrap();
|
||||
let value = Value::read(reader, value_type, &magic).unwrap();
|
||||
metadata.insert(key, value);
|
||||
}
|
||||
// Read tensor infos
|
||||
let mut tensor_infos = HashMap::new();
|
||||
for _idx in 0..tensor_count {
|
||||
let tensor_name = read_string(reader, &magic).unwrap();
|
||||
let n_dimensions = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let n_elements = match magic {
|
||||
VersionedMagic::GgufV1 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader
|
||||
.read_u32_into::<LittleEndian>(&mut dimensions)
|
||||
.unwrap();
|
||||
dimensions.into_iter().map(|c| c as usize).product()
|
||||
}
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader
|
||||
.read_u64_into::<LittleEndian>(&mut dimensions)
|
||||
.unwrap();
|
||||
dimensions.into_iter().map(|c| c as usize).product()
|
||||
}
|
||||
};
|
||||
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>().unwrap();
|
||||
let offset = reader.read_u64::<LittleEndian>().unwrap();
|
||||
tensor_infos.insert(
|
||||
tensor_name,
|
||||
(n_elements, offset as usize, GgmlDType::from_u32(ggml_dtype)),
|
||||
);
|
||||
}
|
||||
let position = reader.stream_position().unwrap();
|
||||
let alignment = match metadata.get("general.alignment") {
|
||||
Some(Value::U8(v)) => *v as u64,
|
||||
Some(Value::U16(v)) => *v as u64,
|
||||
Some(Value::U32(v)) => *v as u64,
|
||||
Some(Value::I8(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I16(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
tensor_infos,
|
||||
tensor_data_offset,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
}
|
||||
|
||||
impl GgmlDType {
|
||||
fn from_u32(u: u32) -> Self {
|
||||
match u {
|
||||
0 => Self::F32,
|
||||
1 => Self::F16,
|
||||
2 => Self::Q4_0,
|
||||
3 => Self::Q4_1,
|
||||
6 => Self::Q5_0,
|
||||
7 => Self::Q5_1,
|
||||
8 => Self::Q8_0,
|
||||
9 => Self::Q8_1,
|
||||
10 => Self::Q2K,
|
||||
11 => Self::Q3K,
|
||||
12 => Self::Q4K,
|
||||
13 => Self::Q5K,
|
||||
14 => Self::Q6K,
|
||||
15 => Self::Q8K,
|
||||
_ => panic!("unknown dtype for tensor {u}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
173
examples/mistral/src/loader.rs
Normal file
173
examples/mistral/src/loader.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Seek},
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::{op::Function, prelude::*};
|
||||
|
||||
use crate::gguf::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use {
|
||||
luminal_metal::MetalBuffer,
|
||||
memmap2::Mmap,
|
||||
metal_rs::{Device, MTLResourceOptions},
|
||||
};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub struct MetalQ8Loader(String);
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl MetalQ8Loader {
|
||||
pub fn new<S: Into<String>>(path: S) -> Self {
|
||||
Self(path.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl Loader for MetalQ8Loader {
|
||||
type Output = Vec<NodeIndex>;
|
||||
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) -> Self::Output {
|
||||
// Read metadata from file
|
||||
let mut reader = File::open(&self.0).unwrap();
|
||||
let Content {
|
||||
mut tensor_infos,
|
||||
tensor_data_offset,
|
||||
..
|
||||
} = Content::read(&mut reader).unwrap();
|
||||
|
||||
// Create weight loading closures
|
||||
let mut q8_weights = vec![];
|
||||
for (weight_name, node_index) in state_dict(model) {
|
||||
if let Some(loading_node) = graph
|
||||
.graph
|
||||
.node_weight_mut(node_index)
|
||||
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
|
||||
{
|
||||
let file_path = self.0.clone();
|
||||
let (n_elements, buffer_offset, data_type) =
|
||||
tensor_infos.remove(&weight_name.replace('/', ".")).unwrap();
|
||||
let n_bytes = match data_type {
|
||||
GgmlDType::F32 => n_elements * 4,
|
||||
GgmlDType::Q8_0 => {
|
||||
q8_weights.push(node_index);
|
||||
n_elements + (n_elements / 16)
|
||||
}
|
||||
_ => panic!("Unsupported dtype: {data_type:?}"),
|
||||
};
|
||||
loading_node.1 = Box::new(move |_| {
|
||||
let mmap_buffer =
|
||||
unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() };
|
||||
let buffer = Device::system_default().unwrap().new_buffer_with_data(
|
||||
unsafe {
|
||||
mmap_buffer
|
||||
.as_ptr()
|
||||
.add(buffer_offset + tensor_data_offset as usize)
|
||||
as *const _
|
||||
},
|
||||
n_bytes as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
vec![Tensor {
|
||||
data: Box::new(MetalBuffer(buffer)),
|
||||
}]
|
||||
});
|
||||
}
|
||||
}
|
||||
q8_weights
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub struct Q8Loader(String);
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
impl Q8Loader {
|
||||
pub fn new<S: Into<String>>(path: S) -> Self {
|
||||
Self(path.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
impl Loader for Q8Loader {
|
||||
type Output = Vec<NodeIndex>;
|
||||
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) -> Self::Output {
|
||||
#[repr(C, packed)]
|
||||
#[derive(Clone, Copy)]
|
||||
struct Q8Block {
|
||||
delta: f16,
|
||||
weights: [i8; 32],
|
||||
}
|
||||
|
||||
// Read metadata from file
|
||||
let mut reader = File::open(&self.0).unwrap();
|
||||
let Content {
|
||||
mut tensor_infos,
|
||||
tensor_data_offset,
|
||||
..
|
||||
} = Content::read(&mut reader).unwrap();
|
||||
|
||||
// Create weight loading closures
|
||||
let mut q8_weights = vec![];
|
||||
for (weight_name, node_index) in state_dict(model) {
|
||||
if let Some(loading_node) = graph
|
||||
.graph
|
||||
.node_weight_mut(node_index)
|
||||
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
|
||||
{
|
||||
let file_path = self.0.clone();
|
||||
let (n_elements, buffer_offset, data_type) =
|
||||
tensor_infos.remove(&weight_name.replace('/', ".")).unwrap();
|
||||
let n_bytes = match data_type {
|
||||
GgmlDType::F32 => n_elements * 4,
|
||||
GgmlDType::Q8_0 => {
|
||||
q8_weights.push(node_index);
|
||||
n_elements + (n_elements / 16)
|
||||
}
|
||||
_ => panic!("Unsupported dtype: {data_type:?}"),
|
||||
};
|
||||
loading_node.1 = Box::new(move |_| {
|
||||
// Load all bytes
|
||||
let mut bytes = vec![0; n_bytes];
|
||||
let mut file = File::open(&file_path).unwrap();
|
||||
file.seek(std::io::SeekFrom::Start(
|
||||
buffer_offset as u64 + tensor_data_offset,
|
||||
))
|
||||
.unwrap();
|
||||
file.read_exact(&mut bytes).unwrap();
|
||||
// Dequantize into f32
|
||||
let data: Vec<f32> = match data_type {
|
||||
GgmlDType::F32 => bytes
|
||||
.into_iter()
|
||||
.chunks(4)
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let c = c.collect::<Vec<_>>();
|
||||
f32::from_le_bytes([c[0], c[1], c[2], c[3]])
|
||||
})
|
||||
.collect(),
|
||||
GgmlDType::Q8_0 => bytes
|
||||
.into_iter()
|
||||
.chunks(34)
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let chunk = c.collect::<Vec<_>>();
|
||||
unsafe { chunk.align_to::<Q8Block>().1[0] }
|
||||
})
|
||||
.flat_map(|chunk| {
|
||||
chunk
|
||||
.weights
|
||||
.into_iter()
|
||||
.map(move |i| i as f32 * chunk.delta.to_f32())
|
||||
})
|
||||
.collect(),
|
||||
_ => panic!("Unsupported dtype: {data_type:?}"),
|
||||
};
|
||||
vec![Tensor::new(data)]
|
||||
});
|
||||
}
|
||||
}
|
||||
q8_weights
|
||||
}
|
||||
}
|
||||
185
examples/mistral/src/main.rs
Normal file
185
examples/mistral/src/main.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
marker::PhantomData,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use clap::Parser;
|
||||
use colored::Colorize;
|
||||
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
|
||||
|
||||
mod gguf;
|
||||
mod loader;
|
||||
mod model;
|
||||
|
||||
use crate::model::KVCache;
|
||||
use luminal::{prelude::*, shape::symbolic::Expression};
|
||||
|
||||
// Command args parser
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct CLIArgs {
|
||||
/// Number of tokens to generate
|
||||
#[clap(short = 't', long = "gen_tokens", default_value = "128")]
|
||||
gen_tokens: i32,
|
||||
|
||||
/// Prompt for the model
|
||||
#[clap(short = 'p', long = "prompt", default_value = include_str!("../prompts/merge_sort.txt"))]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let cli_args = CLIArgs::parse();
|
||||
let tokenizer =
|
||||
SentencePieceBpeTokenizer::from_file("setup/mistral_tokenizer.model", false).unwrap();
|
||||
|
||||
print!("Defining graph");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
|
||||
// Set up graph
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
|
||||
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
|
||||
let model = model::MistralLM::initialize(&mut cx);
|
||||
let (logits, mut cache_dest) =
|
||||
model.forward((input, Some(cache_src.clone()), PhantomData::<Dyn<'t'>>));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.retrieve();
|
||||
cache_dest.keep();
|
||||
|
||||
// Set up model loading
|
||||
#[cfg(feature = "metal")]
|
||||
let quantized_weight_nodes =
|
||||
loader::MetalQ8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf")
|
||||
.load(&model, &mut cx);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
loader::Q8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf").load(&model, &mut cx);
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
print!("Compiling graph");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
cx.compile(
|
||||
(
|
||||
GenericCompiler::default(),
|
||||
#[cfg(feature = "metal")]
|
||||
luminal_metal::MetalQuantizedCompiler::<f32>::new(quantized_weight_nodes),
|
||||
#[cfg(feature = "cuda")]
|
||||
luminal_cuda::CudaCompiler::<f32>::default(),
|
||||
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
|
||||
luminal::compilers::CPUCompiler::default(),
|
||||
),
|
||||
(&mut input, &mut logits, &mut cache_src, &mut cache_dest),
|
||||
);
|
||||
// Keep model weights
|
||||
let model_weights = downstream(state_set(&model), &cx);
|
||||
cx.keep_tensors(&model_weights);
|
||||
let cache_src_set = downstream(&cache_src, &cx);
|
||||
let cache_dest_set = cache_dest.to_ids();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Initial forward pass to load weights
|
||||
print!("Loading model");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
input.set_dyn(vec![0.], &[1, 1]);
|
||||
cx.set_dyn_dim('t', 1);
|
||||
cx.execute();
|
||||
logits.drop();
|
||||
cache_dest.drop();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Now that weights are loaded, delete the loading nodes so they don't run again
|
||||
delete_inputs(&model_weights, &mut cx);
|
||||
// Run prompt processing pass
|
||||
let mut input_ids = encode(&tokenizer, &cli_args.prompt);
|
||||
input.set_dyn(
|
||||
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
|
||||
&[1, input_ids.len()],
|
||||
);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
print!("Processing Prompt");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
cx.execute();
|
||||
let elapsed_ms = now.elapsed().as_millis();
|
||||
println!(
|
||||
"\t - {elapsed_ms}ms ({:.2} tok/s)",
|
||||
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
|
||||
);
|
||||
delete_inputs(&cache_src_set, &mut cx);
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
|
||||
// Decode token
|
||||
print!(
|
||||
"{}{}",
|
||||
cli_args.prompt.white().bold(),
|
||||
decode(&tokenizer, &[output_id]).bright_green()
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
|
||||
|
||||
// Decode loop
|
||||
let mut token_decode_times = vec![];
|
||||
for _ in 0..cli_args.gen_tokens {
|
||||
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
|
||||
cx.set_dyn_dim('p', input_ids.len() - 1);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
|
||||
let now = Instant::now();
|
||||
cx.execute();
|
||||
token_decode_times.push(now.elapsed().as_micros());
|
||||
|
||||
// Sample tokens
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
|
||||
}
|
||||
let avg_token_time = token_decode_times
|
||||
.iter()
|
||||
.map(|t| *t as f32 / 1000.)
|
||||
.sum::<f32>()
|
||||
/ token_decode_times.len() as f32;
|
||||
println!(
|
||||
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
|
||||
avg_token_time,
|
||||
1000.0 / avg_token_time
|
||||
);
|
||||
}
|
||||
|
||||
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
|
||||
let mut vector = tokenizer
|
||||
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
|
||||
.token_ids;
|
||||
vector.insert(0, 1); // Start token
|
||||
vector
|
||||
}
|
||||
|
||||
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
|
||||
tokenizer
|
||||
.decode(token_ids, true, false)
|
||||
.replace("<0x0A>", "\n")
|
||||
}
|
||||
|
||||
// Currently just an argmax, do actual sampling here
|
||||
fn sample_index(dist: &[f32]) -> i64 {
|
||||
dist.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap()
|
||||
.0 as i64
|
||||
}
|
||||
350
examples/mistral/src/model.rs
Normal file
350
examples/mistral/src/model.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
use std::{marker::PhantomData, ops::Div};
|
||||
|
||||
use luminal::{
|
||||
nn::{embedding::Embedding, norm::RMSNorm},
|
||||
prelude::*,
|
||||
shape::symbolic::{BigExpression, Expression},
|
||||
};
|
||||
|
||||
// Mistral 7B Config
|
||||
pub const VOCAB_SIZE: usize = 32000;
|
||||
pub const HIDDEN_DIM: usize = 4096;
|
||||
pub const NUM_LAYERS: usize = 32;
|
||||
pub const N_HEADS: usize = 32;
|
||||
pub const N_KV_HEADS: usize = 8;
|
||||
pub const MLP_DIM: usize = 14336;
|
||||
|
||||
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
pub down_proj: GraphTensor<(Const<H>, Const<I>)>,
|
||||
pub up_proj: GraphTensor<(Const<I>, Const<H>)>,
|
||||
}
|
||||
|
||||
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
|
||||
where
|
||||
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
|
||||
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
|
||||
{
|
||||
type Output = GraphTensor<Sh>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
|
||||
let gate = input.matmul(self.gate_proj.permute()).swish();
|
||||
let up = input.matmul(self.up_proj.permute()) * gate;
|
||||
up.matmul(self.down_proj.permute())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: cx.named_tensor("Gate Weight"),
|
||||
up_proj: cx.named_tensor("Up Weight"),
|
||||
down_proj: cx.named_tensor("Down Weight"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("ffn_gate/weight", self.gate_proj);
|
||||
s.tensor("ffn_up/weight", self.up_proj);
|
||||
s.tensor("ffn_down/weight", self.down_proj);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
|
||||
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
// Get freqs
|
||||
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = freqs.inv_pow(1000000.0).recip();
|
||||
let pos = input.graph().arange::<Seq>() + prev_seq;
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
|
||||
// Split input into evens and odds
|
||||
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
|
||||
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> = split
|
||||
.slice((.., .., .., .., ..Expression::from(1)))
|
||||
.contiguous()
|
||||
.realize();
|
||||
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> = split
|
||||
.slice((.., .., .., .., Expression::from(1)..))
|
||||
.contiguous()
|
||||
.realize();
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
|
||||
// Combine back into output
|
||||
x0_out
|
||||
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
|
||||
x1_out,
|
||||
)
|
||||
.reshape()
|
||||
}
|
||||
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().into());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().into());
|
||||
|
||||
// Add KV cache
|
||||
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
|
||||
(
|
||||
k_cache.concat_along::<_, Axis<2>, _>(keys),
|
||||
v_cache.concat_along::<_, Axis<2>, _>(values),
|
||||
)
|
||||
} else {
|
||||
(keys.realize(), values.contiguous().realize())
|
||||
};
|
||||
|
||||
// Repeat the KV States for Grouped-Query Attention
|
||||
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries
|
||||
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
|
||||
.matmul(repeated_keys.permute())
|
||||
.div((HEAD_DIM as f32).sqrt());
|
||||
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq), _, _>(&[
|
||||
(0.into(), Expression::from(0)),
|
||||
(TotSeq::const_size() - CurSeq::const_size(), 0.into()),
|
||||
])
|
||||
.expand();
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<4>()
|
||||
// Apply distribution to values
|
||||
.matmul(repeated_values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute());
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for SelfAttention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Q Proj"),
|
||||
k_proj: cx.named_tensor("K Proj"),
|
||||
v_proj: cx.named_tensor("V Proj"),
|
||||
o_proj: cx.named_tensor("O Proj"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for SelfAttention {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("attn_q/weight", self.q_proj);
|
||||
s.tensor("attn_v/weight", self.v_proj);
|
||||
s.tensor("attn_k/weight", self.k_proj);
|
||||
s.tensor("attn_output/weight", self.o_proj);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: RMSNorm<HIDDEN_DIM>,
|
||||
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
|
||||
pub feed_forward_norm: RMSNorm<HIDDEN_DIM>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.attention_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.attention
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
|
||||
// Residual Addition
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
|
||||
|
||||
// Residual Addition
|
||||
(x + y, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: {
|
||||
let mut norm = RMSNorm::initialize(cx);
|
||||
norm.epsilon = 1e-5;
|
||||
norm
|
||||
},
|
||||
feed_forward: InitModule::initialize(cx),
|
||||
feed_forward_norm: {
|
||||
let mut norm = RMSNorm::initialize(cx);
|
||||
norm.epsilon = 1e-5;
|
||||
norm
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for TransformerBlock {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("", &self.attention);
|
||||
s.module("attn_norm", &self.attention_norm);
|
||||
s.module("ffn_norm", &self.feed_forward_norm);
|
||||
s.module("", &self.feed_forward);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MistralLM {
|
||||
// Token embeddings
|
||||
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Final Norm layer
|
||||
pub norm: RMSNorm<HIDDEN_DIM>,
|
||||
// LM Head Layer
|
||||
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Option<Vec<KVCache<Batch, PrevSeq>>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for MistralLM
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
Option<Vec<KVCache<Batch, PrevSeq>>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Embed tokens
|
||||
let mut x = self.embedding.forward(input);
|
||||
|
||||
// Run through layers and collect new caches
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) =
|
||||
layer.forward((x, cache.as_ref().map(|c| c[i]), PhantomData::<TotSeq>));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
let output = self.norm.forward(x).matmul(self.lm_head.permute());
|
||||
|
||||
(output, new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for MistralLM {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embedding: InitModule::initialize(cx),
|
||||
norm: {
|
||||
let mut norm = RMSNorm::initialize(cx);
|
||||
norm.epsilon = 1e-5;
|
||||
norm
|
||||
},
|
||||
lm_head: cx.named_tensor("LM Head"),
|
||||
layers: (0..NUM_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for MistralLM {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("token_embd", &self.embedding);
|
||||
s.module("output_norm", &self.norm);
|
||||
s.tensor("output/weight", self.lm_head);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("blk/{i}"), layer);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,16 @@
|
||||
use luminal::{nn::linear::Linear, prelude::*};
|
||||
|
||||
fn main() {
|
||||
// Create a new graph
|
||||
let mut cx = Graph::new();
|
||||
let model: Linear<4, 5> = InitModule::initialize(&mut cx);
|
||||
let a = cx.new_tensor::<R1<4>>("Input");
|
||||
let b = model.forward(a);
|
||||
|
||||
a.set(vec![1., 2., 3., 4.]);
|
||||
b.mark();
|
||||
cx.execute();
|
||||
|
||||
println!(
|
||||
"B: {:?}",
|
||||
b.retrieve().unwrap().real_data(b.view().unwrap()).unwrap()
|
||||
);
|
||||
// Randomly initialize a linear layer with an input size of 4 and an output size of 5
|
||||
let model = Linear::<4, 5>::initialize(&mut cx);
|
||||
// Make an input tensor
|
||||
let a = cx.tensor::<R1<4>>().set(vec![1., 2., 3., 4.]);
|
||||
// Feed tensor through model
|
||||
let b = model.forward(a).retrieve();
|
||||
// Execute the graph
|
||||
cx.execute_debug();
|
||||
// Print the results
|
||||
println!("B: {:?}", b.data());
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 57 KiB |
13
resources/luminal_cudarc/.github/FUNDING.yml
vendored
Normal file
13
resources/luminal_cudarc/.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: coreylowman
|
||||
patreon: dfdx
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: coreylowman
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
23
resources/luminal_cudarc/.github/workflows/cargo-check.yaml
vendored
Normal file
23
resources/luminal_cudarc/.github/workflows/cargo-check.yaml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
cargo-check:
|
||||
name: cargo-check
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --features ci-check
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --no-default-features --features ci-check,no-std,cudnn,cublas,cublaslt,nvrtc,driver,curand,nccl
|
||||
18
resources/luminal_cudarc/.github/workflows/cargo-clippy.yaml
vendored
Normal file
18
resources/luminal_cudarc/.github/workflows/cargo-clippy.yaml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
name: clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add clippy
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --no-default-features --features ci-check,no-std,cudnn,cublas,cublaslt,nvrtc,driver,curand,nccl -- -D warnings
|
||||
19
resources/luminal_cudarc/.github/workflows/cargo-fmt.yaml
vendored
Normal file
19
resources/luminal_cudarc/.github/workflows/cargo-fmt.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
cargo-fmt:
|
||||
name: cargo-fmt
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
3
resources/luminal_cudarc/.gitignore
vendored
Normal file
3
resources/luminal_cudarc/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/Cargo.lock
|
||||
/target
|
||||
/.vscode/
|
||||
0
resources/luminal_cudarc/.rustfmt.toml
Normal file
0
resources/luminal_cudarc/.rustfmt.toml
Normal file
40
resources/luminal_cudarc/Cargo.toml
Normal file
40
resources/luminal_cudarc/Cargo.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "luminal_cudarc"
|
||||
version = "0.10.0"
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "Safe wrappers around CUDA apis"
|
||||
readme = "README.md"
|
||||
|
||||
keywords = [
|
||||
"cuda",
|
||||
"nvidia",
|
||||
"gpu",
|
||||
"nvrtc",
|
||||
"cublas",
|
||||
]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["ci-check", "f16", "cudnn"]
|
||||
|
||||
[features]
|
||||
default = ["std", "driver", "nvrtc", "cublas", "curand"]
|
||||
nvrtc = []
|
||||
driver = ["nvrtc"]
|
||||
cublas = ["driver"]
|
||||
cublaslt = ["driver"]
|
||||
cudnn = ["driver"]
|
||||
curand = ["driver"]
|
||||
nccl = ["driver"]
|
||||
std = []
|
||||
no-std = ["no-std-compat/std", "dep:spin"]
|
||||
f16 = ["dep:half"]
|
||||
ci-check = []
|
||||
static-linking=[]
|
||||
|
||||
[dependencies]
|
||||
spin = { version = "0.9.8", optional = true, features = ["rwlock"], default-features = false }
|
||||
no-std-compat = { version = "0.4.1", optional = true, features = [ "alloc" ] }
|
||||
half = { version = "2.3.1", optional = true, default-features = false, features = ["num-traits", "rand_distr"] }
|
||||
@@ -174,28 +174,3 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,5 +1,3 @@
|
||||
Copyright (c) 2015
|
||||
|
||||
Permission is hereby granted, free of charge, to any
|
||||
person obtaining a copy of this software and associated
|
||||
documentation files (the "Software"), to deal in the
|
||||
89
resources/luminal_cudarc/README.md
Normal file
89
resources/luminal_cudarc/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# cudarc: minimal and safe api over the cuda toolkit
|
||||
|
||||
[](https://discord.gg/AtUhGqBDP5)
|
||||
[](https://crates.io/crates/cudarc)
|
||||
[](https://docs.rs/cudarc)
|
||||
|
||||
Checkout cudarc on [crates.io](https://crates.io/crates/cudarc) and [docs.rs](https://docs.rs/cudarc/latest/cudarc/).
|
||||
|
||||
Safe abstractions over:
|
||||
1. [CUDA driver API](https://docs.nvidia.com/cuda/cuda-driver-api/index.html)
|
||||
2. [NVRTC API](https://docs.nvidia.com/cuda/nvrtc/index.html)
|
||||
3. [cuRAND API](https://docs.nvidia.com/cuda/curand/index.html)
|
||||
4. [cuBLAS API](https://docs.nvidia.com/cuda/cublas/index.html)
|
||||
5. [cuBLASLt API](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
|
||||
|
||||
**Pre-alpha state**, expect breaking changes and not all cuda functions
|
||||
contain a safe wrapper. **Contributions welcome for any that aren't included!**
|
||||
|
||||
# Design
|
||||
|
||||
Goals are:
|
||||
1. As safe as possible (there will still be a lot of unsafe due to ffi & async)
|
||||
2. As ergonomic as possible
|
||||
3. Allow mixing of high level `safe` apis, with low level `sys` apis
|
||||
|
||||
To that end there are three levels to each wrapper (by default the safe api is exported):
|
||||
```rust
|
||||
use cudarc::driver::{safe, result, sys};
|
||||
use cudarc::nvrtc::{safe, result, sys};
|
||||
use cudarc::cublas::{safe, result, sys};
|
||||
use cudarc::cublaslt::{safe, result, sys};
|
||||
use cudarc::curand::{safe, result, sys};
|
||||
```
|
||||
|
||||
where:
|
||||
1. `sys` is the raw ffi apis generated with bindgen
|
||||
2. `result` is a very small wrapper around sys to return `Result` from each function
|
||||
3. `safe` is a wrapper around result/sys to provide safe abstractions
|
||||
|
||||
*Heavily recommend sticking with safe APIs*
|
||||
|
||||
# API Preview
|
||||
|
||||
It's easy to create a new device and transfer data to the gpu:
|
||||
|
||||
```rust
|
||||
let dev = cudarc::driver::CudaDevice::new(0)?;
|
||||
|
||||
// allocate buffers
|
||||
let inp = dev.htod_copy(vec![1.0f32; 100])?;
|
||||
let mut out = dev.alloc_zeros::<f32>(100)?;
|
||||
```
|
||||
|
||||
You can also use the nvrtc api to compile kernels at runtime:
|
||||
|
||||
```rust
|
||||
let ptx = cudarc::nvrtc::compile_ptx("
|
||||
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, const size_t numel) {
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < numel) {
|
||||
out[i] = sin(inp[i]);
|
||||
}
|
||||
}")?;
|
||||
|
||||
// and dynamically load it into the device
|
||||
dev.load_ptx(ptx, "my_module", &["sin_kernel"])?;
|
||||
```
|
||||
|
||||
`cudarc` provides a very simple interface to launch kernels, tuples
|
||||
are the arguments!
|
||||
|
||||
```rust
|
||||
let sin_kernel = dev.get_func("my_module", "sin_kernel").unwrap();
|
||||
let cfg = LaunchConfig::for_num_elems(100);
|
||||
unsafe { sin_kernel.launch(cfg, (&mut out, &inp, 100usize)) }?;
|
||||
```
|
||||
|
||||
And of course it's easy to copy things back to host after you're done:
|
||||
|
||||
```rust
|
||||
let out_host: Vec<f32> = dev.dtoh_sync_copy(&out)?;
|
||||
assert_eq!(out_host, [1.0; 100].map(f32::sin));
|
||||
```
|
||||
|
||||
# License
|
||||
|
||||
Dual-licensed to be compatible with the Rust project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
130
resources/luminal_cudarc/build.rs
Normal file
130
resources/luminal_cudarc/build.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
link_cuda();
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn link_cuda() {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_ROOT");
|
||||
println!("cargo:rerun-if-env-changed=CUDA_PATH");
|
||||
println!("cargo:rerun-if-env-changed=CUDA_TOOLKIT_ROOT_DIR");
|
||||
|
||||
let candidates: Vec<PathBuf> = root_candidates().collect();
|
||||
|
||||
let toolkit_root = root_candidates()
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Unable to find `include/cuda.h` under any of: {:?}. Set the `CUDA_ROOT` environment variable to `$CUDA_ROOT/include/cuda.h` to override path.",
|
||||
candidates
|
||||
)
|
||||
});
|
||||
|
||||
for path in lib_candidates(&toolkit_root) {
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
}
|
||||
|
||||
#[cfg(feature = "driver")]
|
||||
println!("cargo:rustc-link-lib=dylib=cuda");
|
||||
#[cfg(feature = "nccl")]
|
||||
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||
|
||||
#[cfg(feature = "static-linking")]
|
||||
{
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
|
||||
{
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=static=cublasLt_static");
|
||||
}
|
||||
#[cfg(feature = "cublas")]
|
||||
println!("cargo:rustc-link-lib=static=cublas_static");
|
||||
#[cfg(feature = "curand")]
|
||||
{
|
||||
println!("cargo:rustc-link-lib=dylib=culibos");
|
||||
println!("cargo:rustc-link-lib=static=curand_static");
|
||||
}
|
||||
#[cfg(feature = "nvrtc")]
|
||||
{
|
||||
println!("cargo:rustc-link-lib=static=nvrtc_static");
|
||||
println!("cargo:rustc-link-lib=static=nvptxcompiler_static");
|
||||
println!("cargo:rustc-link-lib=static=nvrtc-builtins_static");
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "static-linking"))]
|
||||
{
|
||||
#[cfg(feature = "nvrtc")]
|
||||
println!("cargo:rustc-link-lib=dylib=nvrtc");
|
||||
#[cfg(feature = "curand")]
|
||||
println!("cargo:rustc-link-lib=dylib=curand");
|
||||
#[cfg(feature = "cublas")]
|
||||
println!("cargo:rustc-link-lib=dylib=cublas");
|
||||
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
|
||||
println!("cargo:rustc-link-lib=dylib=cublasLt");
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
{
|
||||
let cudnn_root = root_candidates()
|
||||
.find(|path| path.join("include").join("cudnn.h").is_file())
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Unable to find `include/cudnn.h` under any of: {:?}. Set the `CUDNN_LIB` environment variable to `$CUDNN_LIB/include/cudnn.h` to override path.",
|
||||
candidates
|
||||
)
|
||||
});
|
||||
|
||||
for path in lib_candidates(&cudnn_root) {
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "cudnn")]
|
||||
println!("cargo:rustc-link-lib=dylib=cudnn");
|
||||
}
|
||||
|
||||
fn root_candidates() -> impl Iterator<Item = PathBuf> {
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::into);
|
||||
env_vars.chain(roots).map(Into::<PathBuf>::into)
|
||||
}
|
||||
|
||||
fn lib_candidates(root: &Path) -> Vec<PathBuf> {
|
||||
[
|
||||
"lib",
|
||||
"lib/x64",
|
||||
"lib/Win32",
|
||||
"lib/x86_64",
|
||||
"lib/x86_64-linux-gnu",
|
||||
"lib64",
|
||||
"lib64/stubs",
|
||||
"targets/x86_64-linux",
|
||||
"targets/x86_64-linux/lib",
|
||||
"targets/x86_64-linux/lib/stubs",
|
||||
]
|
||||
.iter()
|
||||
.map(|&p| root.join(p))
|
||||
.filter(|p| p.is_dir())
|
||||
.collect()
|
||||
}
|
||||
19
resources/luminal_cudarc/examples/01-allocate.rs
Normal file
19
resources/luminal_cudarc/examples/01-allocate.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use cudarc::driver::{CudaDevice, CudaSlice, DriverError};
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
|
||||
// unsafe initialization of unset memory
|
||||
let _: CudaSlice<f32> = unsafe { dev.alloc::<f32>(10) }?;
|
||||
|
||||
// this will have memory initialized as 0
|
||||
let _: CudaSlice<f64> = dev.alloc_zeros::<f64>(10)?;
|
||||
|
||||
// initialize with a rust vec
|
||||
let _: CudaSlice<usize> = dev.htod_copy(vec![0; 10])?;
|
||||
|
||||
// or finially, initialize with a slice. this is synchronous though.
|
||||
let _: CudaSlice<u32> = dev.htod_sync_copy(&[1, 2, 3])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
31
resources/luminal_cudarc/examples/02-copy.rs
Normal file
31
resources/luminal_cudarc/examples/02-copy.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use cudarc::driver::{CudaDevice, CudaSlice, DriverError};
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
|
||||
let a: CudaSlice<f64> = dev.alloc_zeros::<f64>(10)?;
|
||||
let mut b = dev.alloc_zeros::<f64>(10)?;
|
||||
|
||||
// you can do device to device copies of course
|
||||
dev.dtod_copy(&a, &mut b)?;
|
||||
|
||||
// but also host to device copys with already allocated buffers
|
||||
dev.htod_copy_into(vec![2.0; 10], &mut b)?;
|
||||
|
||||
// if you want to use slices, you can do synchronous copy
|
||||
dev.htod_sync_copy_into(&[3.0; 10], &mut b)?;
|
||||
|
||||
// you can transfer back using reclaim:
|
||||
let mut a_host: Vec<f64> = dev.sync_reclaim(a)?;
|
||||
assert_eq!(a_host, [0.0; 10]);
|
||||
|
||||
// or copy back without losing ownership:
|
||||
let b_host = dev.dtoh_sync_copy(&b)?;
|
||||
assert_eq!(b_host, [3.0; 10]);
|
||||
|
||||
// or use a slice
|
||||
dev.dtoh_sync_copy_into(&b, &mut a_host)?;
|
||||
assert_eq!(a_host, b_host);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
32
resources/luminal_cudarc/examples/03-launch-kernel.rs
Normal file
32
resources/luminal_cudarc/examples/03-launch-kernel.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use cudarc::{
|
||||
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
|
||||
nvrtc::Ptx,
|
||||
};
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
|
||||
// You can load a function from a pre-compiled PTX like so:
|
||||
dev.load_ptx(Ptx::from_file("./examples/sin.ptx"), "sin", &["sin_kernel"])?;
|
||||
|
||||
// and then retrieve the function with `get_func`
|
||||
let f = dev.get_func("sin", "sin_kernel").unwrap();
|
||||
|
||||
let a_host = [1.0, 2.0, 3.0];
|
||||
|
||||
let a_dev = dev.htod_copy(a_host.into())?;
|
||||
let mut b_dev = a_dev.clone();
|
||||
|
||||
let n = 3;
|
||||
let cfg = LaunchConfig::for_num_elems(n);
|
||||
unsafe { f.launch(cfg, (&mut b_dev, &a_dev, n as i32)) }?;
|
||||
|
||||
let a_host_2 = dev.sync_reclaim(a_dev)?;
|
||||
let b_host = dev.sync_reclaim(b_dev)?;
|
||||
|
||||
println!("Found {:?}", b_host);
|
||||
println!("Expected {:?}", a_host.map(f32::sin));
|
||||
assert_eq!(&a_host, a_host_2.as_slice());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
42
resources/luminal_cudarc/examples/04-streams.rs
Normal file
42
resources/luminal_cudarc/examples/04-streams.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use cudarc::{
|
||||
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
|
||||
nvrtc::Ptx,
|
||||
};
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
dev.load_ptx(Ptx::from_file("./examples/sin.ptx"), "sin", &["sin_kernel"])?;
|
||||
|
||||
let n = 3;
|
||||
let cfg = LaunchConfig::for_num_elems(n);
|
||||
|
||||
let a_host = [1.0, 2.0, 3.0];
|
||||
let a_dev = dev.htod_copy(a_host.into())?;
|
||||
let mut b_dev = a_dev.clone();
|
||||
|
||||
// create a stream with `fork_default_stream()`
|
||||
// This synchronizes with the default stream, so since
|
||||
// we put this call **after** the `htod_copy` & `clone` above,
|
||||
// cuda will complete those orders **before** work on this stream
|
||||
// can start.
|
||||
let stream = dev.fork_default_stream()?;
|
||||
|
||||
let f = dev.get_func("sin", "sin_kernel").unwrap();
|
||||
|
||||
// we launch it differently too
|
||||
unsafe { f.launch_on_stream(&stream, cfg, (&mut b_dev, &a_dev, n as i32)) }?;
|
||||
|
||||
// and we must join with the default work stream in order for copies
|
||||
// to work corrently.
|
||||
// NOTE: this is actually async with respect to the host!
|
||||
dev.wait_for(&stream)?;
|
||||
|
||||
let a_host_2 = dev.sync_reclaim(a_dev)?;
|
||||
let b_host = dev.sync_reclaim(b_dev)?;
|
||||
|
||||
println!("Found {:?}", b_host);
|
||||
println!("Expected {:?}", a_host.map(f32::sin));
|
||||
assert_eq!(&a_host, a_host_2.as_slice());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
52
resources/luminal_cudarc/examples/05-device-repr.rs
Normal file
52
resources/luminal_cudarc/examples/05-device-repr.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use cudarc::{driver::*, nvrtc::compile_ptx};
|
||||
|
||||
/// Here's the struct in rust, note that we have #[repr(C)]
|
||||
/// here which allows us to pass it to cuda.
|
||||
#[repr(C)]
|
||||
struct MyCoolRustStruct {
|
||||
a: f32,
|
||||
b: f64,
|
||||
c: u32,
|
||||
d: usize,
|
||||
}
|
||||
|
||||
/// We have to implement this to send it to cuda!
|
||||
unsafe impl DeviceRepr for MyCoolRustStruct {}
|
||||
|
||||
const PTX_SRC: &str = "
|
||||
// here's the same struct in cuda
|
||||
struct MyCoolStruct {
|
||||
float a;
|
||||
double b;
|
||||
unsigned int c;
|
||||
size_t d;
|
||||
};
|
||||
extern \"C\" __global__ void my_custom_kernel(MyCoolStruct thing) {
|
||||
assert(thing.a == 1.0);
|
||||
assert(thing.b == 2.34);
|
||||
assert(thing.c == 57);
|
||||
assert(thing.d == 420);
|
||||
}
|
||||
";
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
|
||||
let ptx = compile_ptx(PTX_SRC).unwrap();
|
||||
dev.load_ptx(ptx, "module", &["my_custom_kernel"])?;
|
||||
|
||||
// try changing some of these values to see a device assert
|
||||
let thing = MyCoolRustStruct {
|
||||
a: 1.0,
|
||||
b: 2.34,
|
||||
c: 57,
|
||||
d: 420,
|
||||
};
|
||||
|
||||
let f = dev.get_func("module", "my_custom_kernel").unwrap();
|
||||
|
||||
// since MyCoolRustStruct implements DeviceRepr, we can pass it to launch.
|
||||
unsafe { f.launch(LaunchConfig::for_num_elems(1), (thing,)) }?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
65
resources/luminal_cudarc/examples/06-threading.rs
Normal file
65
resources/luminal_cudarc/examples/06-threading.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use cudarc::driver::*;
|
||||
use cudarc::nvrtc::compile_ptx;
|
||||
|
||||
use std::thread;
|
||||
|
||||
const KERNEL_SRC: &str = "
|
||||
extern \"C\" __global__ void hello_world(int i) {
|
||||
printf(\"Hello from the cuda kernel in thread %d\\n\", i);
|
||||
}
|
||||
";
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (1, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
{
|
||||
// Option 1: use the same device on each thread.
|
||||
// This requires calling the CudaDevice::bind_to_thread() method.
|
||||
// Note that all kernels are submitted to the same stream/context,
|
||||
// so the kernels will still execute in sequentially in the order
|
||||
// they are submitted to the gpu.
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ptx = compile_ptx(KERNEL_SRC).unwrap();
|
||||
dev.load_ptx(ptx, "kernel", &["hello_world"])?;
|
||||
|
||||
// explicit borrow so we don't have to re-clone the device for each thread
|
||||
let dev = &dev;
|
||||
|
||||
thread::scope(move |s| {
|
||||
for i in 0..10i32 {
|
||||
s.spawn(move || {
|
||||
// NOTE: this is the important call to have
|
||||
// without this, you'll get a CUDA_ERROR_INVALID_CONTEXT
|
||||
dev.bind_to_thread()?;
|
||||
let f = dev.get_func("kernel", "hello_world").unwrap();
|
||||
unsafe { f.launch(cfg, (i,)) }
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
// Option 2: create a new device in each thread
|
||||
// This requires loading the PTX for each device, since they won't
|
||||
// share a loaded modules on the Rust side of things.
|
||||
let ptx = compile_ptx(KERNEL_SRC).unwrap();
|
||||
|
||||
thread::scope(|s| {
|
||||
for i in 0..10i32 {
|
||||
let ptx = ptx.clone();
|
||||
s.spawn(move || {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
dev.load_ptx(ptx, "kernel", &["hello_world"])?;
|
||||
let f = dev.get_func("kernel", "hello_world").unwrap();
|
||||
unsafe { f.launch(cfg, (i + 100,)) }
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1
resources/luminal_cudarc/examples/07-build-workflow/.gitignore
vendored
Normal file
1
resources/luminal_cudarc/examples/07-build-workflow/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
target/
|
||||
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "build-workflow"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.66.1"
|
||||
cc = "1.0.82"
|
||||
regex = "1.9.3"
|
||||
|
||||
[dependencies]
|
||||
cudarc = { path = "../.." }
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user