mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2501 Commits
0.1
...
pytest-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ee5b54438 | ||
|
|
389c05abeb | ||
|
|
dcc2c9cbb4 | ||
|
|
a9af4c3923 | ||
|
|
3092d0d68b | ||
|
|
8a2bd714ac | ||
|
|
54a26a044c | ||
|
|
5a0d3f87cc | ||
|
|
a28b755245 | ||
|
|
fd83534e53 | ||
|
|
b5d984c3fa | ||
|
|
64a5ca41b5 | ||
|
|
9bda47714a | ||
|
|
9e513b6589 | ||
|
|
a62d728bd7 | ||
|
|
4114714d3f | ||
|
|
6191597571 | ||
|
|
253cd95ab0 | ||
|
|
d7e396ba5b | ||
|
|
1a53626716 | ||
|
|
4329d68adc | ||
|
|
989e7e2d44 | ||
|
|
4f0a3ab102 | ||
|
|
019972cdd4 | ||
|
|
d7a3f468bd | ||
|
|
c504fbf8a1 | ||
|
|
648720caf9 | ||
|
|
625be7f4da | ||
|
|
21ed7ef31f | ||
|
|
6e94f80c9e | ||
|
|
c2a17a4854 | ||
|
|
386b3df983 | ||
|
|
5c60f1d768 | ||
|
|
4c51e3ea84 | ||
|
|
846551aa6f | ||
|
|
c26076bc75 | ||
|
|
871629b770 | ||
|
|
c6dfa9c62f | ||
|
|
90e3a915d7 | ||
|
|
56cb237aa2 | ||
|
|
a2c42b35c8 | ||
|
|
898204b2dd | ||
|
|
2c1a7f087f | ||
|
|
112d064700 | ||
|
|
c51c36fbcb | ||
|
|
ee372d464e | ||
|
|
1bef1344d1 | ||
|
|
2e27c29b47 | ||
|
|
8d41c491fd | ||
|
|
64f390a833 | ||
|
|
8d20581f38 | ||
|
|
bfd4ae9b27 | ||
|
|
92e4260f1e | ||
|
|
662a564efc | ||
|
|
1761dc6b66 | ||
|
|
da71273d7e | ||
|
|
39122672b4 | ||
|
|
d866ba6407 | ||
|
|
9a0fb453ed | ||
|
|
dab60f0b21 | ||
|
|
1ea872bd2a | ||
|
|
90a66ac704 | ||
|
|
2b94ba0b71 | ||
|
|
2ed65d5386 | ||
|
|
336d49c147 | ||
|
|
1ff5840a76 | ||
|
|
bc94b10648 | ||
|
|
7c921d03a8 | ||
|
|
4e46051617 | ||
|
|
a55952d591 | ||
|
|
679aa7e092 | ||
|
|
3dd2be2fb2 | ||
|
|
c290e266f7 | ||
|
|
be3d8aa064 | ||
|
|
84e0c842a1 | ||
|
|
403fd36b1f | ||
|
|
651a4c2aee | ||
|
|
41ddd244ef | ||
|
|
0dbee87a8c | ||
|
|
194b8adfa5 | ||
|
|
86e616800d | ||
|
|
683205121d | ||
|
|
6d653e854d | ||
|
|
08397b566d | ||
|
|
5310335256 | ||
|
|
638765b62b | ||
|
|
3850b3a533 | ||
|
|
a4c84c6cf5 | ||
|
|
f32161d43b | ||
|
|
da83d51b27 | ||
|
|
29a3ffa3e3 | ||
|
|
97e358916a | ||
|
|
631c1b53d7 | ||
|
|
02449a6bea | ||
|
|
6c4597102e | ||
|
|
b077cfdb76 | ||
|
|
869b519e39 | ||
|
|
2b831c9f25 | ||
|
|
f35a950496 | ||
|
|
9ab0e1472c | ||
|
|
88f2601d5e | ||
|
|
b0ebdcba8c | ||
|
|
0ab124194b | ||
|
|
7f042ae615 | ||
|
|
082d9c48bd | ||
|
|
251e9526f3 | ||
|
|
41b3774ec2 | ||
|
|
3fdb464f5a | ||
|
|
a3a4fd94ec | ||
|
|
5446dccb04 | ||
|
|
8e6535563e | ||
|
|
bdc923aa50 | ||
|
|
ea67742b3b | ||
|
|
149e570f26 | ||
|
|
ac52098d5c | ||
|
|
01946ecd10 | ||
|
|
ef70fee204 | ||
|
|
a6fea110dc | ||
|
|
2d4ebb2cb6 | ||
|
|
5adb875b04 | ||
|
|
2adfcfa70e | ||
|
|
65600e8730 | ||
|
|
4c4f39b4af | ||
|
|
49b9209ad0 | ||
|
|
dea5df51dd | ||
|
|
eb6a6c2174 | ||
|
|
8864ef31fb | ||
|
|
38c98a8835 | ||
|
|
31b5fd886d | ||
|
|
04b2753aa8 | ||
|
|
dbb5282fd6 | ||
|
|
8e315c62df | ||
|
|
f53d990581 | ||
|
|
2c8ecba6a5 | ||
|
|
1644cce031 | ||
|
|
b2bb455b30 | ||
|
|
8628b1425a | ||
|
|
ca66609d6f | ||
|
|
c50e122ac1 | ||
|
|
272acabd0c | ||
|
|
f772c0529a | ||
|
|
c3e1f568ea | ||
|
|
eb3dd02836 | ||
|
|
8a9f85b0ce | ||
|
|
372501e527 | ||
|
|
cda12a6d84 | ||
|
|
566fb00ed2 | ||
|
|
ecb78a2635 | ||
|
|
f401ffb900 | ||
|
|
cd94000140 | ||
|
|
1f1636e188 | ||
|
|
371fa8491a | ||
|
|
6d1fe67b66 | ||
|
|
189d1e2594 | ||
|
|
c7acfb9794 | ||
|
|
fe6af5290a | ||
|
|
8b8669c744 | ||
|
|
7af771b999 | ||
|
|
133757f187 | ||
|
|
07ee241b25 | ||
|
|
958331ab6c | ||
|
|
340199d4a8 | ||
|
|
f17a95e673 | ||
|
|
6bb576e711 | ||
|
|
744e4d767a | ||
|
|
c940161f25 | ||
|
|
3aa2c309f5 | ||
|
|
8a0592646b | ||
|
|
68ce81e52b | ||
|
|
39789404f4 | ||
|
|
8c53234966 | ||
|
|
71eca945cb | ||
|
|
8da130ae1c | ||
|
|
fef6a45c9c | ||
|
|
c6763a69ba | ||
|
|
30caca106c | ||
|
|
6c90bb5059 | ||
|
|
82189cd602 | ||
|
|
cc5e0a639d | ||
|
|
8dc05233cb | ||
|
|
0ab9947292 | ||
|
|
f11ba3a388 | ||
|
|
a346e503db | ||
|
|
6bbf244924 | ||
|
|
a8505668ac | ||
|
|
a0b237c424 | ||
|
|
8fabacd17e | ||
|
|
df0128ad04 | ||
|
|
de55e67594 | ||
|
|
cdff26755f | ||
|
|
9f11b7e24a | ||
|
|
27344a0e45 | ||
|
|
e9e6f824a1 | ||
|
|
d894eeae50 | ||
|
|
da078b5bdd | ||
|
|
f156265ff4 | ||
|
|
b34f104cea | ||
|
|
1873e26185 | ||
|
|
384a426ba3 | ||
|
|
d25654a0ec | ||
|
|
579daa1a57 | ||
|
|
4fa8a92086 | ||
|
|
bf7debc1d6 | ||
|
|
a49c970029 | ||
|
|
1d2db8f88f | ||
|
|
4fff906b8d | ||
|
|
85b49018a3 | ||
|
|
a82c530ae7 | ||
|
|
8fdfac19e1 | ||
|
|
c61dfa0a13 | ||
|
|
3815a6de67 | ||
|
|
375c54b641 | ||
|
|
1dd8853a60 | ||
|
|
050eeba815 | ||
|
|
05e0a2fc31 | ||
|
|
af7d0b002f | ||
|
|
f6c3b68f86 | ||
|
|
afc17eac31 | ||
|
|
5330ab9159 | ||
|
|
ea2cd9da45 | ||
|
|
92957228f5 | ||
|
|
4152c8a732 | ||
|
|
d95e847dba | ||
|
|
153ac59773 | ||
|
|
bcd1fe673a | ||
|
|
99d65ae9df | ||
|
|
af9442bfa0 | ||
|
|
3414a9fe6c | ||
|
|
d3a60de7fa | ||
|
|
63bad7d5a2 | ||
|
|
69e143f33d | ||
|
|
ca9113b8a8 | ||
|
|
4e27a9fb31 | ||
|
|
ad51ad88cd | ||
|
|
9afc41e0a1 | ||
|
|
0cbe874492 | ||
|
|
987ac5b5ec | ||
|
|
d4faf80cf9 | ||
|
|
8ce60f61ec | ||
|
|
3a9c945e71 | ||
|
|
765ddd5070 | ||
|
|
4178b38eb0 | ||
|
|
516faaa205 | ||
|
|
5033efc8d7 | ||
|
|
de659b4a06 | ||
|
|
612c76bd7c | ||
|
|
2230b28751 | ||
|
|
ae3e2c9331 | ||
|
|
f09c2ddd68 | ||
|
|
dae508f26c | ||
|
|
9b970e9d01 | ||
|
|
c14c7f4be4 | ||
|
|
f60147f4a7 | ||
|
|
2f2ba2fb1f | ||
|
|
695df9cb29 | ||
|
|
97af33e671 | ||
|
|
c1824a1e8d | ||
|
|
96cfdd9fcf | ||
|
|
e4d2be9c89 | ||
|
|
10ea5be27d | ||
|
|
3d5c939a7f | ||
|
|
4ba5b3b121 | ||
|
|
2dc8e23496 | ||
|
|
9798fd0e38 | ||
|
|
cf36d45b77 | ||
|
|
dccda51f30 | ||
|
|
34e9f6f5e8 | ||
|
|
ba1aa9575c | ||
|
|
77a68d3091 | ||
|
|
aaaa476352 | ||
|
|
879e4203af | ||
|
|
f5185b86d6 | ||
|
|
7bf2543909 | ||
|
|
00b797e281 | ||
|
|
41f6a64746 | ||
|
|
2556bfa90b | ||
|
|
32756f04c1 | ||
|
|
eaf8ba9219 | ||
|
|
b37c06d9b9 | ||
|
|
dd85e23e60 | ||
|
|
a0a162049e | ||
|
|
d63cb1a115 | ||
|
|
e0413c640a | ||
|
|
da9a45a044 | ||
|
|
a736d1aa2f | ||
|
|
0d5880296a | ||
|
|
ea3fa459ec | ||
|
|
d21969370f | ||
|
|
26ffc23c12 | ||
|
|
df3f6b539b | ||
|
|
45a6a62909 | ||
|
|
4d2bb35e9e | ||
|
|
fa6991d8fb | ||
|
|
a5c02dd6f4 | ||
|
|
4dc7623b9d | ||
|
|
f8fd9d568d | ||
|
|
1aeb825e12 | ||
|
|
ff45626ae1 | ||
|
|
659952ef13 | ||
|
|
45a4e8c617 | ||
|
|
548a3eab83 | ||
|
|
8c0beb1dbf | ||
|
|
ec30bd6b6b | ||
|
|
b78b57b41b | ||
|
|
539c705c22 | ||
|
|
440130a68f | ||
|
|
c6ee7e2f21 | ||
|
|
335ec78d3e | ||
|
|
057e40b26c | ||
|
|
9843d37278 | ||
|
|
38533176df | ||
|
|
c828275906 | ||
|
|
7f760ad847 | ||
|
|
a2e9b5209f | ||
|
|
d18e93d671 | ||
|
|
5a432c717c | ||
|
|
a15ae2d41c | ||
|
|
2a03239c8d | ||
|
|
9b11ce16ba | ||
|
|
428d28e307 | ||
|
|
8f4680e7c0 | ||
|
|
4cb7c2c337 | ||
|
|
38be807c11 | ||
|
|
26902c3cbf | ||
|
|
8550e98370 | ||
|
|
b47aa9f4c5 | ||
|
|
0264290091 | ||
|
|
d3c8bbb838 | ||
|
|
09af10df2d | ||
|
|
a3bc233cc9 | ||
|
|
472148ec2e | ||
|
|
202ee2e4b4 | ||
|
|
a52c90de55 | ||
|
|
d0fc4e528b | ||
|
|
3c51d2cd6e | ||
|
|
a0783896a1 | ||
|
|
64700aa2a8 | ||
|
|
c7325f5590 | ||
|
|
e1f8d2366f | ||
|
|
bfee79d764 | ||
|
|
aab5db7d10 | ||
|
|
840ab6abd2 | ||
|
|
4079eebf1f | ||
|
|
0ad7b4e509 | ||
|
|
60206e362c | ||
|
|
ce72c5223b | ||
|
|
d0d958d69d | ||
|
|
e554532108 | ||
|
|
4ebb762724 | ||
|
|
734787b7c4 | ||
|
|
97e7e6e3f2 | ||
|
|
adf1e6be8b | ||
|
|
dcf49e3a61 | ||
|
|
a3ff32d7a7 | ||
|
|
8760818756 | ||
|
|
e4c153b388 | ||
|
|
6470b74a9d | ||
|
|
e2cf926e5a | ||
|
|
97766c59cf | ||
|
|
379ea51474 | ||
|
|
4e87658460 | ||
|
|
2b16cdea24 | ||
|
|
bc7d5e8b14 | ||
|
|
4e15740876 | ||
|
|
53a056ba8a | ||
|
|
615f43ed05 | ||
|
|
5847810f52 | ||
|
|
1fbce19cfc | ||
|
|
4908ab0db1 | ||
|
|
68da1b69e0 | ||
|
|
64bc9d1786 | ||
|
|
d81253a759 | ||
|
|
0cab8bf8c8 | ||
|
|
c59f54b503 | ||
|
|
4b541db780 | ||
|
|
30e1aab3bb | ||
|
|
f3b929b1cc | ||
|
|
5f3bc73753 | ||
|
|
c10702e638 | ||
|
|
cf1c916bdb | ||
|
|
553f096907 | ||
|
|
3d98fba66c | ||
|
|
dbe4d9d2e5 | ||
|
|
95e5b40e2b | ||
|
|
c4866efd25 | ||
|
|
944000fca8 | ||
|
|
013246ee8f | ||
|
|
3942a4dc9d | ||
|
|
3cfad049b0 | ||
|
|
398b533354 | ||
|
|
b6e40d1b57 | ||
|
|
d591b55579 | ||
|
|
8808c58790 | ||
|
|
35bb113209 | ||
|
|
0cb890689a | ||
|
|
a4d49ac4e1 | ||
|
|
9abab37e57 | ||
|
|
9794146f65 | ||
|
|
0659b2a0bd | ||
|
|
aca17a072f | ||
|
|
b49c4cf36b | ||
|
|
50d16f6efa | ||
|
|
4016a0a253 | ||
|
|
5ee6127bca | ||
|
|
bdf05ce95a | ||
|
|
dce379c151 | ||
|
|
54bb1a6370 | ||
|
|
cac5eacc35 | ||
|
|
0749fd9ea4 | ||
|
|
aeb7274d55 | ||
|
|
cd246bf7cb | ||
|
|
17035b273e | ||
|
|
d43b6aa214 | ||
|
|
16b93235d7 | ||
|
|
d3e6dccd76 | ||
|
|
cbf731d10c | ||
|
|
07e637a584 | ||
|
|
1db3b4d837 | ||
|
|
0a6366c398 | ||
|
|
87c91dd0c4 | ||
|
|
96f89e16e8 | ||
|
|
e019f312ee | ||
|
|
623fd10ac2 | ||
|
|
855f996963 | ||
|
|
cd61de9362 | ||
|
|
694f66982e | ||
|
|
49ddba30fe | ||
|
|
53c0b08f5e | ||
|
|
92ab77c74c | ||
|
|
3e9c889fd2 | ||
|
|
f18556475f | ||
|
|
a5f6abc6c0 | ||
|
|
dcdc6864bf | ||
|
|
acbc2851c8 | ||
|
|
d97629b759 | ||
|
|
13889b50e0 | ||
|
|
92037538e5 | ||
|
|
710da39851 | ||
|
|
736b459a81 | ||
|
|
aa136a09ae | ||
|
|
ee782ea829 | ||
|
|
725c5947a1 | ||
|
|
63667b724a | ||
|
|
a43319b8c0 | ||
|
|
26ef39cca0 | ||
|
|
502a36046d | ||
|
|
47494e1489 | ||
|
|
5143759ef4 | ||
|
|
69ec453651 | ||
|
|
9e3e038e74 | ||
|
|
b7196102f9 | ||
|
|
9244a9acdc | ||
|
|
d48289b64b | ||
|
|
dbcd747d70 | ||
|
|
7e2df5a15b | ||
|
|
795cae0dfa | ||
|
|
9c9a05ba4e | ||
|
|
f4d8b880f5 | ||
|
|
75ccffa6db | ||
|
|
9d7c41a063 | ||
|
|
81d3c504bb | ||
|
|
02ae6cce3d | ||
|
|
223b81c872 | ||
|
|
75d2b57a76 | ||
|
|
f16f3643ad | ||
|
|
9b225f21ee | ||
|
|
5ede75ddd1 | ||
|
|
86404b762d | ||
|
|
83ae7a4f43 | ||
|
|
d4cc6b9851 | ||
|
|
04c7389ec1 | ||
|
|
f3846d482a | ||
|
|
e8bbe199f6 | ||
|
|
3a0d80acc0 | ||
|
|
47ce1f2c53 | ||
|
|
eb4a5cfe0b | ||
|
|
dcdff24d27 | ||
|
|
1d37b3f279 | ||
|
|
7044913680 | ||
|
|
c6f32e8770 | ||
|
|
d9f6f71bdf | ||
|
|
35d3c35c31 | ||
|
|
6d9899a42a | ||
|
|
b297a3ad9e | ||
|
|
06f83aecdf | ||
|
|
38a79a57ab | ||
|
|
dbbee06363 | ||
|
|
ad973665ff | ||
|
|
c0427d5680 | ||
|
|
ed71697de4 | ||
|
|
04030bb5d6 | ||
|
|
4fd2091610 | ||
|
|
86041eab27 | ||
|
|
817c1416e1 | ||
|
|
56bbd29e21 | ||
|
|
8f0790736d | ||
|
|
22365da202 | ||
|
|
a8290a4505 | ||
|
|
8fa52c8028 | ||
|
|
888512507a | ||
|
|
1b7a2e8c57 | ||
|
|
691f66c030 | ||
|
|
b66a8b9370 | ||
|
|
98cc057d4d | ||
|
|
48dfd09e28 | ||
|
|
9ee58410e6 | ||
|
|
7623ee1d66 | ||
|
|
75120847de | ||
|
|
e13e83089a | ||
|
|
3031ab25c4 | ||
|
|
23ea4c9b00 | ||
|
|
b0880e60f2 | ||
|
|
e1681f3b6b | ||
|
|
ca3159dfe3 | ||
|
|
6f9f1f9078 | ||
|
|
b003e88169 | ||
|
|
233a890e49 | ||
|
|
642f4739d1 | ||
|
|
6cfd456f70 | ||
|
|
4b710ac380 | ||
|
|
2a3a510317 | ||
|
|
3a3c610e20 | ||
|
|
cf98f09cd9 | ||
|
|
eeba8bba19 | ||
|
|
61139c4b51 | ||
|
|
190a115366 | ||
|
|
0ae77dd630 | ||
|
|
5fbc260fb1 | ||
|
|
b3aca06857 | ||
|
|
6ec9088055 | ||
|
|
7bd28301d1 | ||
|
|
30c2af5523 | ||
|
|
09644fe4a4 | ||
|
|
74ce1ad751 | ||
|
|
3ccfe54d1b | ||
|
|
a283ee099d | ||
|
|
86b04dff3b | ||
|
|
d322962b8a | ||
|
|
b424230113 | ||
|
|
557d5137ee | ||
|
|
472fdf0fa4 | ||
|
|
15cf055f6f | ||
|
|
1b473f81b0 | ||
|
|
038fdfca11 | ||
|
|
dec8aa90f9 | ||
|
|
1f359bc0d2 | ||
|
|
06d50459a1 | ||
|
|
9d7d87de81 | ||
|
|
8409a1ef1c | ||
|
|
3f19f4e331 | ||
|
|
fcfa8806e5 | ||
|
|
8c6c6a5964 | ||
|
|
e9fce9ef81 | ||
|
|
69edcb9a7e | ||
|
|
7bfe0fd61e | ||
|
|
9e941e03c1 | ||
|
|
e1191e46d5 | ||
|
|
bb973bb6eb | ||
|
|
2666bf00cc | ||
|
|
dd9b92517d | ||
|
|
2c2f486385 | ||
|
|
42837f800d | ||
|
|
f886b55bbf | ||
|
|
613ca6895e | ||
|
|
9d6dec791a | ||
|
|
5e6860d669 | ||
|
|
e39742a397 | ||
|
|
d25210615c | ||
|
|
1b183f1515 | ||
|
|
b2e8a35d94 | ||
|
|
0f2d109cf5 | ||
|
|
8bce1e1d38 | ||
|
|
687e10b31b | ||
|
|
bf3f4a33ef | ||
|
|
47520d0291 | ||
|
|
54ab273b88 | ||
|
|
abc500bc67 | ||
|
|
3d46196fb5 | ||
|
|
ec62c3441f | ||
|
|
6c3e3232ef | ||
|
|
139c7b0d70 | ||
|
|
a194fea8ac | ||
|
|
35e9695760 | ||
|
|
b9e3ecaeb8 | ||
|
|
40655a3916 | ||
|
|
96af7c7670 | ||
|
|
6dc8e8f21c | ||
|
|
6658b8dae3 | ||
|
|
fee1952e70 | ||
|
|
6eb45916a3 | ||
|
|
adfb98ca46 | ||
|
|
eb5becf391 | ||
|
|
c294bee221 | ||
|
|
65a03cd77e | ||
|
|
3057849501 | ||
|
|
6747de13a8 | ||
|
|
8f6e5aaae2 | ||
|
|
74d8564a72 | ||
|
|
7f2ea49a4f | ||
|
|
7aa4898bec | ||
|
|
4c7aab343a | ||
|
|
d3a4e3c4cf | ||
|
|
e8f0abab6a | ||
|
|
d2bf7182e9 | ||
|
|
f3a9422de0 | ||
|
|
61262232b2 | ||
|
|
89b22e0e9e | ||
|
|
7d40194902 | ||
|
|
53b45f127d | ||
|
|
6b621a3077 | ||
|
|
873c41f6b3 | ||
|
|
bb639d2861 | ||
|
|
e4fdcb7730 | ||
|
|
31fa21d9d4 | ||
|
|
ebb4f02a47 | ||
|
|
d331a5c48c | ||
|
|
23d39031ce | ||
|
|
241313c018 | ||
|
|
ec47346080 | ||
|
|
2348bcfc20 | ||
|
|
4ce67cdb11 | ||
|
|
0e8ca500ec | ||
|
|
18f8f09429 | ||
|
|
2664be3b81 | ||
|
|
6a7061497c | ||
|
|
17e5ab8690 | ||
|
|
9e6e1755e0 | ||
|
|
c93070a096 | ||
|
|
488348a890 | ||
|
|
852bd67d77 | ||
|
|
e9e169fb51 | ||
|
|
a8b7600cf2 | ||
|
|
fe7d7403e1 | ||
|
|
f7b945940e | ||
|
|
472ea979c7 | ||
|
|
07fac1e0ca | ||
|
|
e3b0d79c78 | ||
|
|
6d0fff7f35 | ||
|
|
6fc80d1432 | ||
|
|
445b8a621e | ||
|
|
16c9b1b250 | ||
|
|
ad190e4951 | ||
|
|
fac121a8fc | ||
|
|
280be38da9 | ||
|
|
919cf6b97d | ||
|
|
32ecfa83df | ||
|
|
d92295407a | ||
|
|
fa473087a6 | ||
|
|
b88e3cb60c | ||
|
|
1bec5dfb9e | ||
|
|
56573ba532 | ||
|
|
8abc155ff0 | ||
|
|
bc6ab17048 | ||
|
|
293adaf6b0 | ||
|
|
21fe6cba57 | ||
|
|
a55ee71bff | ||
|
|
9641f930a7 | ||
|
|
6529fd48f9 | ||
|
|
83d8a9a7c2 | ||
|
|
07201d4d99 | ||
|
|
7280de75d9 | ||
|
|
2e4939ef58 | ||
|
|
44a8f1019f | ||
|
|
c52e64d73d | ||
|
|
c09a9d8f50 | ||
|
|
0acd632847 | ||
|
|
ce71dd3ce9 | ||
|
|
b90a29b01a | ||
|
|
f325e94e5b | ||
|
|
f423c9aa67 | ||
|
|
75be858339 | ||
|
|
7e76f68e82 | ||
|
|
cb96fbc669 | ||
|
|
f63abbc638 | ||
|
|
b25c7a528c | ||
|
|
871e182647 | ||
|
|
a8b59361e5 | ||
|
|
398edb31e2 | ||
|
|
e7d495f4de | ||
|
|
a51300adbe | ||
|
|
5c2c1bce32 | ||
|
|
04af8b6605 | ||
|
|
352ef5bc69 | ||
|
|
0b0852c8e2 | ||
|
|
5c11c41ac3 | ||
|
|
b197180607 | ||
|
|
1bc238a0f2 | ||
|
|
e8e8257cbc | ||
|
|
5b4701a304 | ||
|
|
c08881f3ca | ||
|
|
9b8416b5e1 | ||
|
|
32637f4279 | ||
|
|
16783e3851 | ||
|
|
f802faa528 | ||
|
|
524c28b729 | ||
|
|
fd569a71a4 | ||
|
|
f7bcea3eab | ||
|
|
9c16272f03 | ||
|
|
e52087ef18 | ||
|
|
ca3ad47106 | ||
|
|
ade077f2f3 | ||
|
|
ce9d0b9fc8 | ||
|
|
a09dbe8b3a | ||
|
|
a27032f9f4 | ||
|
|
b21f5c17da | ||
|
|
53345ea47c | ||
|
|
a5ffafa300 | ||
|
|
a6fec2d1b6 | ||
|
|
b7468a4d2b | ||
|
|
1aa149f787 | ||
|
|
23f2952653 | ||
|
|
5dec359501 | ||
|
|
ab611aedae | ||
|
|
6c7f60299d | ||
|
|
51a57068ed | ||
|
|
5e060887dd | ||
|
|
7576cc892b | ||
|
|
fc0e9deb28 | ||
|
|
31102a5443 | ||
|
|
a2cad934ff | ||
|
|
89752055a2 | ||
|
|
51339e7564 | ||
|
|
2afd077198 | ||
|
|
9565b8a324 | ||
|
|
688cf9bc4a | ||
|
|
d50fd065e7 | ||
|
|
857030005d | ||
|
|
6766b47c05 | ||
|
|
5b27f661db | ||
|
|
edde16845b | ||
|
|
b93b13d3f3 | ||
|
|
450fddf98e | ||
|
|
4abe144cfa | ||
|
|
c21a482025 | ||
|
|
3dee4aa1e6 | ||
|
|
544520a71f | ||
|
|
5bcb666ad9 | ||
|
|
80794386ce | ||
|
|
0a1d0d70fe | ||
|
|
a4d5941437 | ||
|
|
dc73e445e6 | ||
|
|
808ad3d4e8 | ||
|
|
337aad82c5 | ||
|
|
14fae6755e | ||
|
|
80f0e48e08 | ||
|
|
950a108904 | ||
|
|
a355414a70 | ||
|
|
5b42419a9f | ||
|
|
baedeeec63 | ||
|
|
b1426ba8b2 | ||
|
|
a8f6110fff | ||
|
|
e6e1801426 | ||
|
|
ec578b70a9 | ||
|
|
ec91872d04 | ||
|
|
ee1daf3979 | ||
|
|
aa55c46f9e | ||
|
|
48db7a7191 | ||
|
|
da0b514f30 | ||
|
|
eafc31de50 | ||
|
|
3fbdbad4e8 | ||
|
|
d0b7acd27d | ||
|
|
33133ad7a8 | ||
|
|
a1af80c677 | ||
|
|
5c4ee6b272 | ||
|
|
1bd80ec18b | ||
|
|
cf43620a35 | ||
|
|
625ea8aaf2 | ||
|
|
a75d19d645 | ||
|
|
710375673a | ||
|
|
a945e03a14 | ||
|
|
00fb85a4da | ||
|
|
d22ff09f2b | ||
|
|
d4342df432 | ||
|
|
22168d7169 | ||
|
|
437fb84ae0 | ||
|
|
ed63b25cd6 | ||
|
|
66a56936b3 | ||
|
|
9a31917f92 | ||
|
|
7bf74b6afd | ||
|
|
94c5ac3977 | ||
|
|
bb16f481b1 | ||
|
|
d344270c79 | ||
|
|
e533ad35f1 | ||
|
|
b29e373cf1 | ||
|
|
d6321f3f6a | ||
|
|
2bcc5d54a6 | ||
|
|
eeec8d4eb5 | ||
|
|
6cb5e39e97 | ||
|
|
62170cb64a | ||
|
|
fbf3e81ef7 | ||
|
|
5c6d37b67e | ||
|
|
e51edb326e | ||
|
|
cff26db0c0 | ||
|
|
f6dad3b9c7 | ||
|
|
094eb86db0 | ||
|
|
8bc04477f3 | ||
|
|
44a60dc38a | ||
|
|
1f7b29a7d9 | ||
|
|
d3fbc58173 | ||
|
|
5fa21d2f8f | ||
|
|
53ed1f7898 | ||
|
|
6fe0f7c9d1 | ||
|
|
9bfa4096e7 | ||
|
|
0082fedd3c | ||
|
|
183aeae009 | ||
|
|
112b833989 | ||
|
|
09ada8acc3 | ||
|
|
d654927b2e | ||
|
|
fe78f9b4a0 | ||
|
|
f41e7600f2 | ||
|
|
8b0fafedd2 | ||
|
|
cf7381529f | ||
|
|
81f2276c62 | ||
|
|
b384319d42 | ||
|
|
81251f6c63 | ||
|
|
0ef5a28c5f | ||
|
|
b8906af6f8 | ||
|
|
db56feb260 | ||
|
|
2d393d6569 | ||
|
|
25be1205b1 | ||
|
|
2f15f449a2 | ||
|
|
cd4c107bdf | ||
|
|
fbe3b52703 | ||
|
|
b23a6c7555 | ||
|
|
c0519572cc | ||
|
|
6967193f16 | ||
|
|
4200d2f26a | ||
|
|
4ddd725232 | ||
|
|
e0727a3678 | ||
|
|
f9599ceb8e | ||
|
|
947e5901d2 | ||
|
|
1dcc626884 | ||
|
|
2baac7abe6 | ||
|
|
6015349e41 | ||
|
|
72958f82c3 | ||
|
|
62bd3d6725 | ||
|
|
8e1fe3b1f4 | ||
|
|
6693df7a03 | ||
|
|
5712458fcb | ||
|
|
fc9047401f | ||
|
|
2039f4983a | ||
|
|
9492cbaf5a | ||
|
|
c546bc58f0 | ||
|
|
2285849b38 | ||
|
|
8a4b2c3c13 | ||
|
|
2f3fe5569d | ||
|
|
af7fa415bf | ||
|
|
d21914888d | ||
|
|
ee84b4f534 | ||
|
|
1c33a35bbb | ||
|
|
8929fc5b21 | ||
|
|
5054cd90d8 | ||
|
|
3b4e332c65 | ||
|
|
9d8a46c665 | ||
|
|
d9aa330d1a | ||
|
|
648880f3bd | ||
|
|
16bbbf551e | ||
|
|
99c988a0b6 | ||
|
|
8564970ad8 | ||
|
|
4c8368d230 | ||
|
|
90c96fbe91 | ||
|
|
6788b98d38 | ||
|
|
262b77283f | ||
|
|
5938a06bc5 | ||
|
|
1fded13426 | ||
|
|
4fe545709e | ||
|
|
665a7c47f0 | ||
|
|
019f5b05de | ||
|
|
0b44f1a67e | ||
|
|
aec2f0ca4f | ||
|
|
9140e012d7 | ||
|
|
eaf4e350e6 | ||
|
|
aaec23ea43 | ||
|
|
0ac28cf9f7 | ||
|
|
99b15d57bc | ||
|
|
012dc6a893 | ||
|
|
709d37b349 | ||
|
|
b0ca4bf007 | ||
|
|
dde521d0c8 | ||
|
|
da359acf18 | ||
|
|
148a159a9a | ||
|
|
b7e777f3eb | ||
|
|
3d214c8dc5 | ||
|
|
de42253657 | ||
|
|
c59c6d873e | ||
|
|
06175991af | ||
|
|
e1d7a630b4 | ||
|
|
e946d7953a | ||
|
|
8304525d44 | ||
|
|
ba241070ed | ||
|
|
44141edce8 | ||
|
|
2c5c6db989 | ||
|
|
fa607d6cef | ||
|
|
f2f16f3931 | ||
|
|
180ad53ebf | ||
|
|
abc6b4fbdf | ||
|
|
d9161d1e47 | ||
|
|
fc7c69c0ef | ||
|
|
7dff991ab0 | ||
|
|
237497ea9a | ||
|
|
e7a6ca52b3 | ||
|
|
e05118d3ee | ||
|
|
0ef9de30b3 | ||
|
|
53cc3c22f3 | ||
|
|
a467e6927f | ||
|
|
472df70919 | ||
|
|
a501f08bc5 | ||
|
|
192b5a544a | ||
|
|
b09b8f14ea | ||
|
|
b73d476ec5 | ||
|
|
647ab26736 | ||
|
|
ad1ec37d60 | ||
|
|
32e7cf3813 | ||
|
|
39fdf43354 | ||
|
|
05309ece1d | ||
|
|
7453c40d62 | ||
|
|
0098b1e4e4 | ||
|
|
ebff9a563d | ||
|
|
2a10f105b8 | ||
|
|
4692788b87 | ||
|
|
3157a47585 | ||
|
|
71584e5f3a | ||
|
|
0575328fa9 | ||
|
|
9074b6ffd2 | ||
|
|
6b2e7dd83f | ||
|
|
db1de85fe1 | ||
|
|
fa62b4f3a5 | ||
|
|
d6692721ef | ||
|
|
e17269fcfc | ||
|
|
1d91ea244f | ||
|
|
e09598e97b | ||
|
|
5df9fdb311 | ||
|
|
92e0f259e4 | ||
|
|
8c7ea89ade | ||
|
|
886d9b812e | ||
|
|
bb694febc5 | ||
|
|
4288566c33 | ||
|
|
14f348ecb1 | ||
|
|
86ebdeb7a4 | ||
|
|
4a4c511323 | ||
|
|
73daea1a40 | ||
|
|
b60d348af6 | ||
|
|
58ce0722eb | ||
|
|
dca3c3bdda | ||
|
|
f1d75b97cf | ||
|
|
1ccbafdd32 | ||
|
|
c94bb93c40 | ||
|
|
57e23e5a7d | ||
|
|
8607e10d4d | ||
|
|
fd0b796d78 | ||
|
|
b6e3894693 | ||
|
|
08f118ba83 | ||
|
|
a6d785c55c | ||
|
|
bb99be01bf | ||
|
|
f462614f9b | ||
|
|
612146afed | ||
|
|
aa0c8a6532 | ||
|
|
629411b514 | ||
|
|
9a9a5bd9e5 | ||
|
|
ff16cd988e | ||
|
|
0e99c676b5 | ||
|
|
c6caca1afc | ||
|
|
617e3c771c | ||
|
|
5a1ecbad1e | ||
|
|
bccb7b4e27 | ||
|
|
8e481afd9f | ||
|
|
0a514e5fcc | ||
|
|
cbfa3bccf9 | ||
|
|
b32e65652f | ||
|
|
355378ddda | ||
|
|
8a6afcc28c | ||
|
|
2a16edb7eb | ||
|
|
7d67297081 | ||
|
|
5130641477 | ||
|
|
c6e99c7255 | ||
|
|
edcc2fbb1c | ||
|
|
3a8811a1b5 | ||
|
|
7b25e910b3 | ||
|
|
2ec40362b0 | ||
|
|
c5127855ca | ||
|
|
fffd08120d | ||
|
|
71b35e1fb0 | ||
|
|
2ae52b3570 | ||
|
|
7856b08a5a | ||
|
|
cbf594bcf9 | ||
|
|
46a11956b4 | ||
|
|
0dc29128cb | ||
|
|
c980a9fd3a | ||
|
|
291a780ab9 | ||
|
|
adcf4d02db | ||
|
|
a8c18d9f42 | ||
|
|
4c6986b012 | ||
|
|
6e0ca785b4 | ||
|
|
ca86d99f8b | ||
|
|
0bd3b80c3e | ||
|
|
455dc75efa | ||
|
|
6ff270d3a0 | ||
|
|
33c934e8fc | ||
|
|
3e18df684a | ||
|
|
5b658f6355 | ||
|
|
3f727aa9f7 | ||
|
|
bb02334848 | ||
|
|
a7f6e63170 | ||
|
|
96bb3e5e1c | ||
|
|
28c501f9d0 | ||
|
|
7520e71d05 | ||
|
|
5a241283b6 | ||
|
|
62568054ec | ||
|
|
ca8394d7af | ||
|
|
63ace83876 | ||
|
|
884d34abeb | ||
|
|
dee4aa3b69 | ||
|
|
106f0a8fa8 | ||
|
|
3c99f23a8b | ||
|
|
d2339084ca | ||
|
|
db24774070 | ||
|
|
7cd8662650 | ||
|
|
217de87959 | ||
|
|
961849fda6 | ||
|
|
fa2799a629 | ||
|
|
97b221458b | ||
|
|
823df2a826 | ||
|
|
d3d5f7f69c | ||
|
|
540ba3d0ba | ||
|
|
b732016f2e | ||
|
|
444b146987 | ||
|
|
554775e6f7 | ||
|
|
2661475362 | ||
|
|
775b6cafb2 | ||
|
|
ab0d6226dd | ||
|
|
7bd08631ab | ||
|
|
8a5711fdb3 | ||
|
|
f7cdff5ed7 | ||
|
|
90134c0a7b | ||
|
|
edc50b02ad | ||
|
|
0b9e3da0a3 | ||
|
|
3edf33d80d | ||
|
|
b6151ec2f5 | ||
|
|
4f5086b457 | ||
|
|
ca1c509261 | ||
|
|
1bcae63f0b | ||
|
|
7c259f4def | ||
|
|
1d63c86f6c | ||
|
|
0e83fa91ed | ||
|
|
4aad8acab4 | ||
|
|
c8373bc6ed | ||
|
|
2ccefef1f0 | ||
|
|
1fcaf70272 | ||
|
|
5c12a76318 | ||
|
|
3954ecc168 | ||
|
|
8c1f41058a | ||
|
|
e9a0d8353b | ||
|
|
bd4357e281 | ||
|
|
01cede3124 | ||
|
|
1f998aa4b9 | ||
|
|
4bd28878d6 | ||
|
|
421270a822 | ||
|
|
35b6d05b93 | ||
|
|
fd7962dbc9 | ||
|
|
207d3b9253 | ||
|
|
76ffbcdd84 | ||
|
|
4741ea9861 | ||
|
|
deecb40d21 | ||
|
|
781f6cb925 | ||
|
|
647fc4cadd | ||
|
|
5ff583d657 | ||
|
|
d0ef1f7034 | ||
|
|
c9c8adb5b4 | ||
|
|
cf8509ff3b | ||
|
|
ebf14c6962 | ||
|
|
7a5fb61199 | ||
|
|
352128e066 | ||
|
|
d177f698de | ||
|
|
d5c84c1a86 | ||
|
|
4f3702bfa0 | ||
|
|
3857950e7e | ||
|
|
b2b4a8de9b | ||
|
|
2e5cc0367a | ||
|
|
6826e9b248 | ||
|
|
27de8c17f3 | ||
|
|
b804383ed5 | ||
|
|
6b126a0733 | ||
|
|
eb63c3c44b | ||
|
|
ec730d4374 | ||
|
|
77a3bfa730 | ||
|
|
2173531763 | ||
|
|
7d40d58a86 | ||
|
|
b332538c59 | ||
|
|
ec9da07f82 | ||
|
|
c8d58e89d8 | ||
|
|
4ae2f12815 | ||
|
|
a3eaf48e04 | ||
|
|
79e385c9e2 | ||
|
|
6586ddad6c | ||
|
|
92d04ceed0 | ||
|
|
d8246c5fde | ||
|
|
52aed4f431 | ||
|
|
f59b0fa089 | ||
|
|
072e39e7e8 | ||
|
|
88dedff7f2 | ||
|
|
fdeb011eea | ||
|
|
31c9caf5a6 | ||
|
|
be85bf5dbe | ||
|
|
32e5e81591 | ||
|
|
b65aba9153 | ||
|
|
5348150f22 | ||
|
|
53de6ee5b3 | ||
|
|
dac589c17a | ||
|
|
4157526890 | ||
|
|
ae46ced304 | ||
|
|
f422616636 | ||
|
|
c6c7797148 | ||
|
|
44e70dffde | ||
|
|
d6ac3f7a83 | ||
|
|
70027b4691 | ||
|
|
9912d01160 | ||
|
|
6562848a23 | ||
|
|
27d300fffd | ||
|
|
d4420b8e5f | ||
|
|
903e8ca7a6 | ||
|
|
51062ba74d | ||
|
|
97fd41b338 | ||
|
|
7e0f5cc227 | ||
|
|
454e520490 | ||
|
|
97773d77fc | ||
|
|
56da21a661 | ||
|
|
b600cc5e0d | ||
|
|
6b8053c210 | ||
|
|
bd6afb8cfc | ||
|
|
944dd01971 | ||
|
|
f06d4ad138 | ||
|
|
0fec01bf43 | ||
|
|
5418817301 | ||
|
|
48a9fa324d | ||
|
|
d8cb189351 | ||
|
|
a6dba9e8f7 | ||
|
|
4f0ed828e6 | ||
|
|
423b2ffa8c | ||
|
|
e4b18fe5ab | ||
|
|
7e872ad8af | ||
|
|
a0697b201b | ||
|
|
12c1893c13 | ||
|
|
66775a03be | ||
|
|
b051f10302 | ||
|
|
5f62412e7e | ||
|
|
89a9b04f60 | ||
|
|
7f7ef1de42 | ||
|
|
8ad99caca1 | ||
|
|
a53dddbac0 | ||
|
|
57c266e783 | ||
|
|
17eab5fd8d | ||
|
|
2fb35a6420 | ||
|
|
3cf26e5239 | ||
|
|
dc19fdd2ab | ||
|
|
668e678882 | ||
|
|
9001d2ad11 | ||
|
|
f63d79ee1d | ||
|
|
1fedf03098 | ||
|
|
a5c6afa6b0 | ||
|
|
1c4ea9bdb7 | ||
|
|
c6b06a2130 | ||
|
|
05fe3d64e8 | ||
|
|
7a61edb23c | ||
|
|
8b7fc39b8b | ||
|
|
8be90191b1 | ||
|
|
9cf0eec4f8 | ||
|
|
6f71ab1450 | ||
|
|
f322c862ad | ||
|
|
5012a86249 | ||
|
|
a1e752d7cc | ||
|
|
8862e254fa | ||
|
|
3e566a2e37 | ||
|
|
88316762ec | ||
|
|
ba692ecfe0 | ||
|
|
752d3e2401 | ||
|
|
97d5e05820 | ||
|
|
89d235cf2b | ||
|
|
cde6e0e55a | ||
|
|
eff6a677a0 | ||
|
|
a1393db02a | ||
|
|
c415a1021c | ||
|
|
5aee772092 | ||
|
|
44caf80f23 | ||
|
|
6450de534c | ||
|
|
67166cac81 | ||
|
|
d6901cc018 | ||
|
|
9c2754d42d | ||
|
|
73299b535e | ||
|
|
fbbe4e1abb | ||
|
|
bd33460c96 | ||
|
|
5e9e811de9 | ||
|
|
cbd2104dc5 | ||
|
|
a550d003f2 | ||
|
|
0bdecf6d92 | ||
|
|
cf8eb769fe | ||
|
|
63d950d106 | ||
|
|
73af2e22e8 | ||
|
|
c0cf78f36f | ||
|
|
6ec4439714 | ||
|
|
d2e34a3093 | ||
|
|
e9be274c48 | ||
|
|
c309cb7b12 | ||
|
|
523d3021fb | ||
|
|
49498aa923 | ||
|
|
952b45e231 | ||
|
|
ea53fa922d | ||
|
|
26ae0c3654 | ||
|
|
555960fa22 | ||
|
|
1efd2dd4f8 | ||
|
|
dd8c34b035 | ||
|
|
e2d3f2beb0 | ||
|
|
e334dbabe4 | ||
|
|
83736c4886 | ||
|
|
ab790153ce | ||
|
|
f4c17b0f1c | ||
|
|
c5cf97ffbe | ||
|
|
5a8903c5f3 | ||
|
|
6137cd3efd | ||
|
|
177beb108e | ||
|
|
a22f8a73a1 | ||
|
|
be2c70c335 | ||
|
|
51172d9324 | ||
|
|
8d00d7628b | ||
|
|
af9214b416 | ||
|
|
6b12fda15e | ||
|
|
30a075bc06 | ||
|
|
596a2240f2 | ||
|
|
577d5f141f | ||
|
|
d8f4201929 | ||
|
|
03a927a174 | ||
|
|
fff1672936 | ||
|
|
8912478ec1 | ||
|
|
8e30b0f939 | ||
|
|
3e3c721457 | ||
|
|
89dcc519a8 | ||
|
|
fd63b22378 | ||
|
|
4fe67b2a29 | ||
|
|
976b9888a0 | ||
|
|
1daa82aab0 | ||
|
|
537a9ab161 | ||
|
|
602b25cc20 | ||
|
|
86fe320fcb | ||
|
|
59d3d656e4 | ||
|
|
61acd7ab0d | ||
|
|
71ea2d0693 | ||
|
|
2e65c6f8be | ||
|
|
1580491e29 | ||
|
|
9dea6a442e | ||
|
|
1d644dcbc1 | ||
|
|
5cf9449698 | ||
|
|
00d8dbfe31 | ||
|
|
a6072212f7 | ||
|
|
2d45b714a5 | ||
|
|
bf7b8faba1 | ||
|
|
ea381f9088 | ||
|
|
c72dacacfd | ||
|
|
fc0343918d | ||
|
|
4423661fe9 | ||
|
|
3630c538ef | ||
|
|
a0d648594f | ||
|
|
2f83a25ef8 | ||
|
|
3d5b8959a7 | ||
|
|
0f3f0c52e7 | ||
|
|
2be18f3d98 | ||
|
|
85edb92a89 | ||
|
|
fb42bf6c02 | ||
|
|
9fcf133d57 | ||
|
|
24760ba120 | ||
|
|
11a388e7cc | ||
|
|
70c427dc71 | ||
|
|
002cfeb4ce | ||
|
|
127a25b6f4 | ||
|
|
acf2ea59c6 | ||
|
|
c4e0a4871c | ||
|
|
ae9233829f | ||
|
|
be8caae06d | ||
|
|
072394c341 | ||
|
|
39259d7126 | ||
|
|
041092b0ea | ||
|
|
ee4f8ba5a3 | ||
|
|
486925da01 | ||
|
|
ad4ea3633a | ||
|
|
355f1b7fce | ||
|
|
9a23bbd3bd | ||
|
|
fcbf03add6 | ||
|
|
f72b92e4f3 | ||
|
|
6c90b412c9 | ||
|
|
4bb8477002 | ||
|
|
02e6ab96ff | ||
|
|
f27df3ddf3 | ||
|
|
7399cf3f5a | ||
|
|
25461ea5bb | ||
|
|
8baff10811 | ||
|
|
f4497242c4 | ||
|
|
26c7cf7d0c | ||
|
|
2ff3dffa95 | ||
|
|
5057689ecb | ||
|
|
59dd0e2be6 | ||
|
|
2602a5676c | ||
|
|
17f8a49185 | ||
|
|
e827571ba1 | ||
|
|
e32a744739 | ||
|
|
c07b3a0f69 | ||
|
|
eab1260d56 | ||
|
|
71745f08e3 | ||
|
|
6dcb4dde02 | ||
|
|
35892ecba5 | ||
|
|
d5bccfb503 | ||
|
|
469c7a15f0 | ||
|
|
038df8109c | ||
|
|
57e8e68973 | ||
|
|
51d652e1eb | ||
|
|
60bcbfc792 | ||
|
|
ca88789df2 | ||
|
|
0c038ee56b | ||
|
|
f5bbeb331e | ||
|
|
46ab40e82a | ||
|
|
6bf42da03c | ||
|
|
23b199c6f3 | ||
|
|
482b8771b3 | ||
|
|
6b41630179 | ||
|
|
e27109b63c | ||
|
|
344b599bca | ||
|
|
6e050631df | ||
|
|
79de7b71e5 | ||
|
|
d0083a0607 | ||
|
|
57edd100e0 | ||
|
|
6fa5aaccdb | ||
|
|
e32ef05e5d | ||
|
|
35db03f9ad | ||
|
|
9dab0f5466 | ||
|
|
d504dfaf21 | ||
|
|
f350e92952 | ||
|
|
81eeb92f4d | ||
|
|
3a8fbc8a70 | ||
|
|
3147ad793e | ||
|
|
dd2ee45b2c | ||
|
|
23af621600 | ||
|
|
e04065de31 | ||
|
|
e7f3434fb7 | ||
|
|
be1f2c5d60 | ||
|
|
f91d50f67d | ||
|
|
9a2eed2dab | ||
|
|
3ad169ed86 | ||
|
|
920ed56a5d | ||
|
|
add2a7520a | ||
|
|
bfc80bee4e | ||
|
|
5b1a1001de | ||
|
|
56621e294f | ||
|
|
b0c8ceca9f | ||
|
|
a67742d203 | ||
|
|
9760d335f1 | ||
|
|
e89907b42a | ||
|
|
8bb6d5ba1c | ||
|
|
a94f9ba13c | ||
|
|
bf2fbb2f4f | ||
|
|
ba25f6fff6 | ||
|
|
c80666c0da | ||
|
|
7831a112d3 | ||
|
|
ae5609c146 | ||
|
|
a525933737 | ||
|
|
980b7a9148 | ||
|
|
0ccd344a69 | ||
|
|
0d61e77d83 | ||
|
|
8f427daff6 | ||
|
|
1a7050a120 | ||
|
|
3c5aa2c253 | ||
|
|
969689f8e9 | ||
|
|
ea2eb367e9 | ||
|
|
8afc084553 | ||
|
|
f7997d877a | ||
|
|
67bf8e7eea | ||
|
|
e3634ebbd6 | ||
|
|
0ba3e9e43b | ||
|
|
25dd25bb38 | ||
|
|
01862e2ab7 | ||
|
|
6318d1d26e | ||
|
|
7278c722b7 | ||
|
|
b36eaa9ba6 | ||
|
|
92c3fff60a | ||
|
|
96a2187539 | ||
|
|
d40b68c367 | ||
|
|
8b608ecde7 | ||
|
|
c767877507 | ||
|
|
f2b8fbf0ad | ||
|
|
e110bb516f | ||
|
|
9a8c59c86e | ||
|
|
11db75b649 | ||
|
|
7a198f816e | ||
|
|
d235fb5a73 | ||
|
|
ee993f2caa | ||
|
|
6a24aab3d3 | ||
|
|
4b95be8c15 | ||
|
|
bc86f88b4a | ||
|
|
758f496355 | ||
|
|
6d596f8d2a | ||
|
|
211ee828ab | ||
|
|
84a2e285b6 | ||
|
|
a8810b908f | ||
|
|
44530716bd | ||
|
|
11bfb96245 | ||
|
|
da5d167a61 | ||
|
|
16177fc5e6 | ||
|
|
50619342e7 | ||
|
|
30cf346af8 | ||
|
|
3822a38a6e | ||
|
|
e4a9bb31be | ||
|
|
beeccc20f3 | ||
|
|
922c9534de | ||
|
|
ca99a3c58c | ||
|
|
40704801eb | ||
|
|
09f125ca49 | ||
|
|
a8bb86ca07 | ||
|
|
d4d90dd75e | ||
|
|
621a4ea5e3 | ||
|
|
20f831d3ee | ||
|
|
1b46e022ba | ||
|
|
5dc4d355fa | ||
|
|
fd288cfa32 | ||
|
|
cc188db593 | ||
|
|
81175151cb | ||
|
|
3533538731 | ||
|
|
67028efbf4 | ||
|
|
40293587b8 | ||
|
|
24660e793d | ||
|
|
88f7418100 | ||
|
|
63b15ea3fc | ||
|
|
80602e4278 | ||
|
|
5de0e273e6 | ||
|
|
59060da192 | ||
|
|
be9dc77b84 | ||
|
|
06aae2fe13 | ||
|
|
b28363b86d | ||
|
|
99ada89a2e | ||
|
|
ce496b9640 | ||
|
|
2f5d299fd1 | ||
|
|
385ca78757 | ||
|
|
31b7ba8aa6 | ||
|
|
128d7c5b0e | ||
|
|
f2291ffb6b | ||
|
|
1f6d629bd7 | ||
|
|
d8e9b898e9 | ||
|
|
000db2f194 | ||
|
|
d311003e8e | ||
|
|
caa7e55524 | ||
|
|
ecf6d65f23 | ||
|
|
5c7d67bff1 | ||
|
|
c23aa440b2 | ||
|
|
c47b38a56a | ||
|
|
e1279d9780 | ||
|
|
0db8f6c793 | ||
|
|
207bc3686d | ||
|
|
fe8eb971d5 | ||
|
|
81db899306 | ||
|
|
38e2029ff4 | ||
|
|
a31000a9a6 | ||
|
|
50d1bb4d9f | ||
|
|
65909205e3 | ||
|
|
1a1fd056b0 | ||
|
|
b5c38dc6db | ||
|
|
56e5c5797d | ||
|
|
2bdad487d9 | ||
|
|
9aba341b95 | ||
|
|
fb544b7530 | ||
|
|
9d37f25c7e | ||
|
|
e917ad2b43 | ||
|
|
8d7b8c8972 | ||
|
|
ddc201e1c1 | ||
|
|
0c98ce2701 | ||
|
|
7136f54404 | ||
|
|
82c9833e30 | ||
|
|
8d36e703d7 | ||
|
|
14c269e604 | ||
|
|
4db3120d68 | ||
|
|
9f98da3581 | ||
|
|
f61d53f859 | ||
|
|
82f2165e20 | ||
|
|
26af0a81d7 | ||
|
|
db69f64c31 | ||
|
|
f5c9f6d56b | ||
|
|
3c47c9f874 | ||
|
|
e4fecf85ea | ||
|
|
2605a02d04 | ||
|
|
35b3883f98 | ||
|
|
ad1a3d9eca | ||
|
|
2756c87b42 | ||
|
|
1c59194427 | ||
|
|
ee69188842 | ||
|
|
d1a493b162 | ||
|
|
aebbd9c08c | ||
|
|
ad2d73fa9b | ||
|
|
89b6d753bc | ||
|
|
2d5393f29f | ||
|
|
619b354bb9 | ||
|
|
e35dd4b95b | ||
|
|
07e6f64a3c | ||
|
|
38c5699977 | ||
|
|
c11fc64edd | ||
|
|
706758b45e | ||
|
|
e2f7e813dc | ||
|
|
1df0327af3 | ||
|
|
1b00a04238 | ||
|
|
b385db2c3a | ||
|
|
63e22243a0 | ||
|
|
571224ec28 | ||
|
|
efd9b1ce0b | ||
|
|
252992f6f5 | ||
|
|
2795c594a5 | ||
|
|
f108b7aa99 | ||
|
|
92bf7470dc | ||
|
|
25f4247208 | ||
|
|
60eba92e34 | ||
|
|
2364298296 | ||
|
|
3c56aa6842 | ||
|
|
af0f7a150a | ||
|
|
d96cdc3430 | ||
|
|
a52ab8072b | ||
|
|
07ab3b31d6 | ||
|
|
a5d292a47d | ||
|
|
1e85a64f56 | ||
|
|
3f75ed6c28 | ||
|
|
3ed9fa6353 | ||
|
|
61fc7aaccb | ||
|
|
2597455824 | ||
|
|
aa615c003e | ||
|
|
5409c64349 | ||
|
|
4b2c0c6251 | ||
|
|
5f730aef1f | ||
|
|
6d917dd579 | ||
|
|
97b66a376f | ||
|
|
17dca47be7 | ||
|
|
4213177982 | ||
|
|
114c587a49 | ||
|
|
bf97c89873 | ||
|
|
816deea270 | ||
|
|
70bbd8d4cf | ||
|
|
bb9552e13a | ||
|
|
a86b915b00 | ||
|
|
a73c2b3dd5 | ||
|
|
ffb2baa3a8 | ||
|
|
7c78ea4ec7 | ||
|
|
dfd40edb95 | ||
|
|
9f615d3319 | ||
|
|
78de438035 | ||
|
|
bb97aab756 | ||
|
|
b7e40c1317 | ||
|
|
aa939cb031 | ||
|
|
349ff23685 | ||
|
|
dc01280737 | ||
|
|
07b7bb69b8 | ||
|
|
0f38920aaa | ||
|
|
9e048ff98d | ||
|
|
962e301d2b | ||
|
|
f2c3130e3b | ||
|
|
fcb824df6a | ||
|
|
a92d13642f | ||
|
|
284532420d | ||
|
|
a29319e2fb | ||
|
|
d03c8a041a | ||
|
|
b323f827dd | ||
|
|
c1cedbc268 | ||
|
|
0f196966a7 | ||
|
|
ae300ec7e4 | ||
|
|
965b362f34 | ||
|
|
efb53d6900 | ||
|
|
ced823c707 | ||
|
|
c50d51a69e | ||
|
|
c5f54b5104 | ||
|
|
eb774b7647 | ||
|
|
88ddc0cf09 | ||
|
|
fc9aa59a13 | ||
|
|
0ad2d92a5b | ||
|
|
d3178b3443 | ||
|
|
e2be277347 | ||
|
|
15ba813ca6 | ||
|
|
8f01f4dba3 | ||
|
|
0efbc51e41 | ||
|
|
4fee36107d | ||
|
|
b48f611464 | ||
|
|
4e6fa93c9b | ||
|
|
41d6d08cd3 | ||
|
|
0f63429513 | ||
|
|
eb5c40832e | ||
|
|
d1c06cc9f3 | ||
|
|
c62419270d | ||
|
|
fa01dbe099 | ||
|
|
67f267a276 | ||
|
|
f497531908 | ||
|
|
7b3214797e | ||
|
|
9e6403ca75 | ||
|
|
003c824a02 | ||
|
|
568b5de507 | ||
|
|
012734c044 | ||
|
|
38bc35230f | ||
|
|
1e5d63949e | ||
|
|
fa2b7ac22e | ||
|
|
18a9aaa38e | ||
|
|
868e1c667e | ||
|
|
9b734d6cbd | ||
|
|
076a165904 | ||
|
|
fb84e93815 | ||
|
|
6869d81383 | ||
|
|
8bf379b3eb | ||
|
|
ad8908ab79 | ||
|
|
3850e14649 | ||
|
|
5e3e69d109 | ||
|
|
09264419ec | ||
|
|
b502804e98 | ||
|
|
a859870a08 | ||
|
|
45dc0f688a | ||
|
|
358dcebde2 | ||
|
|
c49507e087 | ||
|
|
194f221372 | ||
|
|
cd0195ca81 | ||
|
|
5bc7417f55 | ||
|
|
7357817b46 | ||
|
|
a7f823080d | ||
|
|
862af13096 | ||
|
|
a39176dc7a | ||
|
|
b359ecf887 | ||
|
|
2bc3c3b00e | ||
|
|
4bceec314f | ||
|
|
31a93c0045 | ||
|
|
f6029cefc4 | ||
|
|
340a31f410 | ||
|
|
29b019fe01 | ||
|
|
5c9a236843 | ||
|
|
896d3913d9 | ||
|
|
172ae602ca | ||
|
|
29a0bfe195 | ||
|
|
5b8e922a70 | ||
|
|
1424c40384 | ||
|
|
b9d1f6b00e | ||
|
|
149154ba26 | ||
|
|
8fc438550c | ||
|
|
5c99c9da04 | ||
|
|
3184d273cf | ||
|
|
1a43b7485e | ||
|
|
1c40dd0ed4 | ||
|
|
577fd5531b | ||
|
|
3412708a91 | ||
|
|
daaefaeabd | ||
|
|
c576a4e7e4 | ||
|
|
f0d6a58a7f | ||
|
|
396cae039b | ||
|
|
9c63b6456f | ||
|
|
4a0bd46d4d | ||
|
|
a94c97e724 | ||
|
|
d603aac3f2 | ||
|
|
90e350b6fb | ||
|
|
2126556f25 | ||
|
|
a15a79bff6 | ||
|
|
da68799982 | ||
|
|
ddd001ef0e | ||
|
|
acf1e4b465 | ||
|
|
589148707f | ||
|
|
239af0df26 | ||
|
|
5fcee0050c | ||
|
|
481e416b0d | ||
|
|
a6eef04817 | ||
|
|
a4d033a9af | ||
|
|
a68d8bd341 | ||
|
|
8ef34a5805 | ||
|
|
9a34224891 | ||
|
|
21c1e72d1f | ||
|
|
784fe20f32 | ||
|
|
355f1a4816 | ||
|
|
aff47d5960 | ||
|
|
5d07fbc853 | ||
|
|
9540ab1c88 | ||
|
|
03a9dedb04 | ||
|
|
bf64f3cc7b | ||
|
|
c95c514c52 | ||
|
|
a30212bb11 | ||
|
|
6ed065d4b7 | ||
|
|
00be879d6a | ||
|
|
72c2ff6c3c | ||
|
|
ca1c7666cb | ||
|
|
dcb8b0de22 | ||
|
|
f65b029249 | ||
|
|
297b0c3a9d | ||
|
|
f0248d2954 | ||
|
|
eb87e69b0e | ||
|
|
6a76c4a7d5 | ||
|
|
59cd270b3f | ||
|
|
4219e5c0ef | ||
|
|
0c1885fce7 | ||
|
|
c7777ecc3a | ||
|
|
f7e699d9d8 | ||
|
|
371df9ecb0 | ||
|
|
95438654ac | ||
|
|
9b178b99ed | ||
|
|
aac542edcb | ||
|
|
c636bd34d5 | ||
|
|
780951c828 | ||
|
|
eb3e8a26b2 | ||
|
|
19abb4a6b0 | ||
|
|
b58b9848ad | ||
|
|
8ffd089b51 | ||
|
|
32bbff4953 | ||
|
|
3e956f91e7 | ||
|
|
3be6fb103c | ||
|
|
890e3bba93 | ||
|
|
953ce1b1f8 | ||
|
|
ede1454338 | ||
|
|
f7e3786455 | ||
|
|
bff20dd582 | ||
|
|
3617460ac5 | ||
|
|
34fc6156e5 | ||
|
|
f9fac0ceec | ||
|
|
029c22ef95 | ||
|
|
e89e3617b8 | ||
|
|
b3b34277f4 | ||
|
|
9b18274db6 | ||
|
|
32f10ac347 | ||
|
|
d47ca517a4 | ||
|
|
8078dea150 | ||
|
|
b785132c2c | ||
|
|
cb07523f02 | ||
|
|
75dd9a2856 | ||
|
|
94dd04d3d4 | ||
|
|
9394c5b24c | ||
|
|
13604b696b | ||
|
|
9d4d9a51da | ||
|
|
dc9b57b91a | ||
|
|
11da4b339d | ||
|
|
b7e02f7995 | ||
|
|
be4fb7dd9f | ||
|
|
6b2c216e45 | ||
|
|
d309b4f338 | ||
|
|
d562a0321b | ||
|
|
1b61fd2e4d | ||
|
|
e0c9b2b1ff | ||
|
|
d976f71585 | ||
|
|
00acf7aebe | ||
|
|
29751edf20 | ||
|
|
235db905da | ||
|
|
6337d90bce | ||
|
|
8c434b5081 | ||
|
|
84ba491b56 | ||
|
|
c8044504c5 | ||
|
|
9132ad8d94 | ||
|
|
c20b257657 | ||
|
|
89d9bbe105 | ||
|
|
2f34d413e1 | ||
|
|
a9875fde4d | ||
|
|
c19e211629 | ||
|
|
3e49033616 | ||
|
|
f0135920aa | ||
|
|
8776a1c3de | ||
|
|
b6022900b0 | ||
|
|
cc94802ed0 | ||
|
|
ed2fc61c73 | ||
|
|
ebbc86a312 | ||
|
|
b065cdd22b | ||
|
|
c7a0944eda | ||
|
|
b2d6a48eab | ||
|
|
ca71f0dd16 | ||
|
|
94306b086a | ||
|
|
7f2b9cf336 | ||
|
|
df1f5c3ca8 | ||
|
|
ffb7e4c706 | ||
|
|
598e303649 | ||
|
|
866dfb7804 | ||
|
|
1cfefed1ce | ||
|
|
d5c6ef451c | ||
|
|
e363385d3f | ||
|
|
258d1be49f | ||
|
|
18f15d98e2 | ||
|
|
0b5ce105a3 | ||
|
|
d66bf3412a | ||
|
|
5ede551fcb | ||
|
|
e10a19668c | ||
|
|
881fa13a13 | ||
|
|
e279912bb8 | ||
|
|
6dc2a996d2 | ||
|
|
7ee1cad15c | ||
|
|
b8a0f08cea | ||
|
|
edc96f3626 | ||
|
|
dd369f18a9 | ||
|
|
9326fe3cc8 | ||
|
|
3bd99c9f24 | ||
|
|
bd56364160 | ||
|
|
9547004247 | ||
|
|
647f119d3c | ||
|
|
8952443ebd | ||
|
|
5947e5cd3d | ||
|
|
10d94710f7 | ||
|
|
d13af7c562 | ||
|
|
c2bbe446da | ||
|
|
b0a732e5b0 | ||
|
|
59cf7998c9 | ||
|
|
a6f38be402 | ||
|
|
bc92e3137f | ||
|
|
30310a173d | ||
|
|
c00935b451 | ||
|
|
15e4ee6aa3 | ||
|
|
9ec1e75fe6 | ||
|
|
5898076da5 | ||
|
|
5b17c1880e | ||
|
|
1afea6bd86 | ||
|
|
8dff3619b9 | ||
|
|
111452a68e | ||
|
|
d147ed5063 | ||
|
|
162859dedb | ||
|
|
56de7fa4c3 | ||
|
|
7cc02dd51d | ||
|
|
e5963f1c9a | ||
|
|
9d32721ca7 | ||
|
|
bc6b8fb283 | ||
|
|
12381b2624 | ||
|
|
2821145268 | ||
|
|
959528efad | ||
|
|
6a5a45eeae | ||
|
|
4166e27055 | ||
|
|
f55cf6c0f7 | ||
|
|
6ddabf2995 | ||
|
|
54461a6d33 | ||
|
|
b5d6f424d9 | ||
|
|
f846af5901 | ||
|
|
f9c766dca7 | ||
|
|
218db50c79 | ||
|
|
3fddb7e5a8 | ||
|
|
7bd8de272b | ||
|
|
80915d3f3a | ||
|
|
791f1395d5 | ||
|
|
b5a13381a9 | ||
|
|
c64e408471 | ||
|
|
b1770a0b0e | ||
|
|
37dc4428af | ||
|
|
2d198b6be7 | ||
|
|
67e8e439c0 | ||
|
|
908d2c9222 | ||
|
|
c401a95af2 | ||
|
|
e2864d852f | ||
|
|
f043ba2d5e | ||
|
|
cf8412d3bf | ||
|
|
5b4bde0070 | ||
|
|
9fead8dad3 | ||
|
|
0d44507f3c | ||
|
|
3272749663 | ||
|
|
5f917dcbcf | ||
|
|
85a08aca3f | ||
|
|
192858edf1 | ||
|
|
9a5e6f6e69 | ||
|
|
6884bd010d | ||
|
|
9dd852c27e | ||
|
|
198fe76cb3 | ||
|
|
9696c4ce09 | ||
|
|
9a2f8fadd3 | ||
|
|
b59fefaa11 | ||
|
|
8348d06902 | ||
|
|
8f7f6a6ab3 | ||
|
|
13e6dc6da5 | ||
|
|
244711d46e | ||
|
|
9695bcef84 | ||
|
|
2f20b9959c | ||
|
|
308938ec02 | ||
|
|
b1c435b6be | ||
|
|
4219d8ec7b | ||
|
|
8bd7598678 | ||
|
|
e89bdbb612 | ||
|
|
ebb0df6c69 | ||
|
|
8f2d13df3d | ||
|
|
69c207b599 | ||
|
|
fa04b05b5d | ||
|
|
54912c4f6a | ||
|
|
1c0f525e57 | ||
|
|
26c0de512f | ||
|
|
0c27cb02a8 | ||
|
|
b822800ffe | ||
|
|
b54da0ddde | ||
|
|
9295ff8d72 | ||
|
|
e5dcff3f34 | ||
|
|
a1acd5883b | ||
|
|
556e386621 | ||
|
|
9f9256f08a | ||
|
|
f3c53c1193 | ||
|
|
9f668ee333 | ||
|
|
617ef95c09 | ||
|
|
c539946c25 | ||
|
|
7e9f1c7fc0 | ||
|
|
cf0e6ad2f6 | ||
|
|
9813b188f3 | ||
|
|
bf7c1c5608 | ||
|
|
ec09c0202b | ||
|
|
71365cf2d4 | ||
|
|
481d074f5a | ||
|
|
a240e2adc8 | ||
|
|
c3643925ef | ||
|
|
a6b368fa14 | ||
|
|
ab9df3d94e | ||
|
|
c727113351 | ||
|
|
d203df40d5 | ||
|
|
c506d1e783 | ||
|
|
56ce86f194 | ||
|
|
54a8ebc60d | ||
|
|
b3e07bd638 | ||
|
|
94a6a0a9e9 | ||
|
|
fb279c9ee6 | ||
|
|
3ae34ad3b3 | ||
|
|
6b08212df8 | ||
|
|
03d2d02d00 | ||
|
|
0f09b19199 | ||
|
|
fcf232699f | ||
|
|
1ed89b5656 | ||
|
|
69da97727b | ||
|
|
9edf9cdc0b | ||
|
|
2f13fd6100 | ||
|
|
ed278c9be3 | ||
|
|
9e04457895 | ||
|
|
e6c4291db6 | ||
|
|
f62e6ad85e | ||
|
|
0ba62fde38 | ||
|
|
d62f2e217a | ||
|
|
f385ea287e | ||
|
|
140ee69480 | ||
|
|
2c93b7788c | ||
|
|
4fdc8f38eb | ||
|
|
c0645fe35e | ||
|
|
5b5812defa | ||
|
|
349e3d2472 | ||
|
|
fa67608d48 | ||
|
|
527c20d146 | ||
|
|
ff1da67423 | ||
|
|
efd7489a1c | ||
|
|
4dd7cd7cfd | ||
|
|
33274b905e | ||
|
|
3670378bc6 | ||
|
|
275180be20 | ||
|
|
40a62e70be | ||
|
|
95462aa89e | ||
|
|
7a9f9e04d0 | ||
|
|
cf35b286f2 | ||
|
|
e1cf44a4e0 | ||
|
|
b891b8b595 | ||
|
|
67366e1a2f | ||
|
|
ee8206e2ca | ||
|
|
5cdc559241 | ||
|
|
daa7166534 | ||
|
|
2cf0bc29c8 | ||
|
|
139ae0ddad | ||
|
|
703f4d3847 | ||
|
|
d79042d334 | ||
|
|
f9b52f0058 | ||
|
|
5b50192830 | ||
|
|
ae431e0dd4 | ||
|
|
35626309ac | ||
|
|
a38168a91c | ||
|
|
64ebab654f | ||
|
|
ec0ea40bbe | ||
|
|
49ae10a25e | ||
|
|
1a1ba5216b | ||
|
|
0bbc6215d8 | ||
|
|
4e5300c4d4 | ||
|
|
166d4a12a5 | ||
|
|
e4f90c304b | ||
|
|
fa966c8c7c | ||
|
|
9a0261acd2 | ||
|
|
743bacb125 | ||
|
|
d0afd42eb2 | ||
|
|
4c9691c49d | ||
|
|
9aaff41dfa | ||
|
|
a8b6508155 | ||
|
|
a23e536fa0 | ||
|
|
e654f3e72d | ||
|
|
1a6ce5df82 | ||
|
|
a6cd8d9b0f | ||
|
|
8a62e090a3 | ||
|
|
b550de47e4 | ||
|
|
5bc2477352 | ||
|
|
370973108d | ||
|
|
88ed1ded6d | ||
|
|
e9b8a883d0 | ||
|
|
4a7db75715 | ||
|
|
72b3cba68b | ||
|
|
e7c78e9b46 | ||
|
|
0bc32b9c92 | ||
|
|
9b81ef2326 | ||
|
|
cfc8e7dae2 | ||
|
|
09666f93ab | ||
|
|
b489a86fa9 | ||
|
|
4d4338fb58 | ||
|
|
805ebb1931 | ||
|
|
a57b316216 | ||
|
|
94e08ae947 | ||
|
|
21aee96114 | ||
|
|
ac802a3273 | ||
|
|
70f4fff5c2 | ||
|
|
f2e1c17c8c | ||
|
|
9493c11a53 | ||
|
|
7c72d5b06f | ||
|
|
a15cfbae65 | ||
|
|
34ab545763 | ||
|
|
e67d3e6598 | ||
|
|
621536a1dd | ||
|
|
6d9f9176cd | ||
|
|
2e81b54446 | ||
|
|
38acdf315e | ||
|
|
30dff8597c | ||
|
|
2ebd5f2deb | ||
|
|
162b8c38a1 | ||
|
|
1fb155ddfd | ||
|
|
241b9f527b | ||
|
|
53dc4dd9df | ||
|
|
7c7558fcb3 | ||
|
|
5262e32346 | ||
|
|
664fad5f84 | ||
|
|
4c3e530ef3 | ||
|
|
d582111d04 | ||
|
|
e9384dc714 | ||
|
|
b97da50c9d | ||
|
|
517124b424 | ||
|
|
0ef3121ac6 | ||
|
|
542f74f404 | ||
|
|
8662ba864d | ||
|
|
0fc68006d5 | ||
|
|
eac3a57b6d | ||
|
|
f46bc1cb99 | ||
|
|
8f7004c4c3 | ||
|
|
185facb1d5 | ||
|
|
07d0febef1 | ||
|
|
d35a40eacb | ||
|
|
8a744e6035 | ||
|
|
50b47f8610 | ||
|
|
c1af144891 | ||
|
|
dd123fec89 | ||
|
|
92cca97a76 | ||
|
|
a5d01c7576 | ||
|
|
84fbf805c3 | ||
|
|
51545ee82c | ||
|
|
e16771035f | ||
|
|
10ee2c7343 | ||
|
|
f637fff192 | ||
|
|
3e0cafbae3 | ||
|
|
d6c9c977d8 | ||
|
|
1a0f59943e | ||
|
|
5032d894b8 | ||
|
|
2046ee9ade | ||
|
|
d48ac14458 | ||
|
|
24ff638e43 | ||
|
|
7bb4e856ec | ||
|
|
be93cfe817 | ||
|
|
140aeb4591 | ||
|
|
907dadc6a0 | ||
|
|
c9a1e5c47d | ||
|
|
f36d98363c | ||
|
|
c05c0e0575 | ||
|
|
123b48d5ec | ||
|
|
c4553fc132 | ||
|
|
b38be86191 | ||
|
|
da3970082a | ||
|
|
c2a11bf114 | ||
|
|
75a141d8ba | ||
|
|
ee17a48dbe | ||
|
|
02df7e7f8d | ||
|
|
22ae700048 | ||
|
|
dbe6a42018 | ||
|
|
7387ca1b19 | ||
|
|
9e03c3421f | ||
|
|
8a6d088ff3 | ||
|
|
305e8f104c | ||
|
|
472eae1576 | ||
|
|
6c234daba2 | ||
|
|
dfb8691923 | ||
|
|
2784738e41 | ||
|
|
b86b27e0c7 | ||
|
|
ba3faa49df | ||
|
|
c833a65153 | ||
|
|
cab6b2fff2 | ||
|
|
35e5da1ff4 | ||
|
|
67aac97299 | ||
|
|
2b884d6304 | ||
|
|
9fa0b8d0a5 | ||
|
|
422fd32d74 | ||
|
|
1d88be2001 | ||
|
|
3d5c3180be | ||
|
|
47b61ac847 | ||
|
|
18560d0852 | ||
|
|
1400aecf1d | ||
|
|
9e3bea8cac | ||
|
|
b4bf84840e | ||
|
|
941a8b93eb | ||
|
|
666cbe6c5a | ||
|
|
33b7f0914f | ||
|
|
2b2e06d6fa | ||
|
|
eaa4ad8ef5 | ||
|
|
750a6e9e8b | ||
|
|
0028b5ca78 | ||
|
|
d4b18a0e35 | ||
|
|
22d7c563cb | ||
|
|
4671708601 | ||
|
|
9b3948a3ff | ||
|
|
4c415fba7b | ||
|
|
f775833e10 | ||
|
|
b1b06b1e15 | ||
|
|
bf8f3d91d2 | ||
|
|
b40fb1a94b | ||
|
|
2e52833bb5 | ||
|
|
7e2518bbba | ||
|
|
808cf7849e | ||
|
|
1a454b23f8 | ||
|
|
bc4483706b | ||
|
|
8d0cff2b0b | ||
|
|
35097e8e2b | ||
|
|
1fdc8de899 | ||
|
|
995293e5da | ||
|
|
e9d7604f0b | ||
|
|
85824bb1ee | ||
|
|
652f0e365f | ||
|
|
554331f567 | ||
|
|
7311c8f48c | ||
|
|
9a904b6dcc | ||
|
|
8b4234eb60 | ||
|
|
b121bcb20b | ||
|
|
d63ceba488 | ||
|
|
67d8d6b992 | ||
|
|
0130d5dfd9 | ||
|
|
ab8f7187e6 | ||
|
|
0c291d594b | ||
|
|
1d828f7982 | ||
|
|
6f3d52f345 | ||
|
|
ed1f76808d | ||
|
|
a00fe78aa1 | ||
|
|
58a56f9fc0 | ||
|
|
2254b4c96c | ||
|
|
b6a0caa79b | ||
|
|
4e7c6c27ce | ||
|
|
4eb0a8e1fb | ||
|
|
e222cb7a97 | ||
|
|
858f198b43 | ||
|
|
4b5872b5d1 | ||
|
|
d2269eebf7 | ||
|
|
7f8b21f71f | ||
|
|
ca1703745f | ||
|
|
c3f2547349 | ||
|
|
5c24050775 | ||
|
|
927fb9fac2 | ||
|
|
167944b422 | ||
|
|
4cec36f4b5 | ||
|
|
66fbf23d67 | ||
|
|
7e401c69c7 | ||
|
|
5c4076bc8c | ||
|
|
d2cb4f0d48 | ||
|
|
ef16ee6b23 | ||
|
|
c518caacf2 | ||
|
|
acef1725f3 | ||
|
|
6cad14a20b | ||
|
|
c33333724d | ||
|
|
8149440f8f | ||
|
|
b4717747d5 | ||
|
|
9fc98f3288 | ||
|
|
24347bf69c | ||
|
|
c5aa4d2975 | ||
|
|
9ec05b25a8 | ||
|
|
bd83d880a9 | ||
|
|
8037d370ee | ||
|
|
4a0a86577e | ||
|
|
75ea980bd2 | ||
|
|
de5049577c | ||
|
|
5f99756be4 | ||
|
|
33724c7214 | ||
|
|
4f75032c7e | ||
|
|
a45b4b6e85 | ||
|
|
912db261fe | ||
|
|
fad53704fd | ||
|
|
29aeac0531 | ||
|
|
a426971470 | ||
|
|
1bd50bff21 | ||
|
|
d2d733b931 | ||
|
|
1ad6edd9ce | ||
|
|
d924809d85 | ||
|
|
24b1b324e6 | ||
|
|
531b28f75a | ||
|
|
ce40bb7f58 | ||
|
|
be667fb936 | ||
|
|
de2a2c8bb8 | ||
|
|
55e68dff43 | ||
|
|
e922d565a7 | ||
|
|
1763e85aa7 | ||
|
|
5d10422881 | ||
|
|
a004408327 | ||
|
|
cc0b34a640 | ||
|
|
21596a01d7 | ||
|
|
68f0c6f6ca | ||
|
|
3a2ab1d176 | ||
|
|
ff7289ef39 | ||
|
|
0b370359c4 | ||
|
|
d2b720da3f | ||
|
|
ef1054a921 | ||
|
|
0b30af2a7a | ||
|
|
5d97b4ee52 | ||
|
|
e179494ac4 | ||
|
|
db2fc3cbb0 | ||
|
|
935caa24ce | ||
|
|
268a9b2cf8 | ||
|
|
5d8238bcf4 | ||
|
|
a1c4f18725 | ||
|
|
19ec1f1d36 | ||
|
|
3032c685cd | ||
|
|
c890ebdbe1 | ||
|
|
a26d2fe86f | ||
|
|
312305fcb7 | ||
|
|
a402a29f93 | ||
|
|
ed964105ec | ||
|
|
414a3dcc83 | ||
|
|
c51e87385f | ||
|
|
fec403b9f5 | ||
|
|
90e06d90e5 | ||
|
|
25bf6ee63a | ||
|
|
9e5880b130 | ||
|
|
9220e7b1e0 | ||
|
|
4a97c8bee9 | ||
|
|
93d45509ad | ||
|
|
c0d0ec0c32 | ||
|
|
c4b4233e20 | ||
|
|
61e59b27ec | ||
|
|
d16d22492e | ||
|
|
3a25325d37 | ||
|
|
4599fec534 | ||
|
|
1637d0fdb8 | ||
|
|
13de77b68c | ||
|
|
700d8f71e2 | ||
|
|
84eea2a0eb | ||
|
|
b2be7b2583 | ||
|
|
5c396368b6 | ||
|
|
7335d07755 | ||
|
|
b63746fe84 | ||
|
|
ef964536e9 | ||
|
|
f96e3a903e | ||
|
|
7ec82a97d6 | ||
|
|
741b167910 | ||
|
|
268fb4e9aa | ||
|
|
3b0b264ba5 | ||
|
|
1c7a3b8ed9 | ||
|
|
7c307c886e | ||
|
|
e00a89c647 | ||
|
|
c6d37ed5c5 | ||
|
|
deef279977 | ||
|
|
835527333c | ||
|
|
b2735b8dc6 | ||
|
|
7050a8bd7a | ||
|
|
2c6ac7124e | ||
|
|
cc0c2bf8cb | ||
|
|
e335bb24df | ||
|
|
ef0768ebef | ||
|
|
a6c8c4c254 | ||
|
|
1f81ffb182 | ||
|
|
23b7937507 | ||
|
|
ac23472220 | ||
|
|
0e07eb7614 | ||
|
|
f6e2fd1be2 | ||
|
|
ddc6644a87 | ||
|
|
6d987df3e2 | ||
|
|
7b2fd581b6 | ||
|
|
6f810111c4 | ||
|
|
3b154540da | ||
|
|
0675610007 | ||
|
|
2ae67dd894 | ||
|
|
8623843e72 | ||
|
|
98ef29fec0 | ||
|
|
7d37b56c20 | ||
|
|
54c48df279 | ||
|
|
7460fcde9d | ||
|
|
fc2a56039a | ||
|
|
b29f8e3a0f | ||
|
|
3031ead6dc | ||
|
|
20951c0721 | ||
|
|
75b1064922 | ||
|
|
1a4135515b | ||
|
|
acbb1b6e2c | ||
|
|
c0632cb689 | ||
|
|
144e3b7a98 | ||
|
|
dfd21a343b | ||
|
|
0faadea621 | ||
|
|
0dd8f4b7c7 | ||
|
|
96e39c2535 | ||
|
|
909d5b7836 | ||
|
|
1125351f4c | ||
|
|
345622f452 | ||
|
|
53b9bd6e61 | ||
|
|
e7d0a08150 | ||
|
|
0939f50ce2 | ||
|
|
84d7a0cedc | ||
|
|
c9c540057b | ||
|
|
a2edbe14ec | ||
|
|
694fa93d30 | ||
|
|
4214a33525 | ||
|
|
404322b4ab | ||
|
|
afd3eeee88 | ||
|
|
84adc99c33 | ||
|
|
3f4b592c60 | ||
|
|
d61c848f6a | ||
|
|
94c7d00517 | ||
|
|
e799363d0d | ||
|
|
d0d7f74e42 | ||
|
|
de5835822d | ||
|
|
77fb4305e8 | ||
|
|
4cdb364e4a | ||
|
|
fbebf6d485 | ||
|
|
4fde0f4524 | ||
|
|
6f3cff1cd4 | ||
|
|
b90847c43f | ||
|
|
e5c7c8b2a2 | ||
|
|
9e453719e3 | ||
|
|
63b04f1e9a | ||
|
|
904baefa68 | ||
|
|
c82a00981a | ||
|
|
d1add4231f | ||
|
|
678591a1a5 | ||
|
|
f4a07f5259 | ||
|
|
e5e904498c | ||
|
|
89740bdd30 | ||
|
|
c6b72fa317 | ||
|
|
80b917b02f | ||
|
|
971361feac | ||
|
|
4a553724a2 | ||
|
|
b87b30f045 | ||
|
|
a3a69f53da | ||
|
|
cb659f3c25 | ||
|
|
8135540b22 | ||
|
|
3a10b6f4db | ||
|
|
82d4a96ae1 | ||
|
|
802091e15e | ||
|
|
4292259db1 | ||
|
|
b6efabf216 | ||
|
|
7e5471bdfa | ||
|
|
5fa5aff813 | ||
|
|
5035ad1d99 | ||
|
|
63797c90f9 | ||
|
|
8b475ea4f2 | ||
|
|
1cd4fb2e73 | ||
|
|
946ea8dfb8 | ||
|
|
8c264fb2a5 | ||
|
|
b96b792612 | ||
|
|
909ea995b6 | ||
|
|
4e197b512f | ||
|
|
aa3f8cce3d | ||
|
|
00dcc29eb1 | ||
|
|
032cec5c5a | ||
|
|
f0d6fedc90 | ||
|
|
bb9ff4f113 | ||
|
|
1c3f6735f8 | ||
|
|
2d210641d3 | ||
|
|
4078d895c7 | ||
|
|
d67820b6ba | ||
|
|
6869047b44 | ||
|
|
ba7c3972b5 | ||
|
|
f931504a09 | ||
|
|
52c18171a1 | ||
|
|
b89dbefb3c | ||
|
|
07936bc8e4 | ||
|
|
647eda7895 | ||
|
|
5bb703084c | ||
|
|
e7283e9105 | ||
|
|
7102d06e73 | ||
|
|
cf54dee88e | ||
|
|
854864ac5e | ||
|
|
3eceaae45f | ||
|
|
e604a8cba0 | ||
|
|
c351acb075 | ||
|
|
d880efc1db | ||
|
|
e118b293fd | ||
|
|
254996063a | ||
|
|
a85b2ac301 | ||
|
|
d2f8471943 | ||
|
|
b6a7a3bc1e | ||
|
|
3b3007cbdd | ||
|
|
adc2092275 | ||
|
|
96831f2d4e | ||
|
|
baf8664d10 | ||
|
|
d071fd5397 | ||
|
|
44f9415811 | ||
|
|
b104364edb | ||
|
|
24bbf0ead9 | ||
|
|
e8af292958 | ||
|
|
8d14f83bc3 | ||
|
|
2ff89167c2 | ||
|
|
87854bbdf0 | ||
|
|
a3a4a972d7 | ||
|
|
75fbb709d7 | ||
|
|
6a311347bf | ||
|
|
2787fdd8b6 | ||
|
|
634f5c26ee | ||
|
|
e7683ac3ff | ||
|
|
72fdc3bcfe | ||
|
|
271977d1dd | ||
|
|
e6de090ed3 | ||
|
|
65c0224ae5 | ||
|
|
e957a4c99a | ||
|
|
4e6d5b733c | ||
|
|
be591b2f4a | ||
|
|
bb90b73533 | ||
|
|
0abf5c2379 | ||
|
|
ccbf55923d | ||
|
|
eb842428b7 | ||
|
|
e61aa736db | ||
|
|
8794afb246 | ||
|
|
ddf32b6215 | ||
|
|
811fe65412 | ||
|
|
0ad73d19ed | ||
|
|
0b845dc7ee | ||
|
|
a0449b4d6b | ||
|
|
7e58e1f299 | ||
|
|
f1da8c3cb7 | ||
|
|
b87f0124b7 | ||
|
|
49db4cdea8 | ||
|
|
b72e0a2270 | ||
|
|
67965bc275 | ||
|
|
a8abee1422 | ||
|
|
4da5e94adf | ||
|
|
0dc1e71148 | ||
|
|
7d7972d54c | ||
|
|
41de512cdc | ||
|
|
07b2b1f28c | ||
|
|
2aec49d0e5 | ||
|
|
ffa50d43c5 | ||
|
|
ef06f5a746 | ||
|
|
aebdbe5ca8 | ||
|
|
acdcfc14fb | ||
|
|
a6b403e667 | ||
|
|
e5cfe80029 | ||
|
|
798ac9dd69 | ||
|
|
64a05e2f14 | ||
|
|
99f5843c42 | ||
|
|
3f2250e51f | ||
|
|
9a3de0103d | ||
|
|
1848ef4905 | ||
|
|
b8725ec9aa | ||
|
|
fbba2eb1db | ||
|
|
8554a1fcfc | ||
|
|
2f4e189f93 | ||
|
|
1bda13aec0 | ||
|
|
49cadac789 | ||
|
|
1fe9f3a068 | ||
|
|
edb102f7a2 | ||
|
|
97376b36bc | ||
|
|
41d88b0c4a | ||
|
|
ee70a44f8b | ||
|
|
1be715f322 | ||
|
|
519319c9b2 | ||
|
|
76abe671e4 | ||
|
|
14541394dc | ||
|
|
78a10f89ed | ||
|
|
f20a9fd2ed | ||
|
|
18eb48735d | ||
|
|
6274ba8169 | ||
|
|
a63bae227e | ||
|
|
6182590829 | ||
|
|
0922bcb903 | ||
|
|
6eb62664a5 | ||
|
|
dcb2072f36 | ||
|
|
c5bd1a9ce9 | ||
|
|
da1192bd01 | ||
|
|
ef3b917f5e | ||
|
|
2f32bcbb8f | ||
|
|
71adf60a71 | ||
|
|
8a1c51317c | ||
|
|
783d01dd6f | ||
|
|
8c0567146f | ||
|
|
8a4a98fa27 | ||
|
|
37b363a92f | ||
|
|
e9352d0506 | ||
|
|
6efdcdb2b9 | ||
|
|
eb1355c65a | ||
|
|
a135938588 | ||
|
|
aa3bf1ef51 | ||
|
|
dcaa13b20e | ||
|
|
c8ca146d0c | ||
|
|
8c73edb584 | ||
|
|
f6a52704d9 | ||
|
|
40814bc323 | ||
|
|
c652f8050a | ||
|
|
c4bf441fc1 | ||
|
|
8ca9add11e | ||
|
|
1282de3d05 | ||
|
|
f25c40bb08 | ||
|
|
3e12aa3492 | ||
|
|
78696adb53 | ||
|
|
51f649da8a | ||
|
|
3a35e59691 | ||
|
|
f5098784d7 | ||
|
|
b3a21eaa52 | ||
|
|
cc1d92e62f | ||
|
|
3fdc34e286 | ||
|
|
10f3eaad39 | ||
|
|
ede46bd1e0 | ||
|
|
aa48af32ea | ||
|
|
4acc6d1114 | ||
|
|
b781be3cc2 | ||
|
|
d4a04a5055 | ||
|
|
3985301749 | ||
|
|
51b6d2536d | ||
|
|
2001353e9e | ||
|
|
da8b5f62d2 | ||
|
|
1cf47a06c6 | ||
|
|
2b68b022f9 | ||
|
|
7de2e883b9 | ||
|
|
ebbbfa1998 | ||
|
|
7120aede15 | ||
|
|
14c93f0e96 |
130
.agents/skills/aoti-debug/SKILL.md
Normal file
130
.agents/skills/aoti-debug/SKILL.md
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
name: aoti-debug
|
||||
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
|
||||
---
|
||||
|
||||
# AOTInductor Debugging
|
||||
|
||||
Debug errors when compiling and deploying PyTorch models with AOTInductor.
|
||||
|
||||
## First Step: Always Check Device and Shape Matching
|
||||
|
||||
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
|
||||
|
||||
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
|
||||
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
|
||||
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
|
||||
|
||||
```python
|
||||
# Compilation -- note the device and shapes
|
||||
model = MyModel().eval().cuda()
|
||||
inp = torch.randn(2, 10, device="cuda")
|
||||
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
|
||||
|
||||
# Loading -- device type MUST match compilation
|
||||
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
|
||||
|
||||
# Inference -- device and shapes MUST match
|
||||
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
|
||||
```
|
||||
|
||||
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
|
||||
|
||||
## Current vs Deprecated API
|
||||
|
||||
### Current API (use this)
|
||||
```python
|
||||
torch._inductor.aoti_compile_and_package() # compile
|
||||
torch._inductor.aoti_load_package() # load (auto-detects device)
|
||||
```
|
||||
|
||||
### Deprecated API (migrate away)
|
||||
```python
|
||||
torch._export.aot_compile() # deprecated
|
||||
torch._export.aot_load() # deprecated
|
||||
```
|
||||
|
||||
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
|
||||
|
||||
## Common Error Patterns
|
||||
|
||||
### Device Mismatch Segfault
|
||||
|
||||
**Symptom**: Segfault, exception, or crash during load or execution.
|
||||
|
||||
**Example errors**:
|
||||
- `The specified pointer resides on host memory and is not registered with any CUDA device`
|
||||
- Crash during constant loading
|
||||
- `Expected out tensor to have device cuda:0, but got cpu instead`
|
||||
|
||||
**Solution**: Ensure compile and load use the same device type.
|
||||
|
||||
### Input Device Mismatch at Runtime
|
||||
|
||||
**Symptom**: RuntimeError during model execution.
|
||||
|
||||
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
|
||||
```
|
||||
|
||||
Produces actionable messages like:
|
||||
```
|
||||
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
|
||||
```
|
||||
|
||||
## Debugging CUDA Illegal Memory Access (IMA)
|
||||
|
||||
### Step 1: Sanity Checks
|
||||
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
|
||||
```
|
||||
|
||||
Both flags take effect at **compile time** (codegen time).
|
||||
|
||||
### Step 2: Make IMA Deterministic
|
||||
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
|
||||
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
|
||||
|
||||
Both take effect at **runtime**.
|
||||
|
||||
### Step 3: Identify the Problematic Kernel
|
||||
|
||||
```bash
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
|
||||
```
|
||||
|
||||
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
|
||||
|
||||
To inspect inputs to specific kernels:
|
||||
```bash
|
||||
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
|
||||
```
|
||||
|
||||
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
| Variable | When | Purpose |
|
||||
|---|---|---|
|
||||
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
|
||||
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
|
||||
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
|
||||
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
|
||||
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
|
||||
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
|
||||
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
|
||||
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
|
||||
|
||||
## Common Sources of Issues
|
||||
|
||||
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
|
||||
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.
|
||||
195
.agents/skills/pt2-debug/SKILL.md
Normal file
195
.agents/skills/pt2-debug/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: pt2-debug
|
||||
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
|
||||
---
|
||||
|
||||
# PyTorch 2 Compile Debugging
|
||||
|
||||
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
|
||||
|
||||
## Diagnostic Environment Variables
|
||||
|
||||
Pick the right diagnostic based on the error:
|
||||
|
||||
| Command | When to use |
|
||||
|---|---|
|
||||
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
|
||||
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
|
||||
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
|
||||
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
|
||||
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
|
||||
|
||||
## Error Triage
|
||||
|
||||
Classify the failure and jump to the right section:
|
||||
|
||||
| Error Pattern | Category |
|
||||
|---|---|
|
||||
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
|
||||
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
|
||||
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
|
||||
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
|
||||
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
|
||||
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
|
||||
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
|
||||
|
||||
## Graph Breaks
|
||||
|
||||
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="graph_breaks" python script.py
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Data-dependent control flow
|
||||
- Unsupported Python builtins
|
||||
- In-place ops on inputs, unsupported dtypes
|
||||
- Calls to non-traceable functions
|
||||
|
||||
**Fix approaches:**
|
||||
1. Read the graph break message to identify the unsupported operation
|
||||
2. Check for a decomposition or supported alternative
|
||||
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
|
||||
|
||||
## Backend Compiler Failures
|
||||
|
||||
`BackendCompilerFailed` means Inductor crashed during compilation.
|
||||
|
||||
**Diagnose with the minifier:**
|
||||
```bash
|
||||
# Generate minifier launcher
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
|
||||
# Run the minifier to get minimal failing graph
|
||||
python minifier_launcher.py minify
|
||||
|
||||
# Run the minimized reproduction
|
||||
python minifier_launcher.py run
|
||||
```
|
||||
|
||||
**Then inspect:**
|
||||
```bash
|
||||
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
|
||||
```
|
||||
|
||||
## Recompilation Issues
|
||||
|
||||
Excessive recompilation from guards that are too specific, causing cache misses.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
|
||||
```
|
||||
|
||||
**Key config:**
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit # default: 8
|
||||
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Changing tensor shapes without marking them dynamic
|
||||
- Python scalar values that change between calls
|
||||
- Global state mutations between calls
|
||||
|
||||
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
|
||||
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
|
||||
- Fix the source of guard instability
|
||||
|
||||
## Accuracy Issues
|
||||
|
||||
Compiled model produces different numerical results than eager mode.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
# Compares compiled vs eager with fp64 reference, dumps repro on failure
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get minimal failing graph from the minifier
|
||||
2. Compare eager vs compiled output at fp64 precision
|
||||
3. Binary search through ops to find the diverging operation
|
||||
4. Check for known issues: reduction order, fused kernels, dtype promotions
|
||||
|
||||
## Internal Dynamo Errors
|
||||
|
||||
`InternalTorchDynamoError` indicates a bug in Dynamo.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCHDYNAMO_VERBOSE=1 python script.py
|
||||
# or equivalently:
|
||||
TORCH_LOGS="+dynamo" python script.py
|
||||
```
|
||||
|
||||
**Debug interactively:**
|
||||
```bash
|
||||
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
|
||||
```
|
||||
|
||||
## Runtime Crashes
|
||||
|
||||
Segfaults and CUDA illegal memory access during execution of compiled code.
|
||||
|
||||
**Make crash deterministic:**
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
**Add NaN checks to find the first bad kernel:**
|
||||
```bash
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
|
||||
```
|
||||
|
||||
**Inductor sync debugging:**
|
||||
```python
|
||||
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
|
||||
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
|
||||
2. Check input shapes, devices, dtypes
|
||||
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
|
||||
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
|
||||
5. Dynamic shapes are historically a common source of IMA
|
||||
|
||||
## Triton Kernel Failures
|
||||
|
||||
Triton assertion failures or index-out-of-bounds in generated kernels.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="output_code,schedule" python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get the generated Triton kernel from `output_code` logs
|
||||
2. Check index computations for off-by-one or wrong stride calculations
|
||||
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
|
||||
4. Check if fusion decisions created invalid index combinations
|
||||
|
||||
## Distinguish Trace-Time vs Runtime
|
||||
|
||||
Many bugs come from confusing these:
|
||||
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
|
||||
- **Runtime**: Real tensors, real Python calls.
|
||||
|
||||
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
|
||||
|
||||
## Using the Minifier
|
||||
|
||||
The minifier reduces a failing graph to the smallest reproduction:
|
||||
|
||||
```bash
|
||||
# For compilation failures (level 2)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
python minifier_launcher.py minify
|
||||
python minifier_launcher.py run
|
||||
|
||||
# For accuracy failures (level 4)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
134
.agents/skills/ruff/SKILL.md
Normal file
134
.agents/skills/ruff/SKILL.md
Normal file
@@ -0,0 +1,134 @@
|
||||
---
|
||||
name: ruff
|
||||
description:
|
||||
Guide for using ruff, the extremely fast Python linter and formatter. Use this
|
||||
when linting, formatting, or fixing Python code.
|
||||
---
|
||||
|
||||
# ruff
|
||||
|
||||
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
|
||||
isort, Black, pyupgrade, autoflake, and dozens of other tools.
|
||||
|
||||
## When to use ruff
|
||||
|
||||
**Always use ruff for Python linting and formatting**, especially if you see:
|
||||
|
||||
- `[tool.ruff]` section in `pyproject.toml`
|
||||
- A `ruff.toml` or `.ruff.toml` configuration file
|
||||
|
||||
However, avoid making unnecessary changes:
|
||||
|
||||
- **Don't format unformatted code** - If `ruff format --diff` shows changes
|
||||
throughout an entire file, the project likely isn't using ruff for formatting.
|
||||
Skip formatting to avoid obscuring actual changes.
|
||||
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
|
||||
relevant to the code you're changing. Only apply fixes to files you're
|
||||
modifying unless the user explicitly asks for broader fixes.
|
||||
|
||||
## How to invoke ruff
|
||||
|
||||
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
|
||||
you use the pinned version
|
||||
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
|
||||
one-off checks
|
||||
- `ruff ...` - Use if ruff is installed globally
|
||||
|
||||
## Commands
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
ruff check . # Check all files in current directory
|
||||
ruff check path/to/file.py # Check specific file
|
||||
ruff check --fix . # Auto-fix fixable violations
|
||||
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
|
||||
ruff check --watch . # Watch for changes and re-lint
|
||||
ruff check --select E,F . # Only check specific rules
|
||||
ruff check --ignore E501 . # Ignore specific rules
|
||||
ruff rule E501 # Explain a specific rule
|
||||
ruff linter # List available linters
|
||||
```
|
||||
|
||||
### Formatting
|
||||
|
||||
```bash
|
||||
ruff format . # Format all files
|
||||
ruff format path/to/file.py # Format specific file
|
||||
ruff format --check . # Check if files are formatted (no changes)
|
||||
ruff format --diff . # Show formatting diff without applying
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Ruff is configured in `pyproject.toml` or `ruff.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP"] # Enable specific rule sets
|
||||
ignore = ["E501"] # Ignore specific rules
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["myproject"]
|
||||
```
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### Black → ruff format
|
||||
|
||||
```bash
|
||||
black . → ruff format .
|
||||
black --check . → ruff format --check .
|
||||
black --diff . → ruff format --diff .
|
||||
```
|
||||
|
||||
### Flake8 → ruff check
|
||||
|
||||
```bash
|
||||
flake8 . → ruff check .
|
||||
flake8 --select E,F . → ruff check --select E,F .
|
||||
flake8 --ignore E501 . → ruff check --ignore E501 .
|
||||
```
|
||||
|
||||
### isort → ruff check
|
||||
|
||||
```bash
|
||||
isort . → ruff check --select I --fix .
|
||||
isort --check . → ruff check --select I .
|
||||
isort --diff . → ruff check --select I --diff .
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Apply lint fixes before formatting
|
||||
|
||||
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
|
||||
structure (e.g., reordering imports), which formatting then cleans up.
|
||||
|
||||
```bash
|
||||
ruff check --fix .
|
||||
ruff format .
|
||||
```
|
||||
|
||||
### Applying and reviewing unsafe fixes
|
||||
|
||||
Ruff categorizes some auto-fixes as "unsafe" because they may change code
|
||||
behavior, not just style. For example, removing unused imports could break code
|
||||
that relies on side effects.
|
||||
|
||||
```bash
|
||||
ruff check --fix --unsafe-fixes --diff . # Preview changes first
|
||||
ruff check --fix --unsafe-fixes . # Apply changes
|
||||
```
|
||||
|
||||
**Always review changes before applying `--unsafe-fixes`:**
|
||||
|
||||
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
|
||||
- Verify the fix doesn't violate those assumptions in your code
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ruff/
|
||||
135
.agents/skills/ty/SKILL.md
Normal file
135
.agents/skills/ty/SKILL.md
Normal file
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: ty
|
||||
description:
|
||||
Guide for using ty, the extremely fast Python type checker and language
|
||||
server. Use this when type checking Python code or setting up type checking in
|
||||
Python projects.
|
||||
---
|
||||
|
||||
# ty
|
||||
|
||||
ty is an extremely fast Python type checker and language server. It replaces
|
||||
mypy, Pyright, and other type checkers.
|
||||
|
||||
## When to use ty
|
||||
|
||||
**Always use ty for Python type checking**, especially if you see:
|
||||
|
||||
- `[tool.ty]` section in `pyproject.toml`
|
||||
- A `ty.toml` configuration file
|
||||
|
||||
## How to invoke ty
|
||||
|
||||
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
|
||||
use the pinned version or when ty is installed globally and you are in a
|
||||
project so the virtual environment is updated.
|
||||
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
|
||||
checks
|
||||
|
||||
## Commands
|
||||
|
||||
### Type checking
|
||||
|
||||
```bash
|
||||
ty check # Check all files in current directory
|
||||
ty check path/to/file.py # Check specific file
|
||||
ty check src/ # Check specific directory
|
||||
```
|
||||
|
||||
### Rule configuration
|
||||
|
||||
```bash
|
||||
ty check --error possibly-unresolved-reference # Treat as error
|
||||
ty check --warn division-by-zero # Treat as warning
|
||||
ty check --ignore unresolved-import # Disable rule
|
||||
```
|
||||
|
||||
### Python version targeting
|
||||
|
||||
```bash
|
||||
ty check --python-version 3.12 # Check against Python 3.12
|
||||
ty check --python-platform linux # Target Linux platform
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
ty is configured in `pyproject.toml` or `ty.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ty.environment]
|
||||
python-version = "3.12"
|
||||
|
||||
[tool.ty.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
division-by-zero = "error"
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src/**/*.py"]
|
||||
exclude = ["**/migrations/**"]
|
||||
|
||||
[tool.ty.terminal]
|
||||
output-format = "full"
|
||||
error-on-warning = false
|
||||
```
|
||||
|
||||
### Per-file overrides
|
||||
|
||||
Use overrides to apply different rules to specific files, such as relaxing rules
|
||||
for tests or scripts that have different typing requirements than production
|
||||
code:
|
||||
|
||||
```toml
|
||||
[[tool.ty.overrides]]
|
||||
include = ["tests/**", "**/test_*.py"]
|
||||
|
||||
[tool.ty.overrides.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
```
|
||||
|
||||
## Language server
|
||||
|
||||
This plugin automatically configures the ty language server for Python files
|
||||
(`.py` and `.pyi`).
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### mypy → ty
|
||||
|
||||
```bash
|
||||
mypy . → ty check
|
||||
mypy --strict . → ty check --error-on-warning
|
||||
mypy path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
### Pyright → ty
|
||||
|
||||
```bash
|
||||
pyright . → ty check
|
||||
pyright path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't add ignore comments
|
||||
|
||||
Fix type errors instead of suppressing them. Only add ignore comments when
|
||||
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
|
||||
prefer rule-specific ignores:
|
||||
|
||||
```python
|
||||
# Good: rule-specific ignore
|
||||
x = undefined_var # ty: ignore[possibly-unresolved-reference]
|
||||
|
||||
# Bad: blanket ty ignore
|
||||
x = undefined_var # ty: ignore
|
||||
|
||||
# Bad: tool agnostic blanket ignore
|
||||
x = undefined_var # type: ignore
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ty/
|
||||
182
.agents/skills/uv/SKILL.md
Normal file
182
.agents/skills/uv/SKILL.md
Normal file
@@ -0,0 +1,182 @@
|
||||
---
|
||||
name: uv
|
||||
description:
|
||||
Guide for using uv, the Python package and project manager. Use this when
|
||||
working with Python projects, scripts, packages, or tools.
|
||||
---
|
||||
|
||||
# uv
|
||||
|
||||
uv is an extremely fast Python package and project manager. It replaces pip,
|
||||
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
|
||||
|
||||
## When to use uv
|
||||
|
||||
**Always use uv for Python work**, especially if you see:
|
||||
|
||||
- The `uv.lock` file
|
||||
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
|
||||
|
||||
Don't use uv in projects managed by other tools:
|
||||
|
||||
- Poetry projects (identifiable by `poetry.lock` file)
|
||||
- PDM projects (identifiable by `pdm.lock` file)
|
||||
|
||||
## Choosing the right workflow
|
||||
|
||||
### Scripts
|
||||
|
||||
**Use when:** Running single Python files and standalone scripts.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv run script.py # Run a script
|
||||
uv run --with requests script.py # Run with additional packages
|
||||
uv add --script script.py requests # Add dependencies inline to the script
|
||||
```
|
||||
|
||||
### Projects
|
||||
|
||||
**Use when:** There is a `pyproject.toml` or `uv.lock`
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv init # Create new project
|
||||
uv add requests # Add dependency
|
||||
uv remove requests # Remove dependency
|
||||
uv sync # Install from lockfile
|
||||
uv run <command> # Run commands in environment
|
||||
uv run python -c "" # Run Python in project environment
|
||||
uv run -p 3.12 <command> # Run with specific Python version
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
|
||||
installation.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uvx <tool> <args> # Run a tool without installation
|
||||
uvx <tool>@<version> <args> # Run a specific version of a tool
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
|
||||
well-known tools.
|
||||
- Only use `uv tool install` only when specifically requested by the user.
|
||||
|
||||
### Pip interface
|
||||
|
||||
**Use when:** Legacy workflows with `requirements.txt` or manual environment
|
||||
management, no `uv.lock` present.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
uv pip install -r requirements.txt
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
uv pip sync requirements.txt
|
||||
|
||||
# Platform independent resolution
|
||||
uv pip compile --universal requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- Don't use the pip interface unless clearly needed.
|
||||
- Don't introduce new `requirements.txt` files.
|
||||
- Prefer `uv init` for new projects.
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### pyenv → uv python
|
||||
|
||||
```bash
|
||||
pyenv install 3.12 → uv python install 3.12
|
||||
pyenv versions → uv python list --only-installed
|
||||
pyenv local 3.12 → uv python pin 3.12
|
||||
pyenv global 3.12 → uv python install 3.12 --default
|
||||
```
|
||||
|
||||
### pipx → uvx
|
||||
|
||||
```bash
|
||||
pipx run ruff → uvx ruff
|
||||
pipx install ruff → uv tool install ruff
|
||||
pipx upgrade ruff → uv tool upgrade ruff
|
||||
pipx list → uv tool list
|
||||
```
|
||||
|
||||
### pip and pip-tools → uv pip
|
||||
|
||||
```bash
|
||||
pip install package → uv pip install package
|
||||
pip install -r req.txt → uv pip install -r req.txt
|
||||
pip freeze → uv pip freeze
|
||||
pip-compile req.in → uv pip compile req.in
|
||||
pip-sync req.txt → uv pip sync req.txt
|
||||
virtualenv .venv → uv venv
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't use pip in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
pip install requests
|
||||
|
||||
# Good
|
||||
uv add requests
|
||||
```
|
||||
|
||||
### Don't run python directly
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python script.py
|
||||
|
||||
# Good
|
||||
uv run script.py
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -c "..."
|
||||
|
||||
# Good
|
||||
uv run python -c "..."
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python3.12 -c "..."
|
||||
|
||||
# Good
|
||||
uvx python@3.12 -c "..."
|
||||
```
|
||||
|
||||
### Don't manually manage environments in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Good
|
||||
uv run <command>
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/uv/llms.txt
|
||||
|
||||
The documentation links to specific pages for each of these workflows.
|
||||
4
.cargo/config.toml
Normal file
4
.cargo/config.toml
Normal file
@@ -0,0 +1,4 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
rustflags = [
|
||||
"-Ctarget-feature=+fp16,+fhm"
|
||||
]
|
||||
55
.devcontainer/cpu/devcontainer.json
Normal file
55
.devcontainer/cpu/devcontainer.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"name": "Luminal (CPU)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--env-file", ".env"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
59
.devcontainer/cuda/devcontainer.json
Normal file
59
.devcontainer/cuda/devcontainer.json
Normal file
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"name": "Luminal (CUDA)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--env-file",
|
||||
".env",
|
||||
"--runtime=nvidia",
|
||||
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
|
||||
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: CUDA Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: CUDA Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as safe for git
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
23
.github/workflows/fmt.yml
vendored
Normal file
23
.github/workflows/fmt.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Fmt
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
25
.github/workflows/metal-clippy.yml
vendored
Normal file
25
.github/workflows/metal-clippy.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Metal Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
45
.github/workflows/modal-examples.yml
vendored
Normal file
45
.github/workflows/modal-examples.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Modal Examples
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 70
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
# - { type: "H100" }
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
EXAMPLE: ${{ matrix.example }}
|
||||
GPU_TYPE: ${{ matrix.gpu.type }}
|
||||
run: modal run ci/modal_example.py
|
||||
23
.github/workflows/ruff-format.yml
vendored
Normal file
23
.github/workflows/ruff-format.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff Format
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
23
.github/workflows/ruff.yml
vendored
Normal file
23
.github/workflows/ruff.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
22
.github/workflows/rust.yml
vendored
22
.github/workflows/rust.yml
vendored
@@ -1,22 +0,0 @@
|
||||
name: Rust
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
- name: Run tests
|
||||
run: cargo test --verbose
|
||||
24
.github/workflows/test-core.yml
vendored
Normal file
24
.github/workflows/test-core.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Test Core
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
35
.github/workflows/test-cuda.yml
vendored
Normal file
35
.github/workflows/test-cuda.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Test CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_unit_test:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
19
.github/workflows/test-metal.yml
vendored
Normal file
19
.github/workflows/test-metal.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Test Metal
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
47
.github/workflows/test-python-cuda.yml
vendored
Normal file
47
.github/workflows/test-python-cuda.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Test Python CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_cuda_tests:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run pytest with CUDA backend on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
retention-days: 7
|
||||
if-no-files-found: warn
|
||||
28
.github/workflows/test-python-native.yml
vendored
Normal file
28
.github/workflows/test-python-native.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Test Python Native
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
39
.gitignore
vendored
39
.gitignore
vendored
@@ -1,7 +1,42 @@
|
||||
/target
|
||||
/crates/**/target
|
||||
/examples/**/target
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.env
|
||||
.claude/
|
||||
.DS_Store
|
||||
.vscode
|
||||
*.vscode
|
||||
*.zed
|
||||
Cargo.lock
|
||||
*.st
|
||||
*.npx
|
||||
*.npz
|
||||
*.npz
|
||||
*.model
|
||||
*.gguf
|
||||
|
||||
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.pftrace
|
||||
*.safetensors
|
||||
*.safetensors.index.json
|
||||
tokenizer.json
|
||||
**/.cache
|
||||
**/proptest-regressions
|
||||
opencode.json
|
||||
|
||||
# Python build artifacts
|
||||
*.so
|
||||
*.pyd
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
uv.lock
|
||||
|
||||
38
.pre-commit-config.yaml
Normal file
38
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.5
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
name: ruff check
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: cargo-fmt
|
||||
name: cargo fmt
|
||||
entry: cargo fmt --all --check
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy
|
||||
name: cargo clippy
|
||||
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy-metal
|
||||
name: cargo clippy metal
|
||||
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
- id: cargo-clippy-cuda-lite
|
||||
name: cargo clippy cuda_lite
|
||||
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
11
AGENTS.md
Normal file
11
AGENTS.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# Contributor Guide
|
||||
|
||||
## Structure
|
||||
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
|
||||
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
34
CLAUDE.md
Normal file
34
CLAUDE.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Luminal
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
|
||||
- Use `uv sync` to sync the Python environment
|
||||
- Never use pip, pip-tools, poetry, or conda
|
||||
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
|
||||
- Never generate requirements.txt
|
||||
|
||||
## Code Execution
|
||||
|
||||
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
|
||||
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
|
||||
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
|
||||
|
||||
## Building the Python Package (Maturin)
|
||||
|
||||
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
|
||||
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
|
||||
|
||||
## Pre-commit
|
||||
|
||||
- Run with: `uv run pre-commit run --all-files`
|
||||
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
|
||||
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
|
||||
|
||||
## Testing
|
||||
|
||||
- **Rust tests**: `cargo test -p <crate_name>`
|
||||
- **Python tests**: `cd crates/luminal_python && uv run pytest`
|
||||
- `./run_test.sh` — native backend
|
||||
- `./run_tests_cuda.sh` — CUDA backend
|
||||
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions
|
||||
64
Cargo.toml
64
Cargo.toml
@@ -1,31 +1,57 @@
|
||||
[package]
|
||||
name = "luminal"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
version = "0.2.0"
|
||||
edition.workspace = true
|
||||
rust-version = "1.85"
|
||||
description = "Deep learning at the speed of light."
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
#default = ["cuda"]
|
||||
cuda = ["dep:cudarc"]
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.11.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
num-traits = "0.2.16"
|
||||
petgraph = {path="./resources/petgraph"}
|
||||
rand = "0.8.5"
|
||||
strum = { version = "0.25.0", features = ["derive"] }
|
||||
petgraph = "0.6.4"
|
||||
rand = "0.9.2"
|
||||
urlencoding = "2.1.2"
|
||||
webbrowser = "0.8.10"
|
||||
open = "5"
|
||||
dyn-clone = "1.0.12"
|
||||
cudarc = {version="0.9.13", optional=true}
|
||||
safetensors = "0.3.1"
|
||||
memmap2 = "0.7.1"
|
||||
half = "2.3.1"
|
||||
half = {version="2.7.1", features=["num-traits"]}
|
||||
tinyvec = { version = "1.6.0", features = ["serde"] }
|
||||
colored = "2.0.4"
|
||||
regex = "1.9.5"
|
||||
rustc-hash = "2.1.1"
|
||||
as-any = "0.3.1"
|
||||
serde = { version = "1.0.202", features = ["derive"] }
|
||||
generational-box = "0.5.6"
|
||||
serde_json = "1.0.140"
|
||||
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
|
||||
tracing = "0.1.43"
|
||||
paste = "1.0.15"
|
||||
pretty-duration = "0.1.1"
|
||||
anyhow = "1.0"
|
||||
graphviz-rust = { version = "0.9", default-features = false}
|
||||
lru = "0.16.2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = "0.13"
|
||||
tokenizers = "0.13.3"
|
||||
candle-core = "0.9.2"
|
||||
candle-nn = "0.9.2"
|
||||
ordered-float = "5.1.0"
|
||||
proptest = "1.9.0"
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
"examples/*",
|
||||
"crates/luminal_nn",
|
||||
"crates/luminal_cuda_lite",
|
||||
"crates/luminal_metal",
|
||||
"crates/luminal_tracing",
|
||||
"crates/luminal_bench",
|
||||
"crates/luminal_python/rust",
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }
|
||||
|
||||
201
LICENSE-APACHE
Normal file
201
LICENSE-APACHE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
21
LICENSE-MIT
Normal file
21
LICENSE-MIT
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Joe Fioti
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
150
README.md
150
README.md
@@ -1,76 +1,126 @@
|
||||
# luminal
|
||||

|
||||
**Deep learning at the speed of light.**
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/user-attachments/assets/c5832634-55d5-45b7-ba65-6efe36afce4a" />
|
||||
|
||||
Luminal is a deep learning library that prioritizes **static computation** and **operator fusion** to achieve high performance.
|
||||
<h3 align="center">
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
</h3>
|
||||
|
||||
[](https://github.com/jafioti/luminal/actions)
|
||||
[](https://docs.luminalai.com)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/APjuwHAbGy)
|
||||
|
||||
## Usage
|
||||
|
||||
```rust
|
||||
use luminal::prelude::*;
|
||||
|
||||
// Setup graph and tensors
|
||||
// Create compute graph
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.new_tensor::<R2<3, 1>>("A");
|
||||
let b = cx.new_tensor::<R2<1, 4>>("B");
|
||||
let a = cx.tensor((3, 1));
|
||||
let b = cx.tensor((1, 4));
|
||||
|
||||
// Do stuff...
|
||||
let c = a.matmul(b);
|
||||
let c = a.matmul(b).output();
|
||||
|
||||
// Set inputs and mark outputs
|
||||
a.set(vec![1.0, 2.0, 3.0]);
|
||||
b.set(vec![1.0, 2.0, 3.0, 3.0]);
|
||||
c.mark();
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Optimize and run graph
|
||||
cx.optimize(GenericOptimizer::default());
|
||||
cx.execute();
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
rt.set_data(b, vec![1.0, 2.0, 3.0, 3.0]);
|
||||
|
||||
// Get result
|
||||
println!("Result: {:?}", c.retrieve().unwrap().data);
|
||||
// Run
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Get output tensor
|
||||
println!("Result: {:?}", rt.get_f32(c));
|
||||
```
|
||||
|
||||
## Why does this look so different from other DL libraries?
|
||||
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. So when you see `x + y`, the addition actually happens right there. This is great for debugging, it works exactly as most developers expect.
|
||||
## Getting Started
|
||||
|
||||
However, this isn't great for performance because what makes sense for a developer doesn't make sense for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://pytorch.org/docs/stable/dynamo/index.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
**Llama 3 8B**
|
||||
|
||||
Luminal takes a different approach, more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Here everything's static. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. *But isn't that just lazy execution?* Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, and executed later.
|
||||
Here's a quick example of how you can run Llama 3 8B locally using Luminal on CUDA:
|
||||
```bash
|
||||
cd ./examples/llama
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## But Why?
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our optimizers have global knowledge and can do much more aggressive optimization **without any sync points**.
|
||||
## Features
|
||||
|
||||
Of course, we can still split the network into multiple seperate graphs if we want to insert dynamic control flow part-way through, which means this method doesn't preclude optimizations like KV caching, because the KV cached forward pass is just a seperate graph!
|
||||
### Speed
|
||||
|
||||
Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100. The goal is to become the fastest ML framework for any model on any device.
|
||||
|
||||
### Simplicity
|
||||
|
||||
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
### RISC-style architecture
|
||||
|
||||
Everything in Luminal boils down to 14 primitive ops:
|
||||
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
|
||||
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model.
|
||||
|
||||
### Search
|
||||
|
||||
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
|
||||
|
||||
### Native
|
||||
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
|
||||
|
||||
## Ideology
|
||||
|
||||
### Why does this look so different from other DL libraries?
|
||||
|
||||
Most deep learning libraries are eager-first, meaning each op call directly operates on the data. In PyTorch, when you see `x + y`, the addition actually happens right there. This is great for debugging because it works exactly as most developers expect.
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### Compile everything
|
||||
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
|
||||
|
||||
Now we can do:
|
||||
|
||||
Some huge benefits are now unlocked:
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Devices and Dtypes are handled through optimizers (just run the CUDA optimizer to convert the graph to use CUDA kernels, then the fp16 optimizer to convert to half-precision kernels)
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
|
||||
## RISC-style architecture
|
||||
Luminal can be ran on new accelerators by implementing 11 primitive ops. Take a look at `src/optimizers/cuda/prim.rs` to see 1-to-1 CUDA translations of the primops.
|
||||
|
||||
Accellerators are free to implement their own custom ops, and their own optimizers to convert luminal primitive ops to their bespoke ops.
|
||||
|
||||
## Compile-time Shape Checks
|
||||
All operations are shape checked at compile time, so no more shape mismatches! All credit for this goes to [dfdx](https://github.com/coreylowman/dfdx).
|
||||
|
||||
## View the Graph
|
||||
Once you've written all your computation code, run `cx.display_graph()` to see the entire computation graph in all it's glory. Pretty messy looking! Now run `cx.optimize(GeneralOptimizer::default())` and display the graph again. Much better.
|
||||
|
||||
## Where are we?
|
||||
Currently luminal is extremely alpha. Please don't use this in prod.
|
||||
|
||||
- Llama 1 is implemented in `examples/llama`. You'll need to follow the instructions in [llama-dfdx](https://github.com/coreylowman/llama-dfdx) to download and convert the llama weights, and point this example loading path at them.
|
||||
- The llama example shows how to implement a loader for a custom format. Safetensors loaders are already implemented, and are the recommended way to load a model.
|
||||
- We have a small library of NN modules in `nn`, including transformers.
|
||||
- A signifigant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the tinygrad ops set.
|
||||
- Currently there are very few optimizers, so primops are mostly used to run these models, which are very slow.
|
||||
- Next release will bring a signifigant amount of optimizers which should fuse primops into much faster ops. The aim for 0.2 is to be usably fast, not SOTA yet.
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
- Write common sense cuda ops and optimizer (matmuls, mul-add, etc.)
|
||||
|
||||
- Expand the search space to utilize Tensor Cores more flexibly
|
||||
- Bring cuda to parity with Metal
|
||||
- Add Blackwell intrinsics, such as TMEM and TMA
|
||||
- Build a ROCm backend
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Write specialized CUDA kernels for full transformer architecture (FlashAttention, etc.)
|
||||
- Automatic differentiation of graphs
|
||||
- Beat PT 2.0 perf on LLM training
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM inference _and_ training
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
## License
|
||||
|
||||
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
BIN
ci/__pycache__/modal_llama.cpython-312.pyc
Normal file
BIN
ci/__pycache__/modal_llama.cpython-312.pyc
Normal file
Binary file not shown.
67
ci/modal_cargo_test.py
Normal file
67
ci/modal_cargo_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
# Detect GPU compute capability
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo", "test",
|
||||
"-p", "luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"CUDA_COMPUTE_CAP": compute_cap,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_cargo_test.remote()
|
||||
67
ci/modal_example.py
Normal file
67
ci/modal_example.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
|
||||
app = modal.App(f"luminal-ci-{example}")
|
||||
|
||||
hf_cache = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
)
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
)
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
subprocess.run(
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_example.remote(example)
|
||||
34
crates/luminal_bench/Cargo.toml
Normal file
34
crates/luminal_bench/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "luminal_bench"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "Universal benchmark infrastructure for Luminal backends"
|
||||
license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
chrono = "0.4"
|
||||
egraph-serialize = { version = "0.3.0", default-features = false }
|
||||
|
||||
# Backend dependencies - optional, enabled via features
|
||||
luminal_metal = { path = "../luminal_metal", optional = true }
|
||||
metal = { version = "0.32", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
rand = "0.9.2"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
metal = ["dep:luminal_metal", "dep:metal"]
|
||||
|
||||
[[bench]]
|
||||
name = "micro"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "patterns"
|
||||
harness = false
|
||||
98
crates/luminal_bench/README.md
Normal file
98
crates/luminal_bench/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# `luminal_bench`
|
||||
|
||||
Benchmarks and debugging utilities for Luminal (Criterion benchmarks + egglog lowering debug).
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
The benches in this crate are typically run with the Metal backend enabled via a feature flag.
|
||||
|
||||
```bash
|
||||
# L1: micro (single op / HLIR primitive)
|
||||
cargo bench -p luminal_bench --features metal --bench micro
|
||||
|
||||
# L2: patterns (composed patterns)
|
||||
cargo bench -p luminal_bench --features metal --bench patterns
|
||||
```
|
||||
|
||||
### Outputs (Criterion)
|
||||
|
||||
After running, common outputs are under:
|
||||
|
||||
- HTML report: `target/criterion/report/index.html`
|
||||
- micro metrics mapping: `target/criterion/bench_metrics.json`
|
||||
- micro full report: `target/criterion/bench_report.json`
|
||||
- patterns metrics mapping: `target/criterion/pattern_metrics.json`
|
||||
- patterns full report: `target/criterion/pattern_report.json`
|
||||
|
||||
These JSON files (constant metrics such as bytes/flops) can be combined with Criterion timing to
|
||||
compute derived throughput metrics (MBU/MFU/etc.).
|
||||
|
||||
## Coverage (Overview)
|
||||
|
||||
### L1 micro (single op)
|
||||
|
||||
Measures single-op performance for HLIR primitives (currently includes):
|
||||
|
||||
- Unary: `Exp2` / `Log2` / `Sin` / `Recip` / `Sqrt`
|
||||
- Binary: `Add` / `Mul` / `Mod` / `LessThan`
|
||||
- Indexing: `Gather` / `Cast`
|
||||
- Reduction: `Sum` / `Max`
|
||||
|
||||
### L2 patterns (composed patterns)
|
||||
|
||||
Covers common composed patterns (currently includes):
|
||||
|
||||
- `MatMul`
|
||||
- `Softmax`
|
||||
- `GeLU`
|
||||
- `Attention`
|
||||
- `LayerNorm` (currently skipped in the Metal bench: requires unsupported HLIR primitives)
|
||||
|
||||
## egglog Debug Tool: `debug_ops`
|
||||
|
||||
`examples/debug_ops.rs` is a general egglog / lowering debug tool to help diagnose:
|
||||
|
||||
- Why a particular HLIR op failed to lower into backend dialect ops (and cleanup triggers
|
||||
`No valid graphs present in the e-graph!`)
|
||||
- Why a particular egglog function fact (e.g. `dtype`) is missing for some nodes
|
||||
|
||||
### Common Commands (Metal examples)
|
||||
|
||||
```bash
|
||||
# Default: print summaries (HLIR/egglog op counts + root) and try build_search_space
|
||||
# (which prints egglog rule match counts)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner
|
||||
|
||||
# Explicit op coverage check: provide HLIR:Backend mapping(s)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --inspect-op Add:MetalAdd
|
||||
|
||||
# Print full analysis output (HLIR-only + Backend+HLIR)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --analyze --inspect-op Add:MetalAdd
|
||||
|
||||
# Trace an egglog function fact for a specific var (HLIR-only)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --trace-fact dtype t24
|
||||
|
||||
# Scan all vars whose op-head is Add, find the first missing dtype, then trace it (HLIR-only)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner \
|
||||
--trace-first-missing-fact dtype --within-op Add
|
||||
|
||||
# Inspect a var's eclass/enodes/children and dtype facts (HLIR-only)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --inspect-var t24
|
||||
|
||||
# Dump the raw egglog program (the `(let tN ...)` program from `hlir_to_egglog`)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner \
|
||||
--dump-egglog target/gelu-inner.egg
|
||||
|
||||
# Export structured JSON (useful for repro/diffing)
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --case gelu-inner --json target/debug_ops.json
|
||||
```
|
||||
|
||||
Notes:
|
||||
- `--trace-fact` can only evaluate functions that exist in the egglog program (e.g. `dtype`).
|
||||
Many values such as shape/strides are encoded as IR term parameters, not as function facts.
|
||||
|
||||
For more options, see:
|
||||
|
||||
```bash
|
||||
cargo run -p luminal_bench --features metal --example debug_ops -- --help
|
||||
```
|
||||
153
crates/luminal_bench/benches/micro.rs
Normal file
153
crates/luminal_bench/benches/micro.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
#![allow(unused)]
|
||||
|
||||
//! Micro benchmark runner using criterion.
|
||||
//!
|
||||
//! Usage and output locations: see `crates/luminal_bench/README.md`.
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal_bench::{
|
||||
BenchMetrics, BenchMetricsMap, BenchResultCollector, BenchmarkBackend, BenchmarkPattern,
|
||||
HardwareSpec, MetalBenchmark, all_micro_patterns,
|
||||
};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn run_metal_pattern_benchmark(
|
||||
c: &mut Criterion,
|
||||
pattern: &dyn BenchmarkPattern,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
use luminal::hlir::Input;
|
||||
use luminal::op::{Runtime, RuntimeStats};
|
||||
use luminal_metal::runtime::MetalRuntime;
|
||||
use rand::Rng;
|
||||
|
||||
let backend_name = MetalBenchmark::name();
|
||||
let pattern_name = pattern.name();
|
||||
let group_name = format!("{}/{}", backend_name, pattern_name);
|
||||
|
||||
let mut group = c.benchmark_group(&group_name);
|
||||
|
||||
for size in pattern.sizes() {
|
||||
// Build graph and run search once per size; the benchmark loop only measures execution.
|
||||
let mut cx = Graph::default();
|
||||
pattern.build_graph(&mut cx, *size);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(Input { .. }) = (*cx.graph[node]).as_any().downcast_ref::<Input>() {
|
||||
let data: Vec<f32> = (0..size.value).map(|_| rng.random::<f32>()).collect();
|
||||
rt.set_data(node, &data);
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, 5);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
if let Some(stats) = rt.execute_with_stats(&cx.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add(pattern_name, size.name, metrics.clone());
|
||||
bench_metrics = Some(metrics);
|
||||
}
|
||||
|
||||
let dyn_map = cx.dyn_map.clone();
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size.name), size, |b, size| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = rt.execute_with_stats(&dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = bench_metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add(pattern_name, size.name, size.value, avg_time_us, metrics);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_micro_benchmarks(c: &mut Criterion) {
|
||||
let hw = MetalBenchmark::hardware_info();
|
||||
|
||||
println!("\n=== Metal Benchmark ===");
|
||||
println!("Device: {}", hw.device_name);
|
||||
println!("Memory: {:.1} GB", hw.memory_gb);
|
||||
if let Some(bw) = hw.peak_bandwidth_gbps {
|
||||
println!("Peak Bandwidth: {:.0} GB/s", bw);
|
||||
}
|
||||
if let Some(tf) = hw.peak_tflops {
|
||||
println!("Peak Compute: {:.1} TFLOPS", tf);
|
||||
}
|
||||
println!();
|
||||
|
||||
let hardware_spec = HardwareSpec {
|
||||
device_name: hw.device_name.clone(),
|
||||
memory_gb: hw.memory_gb,
|
||||
peak_bandwidth_gbps: hw.peak_bandwidth_gbps.unwrap_or(100.0),
|
||||
peak_tflops: hw.peak_tflops.unwrap_or(1.0),
|
||||
};
|
||||
|
||||
let mut metrics_map = BenchMetricsMap::new(hardware_spec.clone());
|
||||
let collector = BenchResultCollector::new(hardware_spec);
|
||||
|
||||
for pattern in all_micro_patterns() {
|
||||
run_metal_pattern_benchmark(c, pattern.as_ref(), &mut metrics_map, &collector);
|
||||
}
|
||||
|
||||
let metrics_path = std::path::Path::new("target/criterion/bench_metrics.json");
|
||||
if let Some(parent) = metrics_path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
if let Err(e) = metrics_map.save(metrics_path) {
|
||||
eprintln!("Warning: Failed to save metrics mapping: {}", e);
|
||||
}
|
||||
|
||||
let report = collector.into_report();
|
||||
report.print_summary();
|
||||
|
||||
let report_path = std::path::Path::new("target/criterion/bench_report.json");
|
||||
if let Err(e) = report.save(report_path) {
|
||||
eprintln!("Warning: Failed to save full report: {}", e);
|
||||
} else {
|
||||
println!("\nReports saved to:");
|
||||
println!(" - {}", metrics_path.display());
|
||||
println!(" - {}", report_path.display());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
fn metal_micro_benchmarks(_c: &mut Criterion) {
|
||||
println!("Metal benchmarks disabled. Run with --features metal");
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default()
|
||||
.sample_size(50)
|
||||
.warm_up_time(std::time::Duration::from_millis(500))
|
||||
.measurement_time(std::time::Duration::from_secs(2));
|
||||
targets = metal_micro_benchmarks
|
||||
}
|
||||
|
||||
criterion_main!(benches);
|
||||
514
crates/luminal_bench/benches/patterns.rs
Normal file
514
crates/luminal_bench/benches/patterns.rs
Normal file
@@ -0,0 +1,514 @@
|
||||
#![allow(unused)]
|
||||
|
||||
//! Pattern benchmark runner using criterion.
|
||||
//!
|
||||
//! Usage and output locations: see `crates/luminal_bench/README.md`.
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal_bench::{
|
||||
ATTENTION_SIZES, BenchMetrics, BenchMetricsMap, BenchResultCollector, BenchmarkBackend,
|
||||
HardwareSpec, MATMUL_SIZES, MetalBenchmark, TRANSFORMER_SIZES,
|
||||
};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal::hlir::Input;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal::op::{Runtime, RuntimeStats};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use luminal_metal::runtime::MetalRuntime;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use rand::Rng;
|
||||
|
||||
// ============================================================================
|
||||
// Helper: Prepare runtime with graph and search (done once per size)
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
struct PreparedBench {
|
||||
rt: MetalRuntime,
|
||||
dyn_map: luminal::prelude::FxHashMap<char, usize>,
|
||||
metrics: Option<BenchMetrics>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
for (node, size) in input_sizes {
|
||||
let data: Vec<f32> = (0..*size).map(|_| rng.random::<f32>()).collect();
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, 5);
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
dyn_map: cx.dyn_map.clone(),
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MatMul Benchmark
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn bench_matmul(
|
||||
c: &mut Criterion,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
let mut group = c.benchmark_group("metal/matmul");
|
||||
|
||||
for size in MATMUL_SIZES {
|
||||
let size_name = size.name;
|
||||
let (m, k, n) = (size.m, size.k, size.n);
|
||||
|
||||
// Build graph and run search once per size; the benchmark loop only measures execution.
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor((m, k));
|
||||
let b_tensor = cx.tensor((k, n));
|
||||
let _ = a.matmul(b_tensor).output();
|
||||
|
||||
let input_sizes: Vec<(NodeIndex, usize)> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
|
||||
Some((node, m * k.max(k * n)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
|
||||
println!("error: Skipping matmul/{} - search failed", size_name);
|
||||
continue;
|
||||
};
|
||||
|
||||
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
|
||||
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add("matmul", size_name, metrics.clone());
|
||||
prepared.metrics = Some(metrics);
|
||||
}
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = prepared.metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add("matmul", size_name, m * k * n, avg_time_us, metrics);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Softmax Benchmark
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn bench_softmax(
|
||||
c: &mut Criterion,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
let mut group = c.benchmark_group("metal/softmax");
|
||||
|
||||
for size in TRANSFORMER_SIZES {
|
||||
let size_name = size.name;
|
||||
let size_value = size.value;
|
||||
|
||||
let dim = (size_value as f64).sqrt() as usize;
|
||||
let rows = size_value / dim;
|
||||
let cols = dim;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((rows, cols));
|
||||
let _ = x.softmax(1).output();
|
||||
|
||||
let input_sizes: Vec<(NodeIndex, usize)> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
|
||||
Some((node, size_value))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
|
||||
println!("error: Skipping softmax/{} - search failed", size_name);
|
||||
continue;
|
||||
};
|
||||
|
||||
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
|
||||
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add("softmax", size_name, metrics.clone());
|
||||
prepared.metrics = Some(metrics);
|
||||
}
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = prepared.metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add("softmax", size_name, size_value, avg_time_us, metrics);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LayerNorm Benchmark
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn bench_layer_norm(
|
||||
c: &mut Criterion,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
let mut group = c.benchmark_group("metal/layer_norm");
|
||||
|
||||
for size in TRANSFORMER_SIZES {
|
||||
let size_name = size.name;
|
||||
let size_value = size.value;
|
||||
|
||||
// Typical shape: (batch * seq_len, hidden_dim)
|
||||
let hidden_dim = 128;
|
||||
let batch_seq = (size_value / hidden_dim).max(1);
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((batch_seq, hidden_dim));
|
||||
// LayerNorm along last axis with epsilon
|
||||
let _ = x.layer_norm(1, 1e-5).output();
|
||||
|
||||
let input_sizes: Vec<(NodeIndex, usize)> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
|
||||
Some((node, batch_seq * hidden_dim))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
|
||||
println!("error: Skipping layer_norm/{} - search failed", size_name);
|
||||
continue;
|
||||
};
|
||||
|
||||
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
|
||||
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add("layer_norm", size_name, metrics.clone());
|
||||
prepared.metrics = Some(metrics);
|
||||
}
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = prepared.metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add(
|
||||
"layer_norm",
|
||||
size_name,
|
||||
batch_seq * hidden_dim,
|
||||
avg_time_us,
|
||||
metrics,
|
||||
);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GeLU Benchmark
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn bench_gelu(
|
||||
c: &mut Criterion,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
let mut group = c.benchmark_group("metal/gelu");
|
||||
|
||||
for size in TRANSFORMER_SIZES {
|
||||
let size_name = size.name;
|
||||
let size_value = size.value;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor(size_value);
|
||||
let _ = x.gelu().output();
|
||||
|
||||
let input_sizes: Vec<(NodeIndex, usize)> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
|
||||
Some((node, size_value))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
|
||||
println!("error: Skipping gelu/{} - search failed", size_name);
|
||||
continue;
|
||||
};
|
||||
|
||||
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
|
||||
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add("gelu", size_name, metrics.clone());
|
||||
prepared.metrics = Some(metrics);
|
||||
}
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size_name), &size, |b, _| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = prepared.metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add("gelu", size_name, size_value, avg_time_us, metrics);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Attention Benchmark
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn bench_attention(
|
||||
c: &mut Criterion,
|
||||
metrics_map: &mut BenchMetricsMap,
|
||||
collector: &BenchResultCollector,
|
||||
) {
|
||||
let mut group = c.benchmark_group("metal/attention");
|
||||
|
||||
for (seq_len, head_dim) in ATTENTION_SIZES {
|
||||
let size_name = format!("{}x{}", seq_len, head_dim);
|
||||
let seq_len = *seq_len;
|
||||
let head_dim = *head_dim;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q = cx.tensor((seq_len, head_dim));
|
||||
let k = cx.tensor((seq_len, head_dim));
|
||||
let v = cx.tensor((seq_len, head_dim));
|
||||
|
||||
let scores = q.matmul(k.permute((1, 0)));
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let scaled_scores = scores * scale;
|
||||
let attn_weights = scaled_scores.softmax(1);
|
||||
let _ = attn_weights.matmul(v).output();
|
||||
|
||||
let input_sizes: Vec<(NodeIndex, usize)> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
if (*cx.graph[node]).as_any().downcast_ref::<Input>().is_some() {
|
||||
Some((node, seq_len * head_dim))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(mut prepared) = prepare_and_search(&mut cx, &input_sizes) else {
|
||||
println!("error: Skipping attention/{} - search failed", size_name);
|
||||
continue;
|
||||
};
|
||||
|
||||
prepared.rt.allocate_intermediate_buffers(&prepared.dyn_map);
|
||||
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
let metrics = BenchMetrics::new(stats.bytes_loaded, stats.bytes_stored, stats.flops);
|
||||
metrics_map.add("attention", &size_name, metrics.clone());
|
||||
prepared.metrics = Some(metrics);
|
||||
}
|
||||
|
||||
let size_name_clone = size_name.clone();
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(&size_name),
|
||||
&(seq_len, head_dim),
|
||||
|b, _| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut total_time = Duration::ZERO;
|
||||
|
||||
for _ in 0..iters {
|
||||
if let Some(stats) = prepared.rt.execute_with_stats(&prepared.dyn_map) {
|
||||
total_time +=
|
||||
Duration::from_secs_f64(stats.execution_time_us / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref metrics) = prepared.metrics {
|
||||
let avg_time_us = total_time.as_secs_f64() * 1_000_000.0 / iters as f64;
|
||||
collector.add(
|
||||
"attention",
|
||||
&size_name_clone,
|
||||
seq_len * head_dim,
|
||||
avg_time_us,
|
||||
metrics,
|
||||
);
|
||||
}
|
||||
|
||||
total_time
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Benchmark Entry
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_pattern_benchmarks(c: &mut Criterion) {
|
||||
let hw = MetalBenchmark::hardware_info();
|
||||
|
||||
println!("\n=== Metal Pattern Benchmarks ===");
|
||||
println!("Device: {}", hw.device_name);
|
||||
println!("Memory: {:.1} GB", hw.memory_gb);
|
||||
if let Some(bw) = hw.peak_bandwidth_gbps {
|
||||
println!("Peak Bandwidth: {:.0} GB/s", bw);
|
||||
}
|
||||
if let Some(tf) = hw.peak_tflops {
|
||||
println!("Peak Compute: {:.1} TFLOPS", tf);
|
||||
}
|
||||
println!();
|
||||
|
||||
let hardware_spec = HardwareSpec {
|
||||
device_name: hw.device_name.clone(),
|
||||
memory_gb: hw.memory_gb,
|
||||
peak_bandwidth_gbps: hw.peak_bandwidth_gbps.unwrap_or(100.0),
|
||||
peak_tflops: hw.peak_tflops.unwrap_or(1.0),
|
||||
};
|
||||
|
||||
let mut metrics_map = BenchMetricsMap::new(hardware_spec.clone());
|
||||
let collector = BenchResultCollector::new(hardware_spec);
|
||||
|
||||
bench_matmul(c, &mut metrics_map, &collector);
|
||||
bench_softmax(c, &mut metrics_map, &collector);
|
||||
bench_layer_norm(c, &mut metrics_map, &collector);
|
||||
bench_gelu(c, &mut metrics_map, &collector);
|
||||
bench_attention(c, &mut metrics_map, &collector);
|
||||
|
||||
let metrics_path = std::path::Path::new("target/criterion/pattern_metrics.json");
|
||||
if let Some(parent) = metrics_path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
if let Err(e) = metrics_map.save(metrics_path) {
|
||||
eprintln!("Warning: Failed to save metrics mapping: {}", e);
|
||||
}
|
||||
|
||||
let report = collector.into_report();
|
||||
report.print_summary();
|
||||
|
||||
let report_path = std::path::Path::new("target/criterion/pattern_report.json");
|
||||
if let Err(e) = report.save(report_path) {
|
||||
eprintln!("Warning: Failed to save full report: {}", e);
|
||||
} else {
|
||||
println!("\nReports saved to:");
|
||||
println!(" - {}", metrics_path.display());
|
||||
println!(" - {}", report_path.display());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
fn metal_pattern_benchmarks(_c: &mut Criterion) {
|
||||
println!("Metal benchmarks disabled. Run with --features metal");
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default()
|
||||
.sample_size(30)
|
||||
.warm_up_time(std::time::Duration::from_millis(500))
|
||||
.measurement_time(std::time::Duration::from_secs(3));
|
||||
targets = metal_pattern_benchmarks
|
||||
}
|
||||
|
||||
criterion_main!(benches);
|
||||
586
crates/luminal_bench/examples/debug_ops.rs
Normal file
586
crates/luminal_bench/examples/debug_ops.rs
Normal file
@@ -0,0 +1,586 @@
|
||||
#![allow(unused)]
|
||||
|
||||
//! Debug script to locate which HLIR op(s) fail to lower to a backend dialect,
|
||||
//! leading to `No valid graphs present in the e-graph!`.
|
||||
//!
|
||||
//! This tool is backend-agnostic. The specific backend is selected via feature flags.
|
||||
//! All core analysis logic lives in `luminal_bench::egglog_debug` module.
|
||||
//!
|
||||
//! Usage examples: see `crates/luminal_bench/README.md`.
|
||||
|
||||
use luminal::op::IntoEgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use luminal::{egglog_utils::hlir_to_egglog, hlir::HLIROps};
|
||||
use luminal_bench::egglog_debug::{
|
||||
DebugReport, FactQuery, analyze_hlir_dtype_chain, analyze_hlir_function_chain,
|
||||
analyze_lowering, analyze_with_ops, inspect_var_hlir, print_dtype_chain, print_function_chain,
|
||||
print_lowering_analysis, print_var_inspection, summarize_egglog_ops, summarize_hlir_ops,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Backend Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Backend-specific configuration trait.
|
||||
trait BackendConfig {
|
||||
type Runtime: luminal::op::Runtime;
|
||||
const NAME: &'static str;
|
||||
|
||||
fn build_search_space(cx: &mut Graph);
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
mod metal_backend {
|
||||
use super::*;
|
||||
use luminal_metal::runtime::MetalRuntime;
|
||||
|
||||
pub struct MetalConfig;
|
||||
|
||||
impl BackendConfig for MetalConfig {
|
||||
type Runtime = MetalRuntime;
|
||||
const NAME: &'static str = "Metal";
|
||||
|
||||
fn build_search_space(cx: &mut Graph) {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use metal_backend::MetalConfig as ActiveBackend;
|
||||
|
||||
// Future: Add CUDA backend
|
||||
// #[cfg(feature = "cuda")]
|
||||
// mod cuda_backend { ... }
|
||||
|
||||
// ============================================================================
|
||||
// Test Cases
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum Case {
|
||||
Mul,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
GeluInner,
|
||||
Gelu,
|
||||
LayerNorm,
|
||||
}
|
||||
|
||||
impl Case {
|
||||
fn all() -> &'static [Case] {
|
||||
&[
|
||||
Case::Mul,
|
||||
Case::Sigmoid,
|
||||
Case::Tanh,
|
||||
Case::GeluInner,
|
||||
Case::Gelu,
|
||||
Case::LayerNorm,
|
||||
]
|
||||
}
|
||||
|
||||
fn from_str(s: &str) -> Option<Case> {
|
||||
match s {
|
||||
"mul" => Some(Case::Mul),
|
||||
"sigmoid" => Some(Case::Sigmoid),
|
||||
"tanh" => Some(Case::Tanh),
|
||||
"gelu-inner" => Some(Case::GeluInner),
|
||||
"gelu" => Some(Case::Gelu),
|
||||
"layer-norm" | "layer_norm" => Some(Case::LayerNorm),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Case::Mul => "Mul",
|
||||
Case::Sigmoid => "Sigmoid",
|
||||
Case::Tanh => "Tanh",
|
||||
Case::GeluInner => "GeluInner",
|
||||
Case::Gelu => "Gelu",
|
||||
Case::LayerNorm => "LayerNorm",
|
||||
}
|
||||
}
|
||||
|
||||
fn build(&self, cx: &mut Graph, size: usize) {
|
||||
let out = match self {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor(size);
|
||||
x.clone() * x
|
||||
}
|
||||
Case::Sigmoid => cx.tensor(size).sigmoid(),
|
||||
Case::Tanh => cx.tensor(size).tanh(),
|
||||
Case::GeluInner => {
|
||||
let x = cx.tensor(size);
|
||||
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
|
||||
}
|
||||
Case::Gelu => cx.tensor(size).gelu(),
|
||||
Case::LayerNorm => {
|
||||
// Mirror `crates/luminal_bench/src/patterns.rs`: normalize along last axis.
|
||||
let hidden_dim = 128usize;
|
||||
let batch_seq = (size / hidden_dim).max(1);
|
||||
cx.tensor((batch_seq, hidden_dim)).layer_norm(1, 1e-5)
|
||||
}
|
||||
};
|
||||
let _ = out.output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CLI Argument Parsing
|
||||
// ============================================================================
|
||||
|
||||
struct Args {
|
||||
case: Case,
|
||||
size: usize,
|
||||
dump_egglog: Option<std::path::PathBuf>,
|
||||
print_egglog: bool,
|
||||
analyze: bool,
|
||||
inspect_vars: Vec<String>,
|
||||
inspect_ops: Vec<(String, String)>,
|
||||
trace_facts: Vec<(String, String)>,
|
||||
trace_first_missing_facts: Vec<TraceFirstMissingFact>,
|
||||
checks: Vec<Check>,
|
||||
json_out: Option<std::path::PathBuf>,
|
||||
all: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TraceFirstMissingFact {
|
||||
fn_name: String,
|
||||
within_op: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum Check {
|
||||
MissingBackend,
|
||||
DType,
|
||||
Function,
|
||||
All,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args {
|
||||
case: Case::Gelu,
|
||||
size: 262_144,
|
||||
dump_egglog: None,
|
||||
print_egglog: false,
|
||||
analyze: false,
|
||||
inspect_vars: Vec::new(),
|
||||
inspect_ops: Vec::new(),
|
||||
trace_facts: Vec::new(),
|
||||
trace_first_missing_facts: Vec::new(),
|
||||
checks: Vec::new(),
|
||||
json_out: None,
|
||||
all: false,
|
||||
};
|
||||
|
||||
// If the user writes: --trace-first-missing-fact dtype --within-op Add
|
||||
// we attach the next --within-op to the last pending request.
|
||||
let mut pending_within_op_for: Option<usize> = None;
|
||||
|
||||
let mut iter = std::env::args().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--case" => {
|
||||
let val = iter.next().expect("Missing value for --case");
|
||||
args.case = Case::from_str(&val).unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Unknown case: {}. Use: mul|sigmoid|tanh|gelu-inner|gelu",
|
||||
val
|
||||
)
|
||||
});
|
||||
}
|
||||
"--size" => {
|
||||
let val = iter.next().expect("Missing value for --size");
|
||||
args.size = val.parse().expect("Invalid --size value");
|
||||
}
|
||||
"--dump-egglog" => {
|
||||
let val = iter.next().expect("Missing value for --dump-egglog");
|
||||
args.dump_egglog = Some(val.into());
|
||||
}
|
||||
"--print-egglog" => args.print_egglog = true,
|
||||
"--analyze" => args.analyze = true,
|
||||
"--trace-fact" => {
|
||||
let fn_name = iter.next().expect("Missing function name for --trace-fact");
|
||||
let var = iter.next().expect("Missing variable for --trace-fact");
|
||||
args.trace_facts.push((fn_name, var));
|
||||
}
|
||||
"--trace-first-missing-fact" => {
|
||||
let fn_name = iter
|
||||
.next()
|
||||
.expect("Missing function name for --trace-first-missing-fact");
|
||||
args.trace_first_missing_facts.push(TraceFirstMissingFact {
|
||||
fn_name,
|
||||
within_op: String::new(),
|
||||
});
|
||||
pending_within_op_for = Some(args.trace_first_missing_facts.len() - 1);
|
||||
}
|
||||
"--within-op" => {
|
||||
let op = iter.next().expect("Missing op head for --within-op");
|
||||
let Some(idx) = pending_within_op_for.take() else {
|
||||
eprintln!("--within-op must follow a --trace-first-missing-fact");
|
||||
std::process::exit(2);
|
||||
};
|
||||
args.trace_first_missing_facts[idx].within_op = op;
|
||||
}
|
||||
"--inspect-var" => {
|
||||
let val = iter.next().expect("Missing value for --inspect-var");
|
||||
args.inspect_vars.push(val);
|
||||
}
|
||||
"--inspect-op" => {
|
||||
let val = iter.next().expect("Missing value for --inspect-op");
|
||||
let mut parts = val.split(':');
|
||||
let hlir = parts.next().unwrap_or("").to_string();
|
||||
let backend = parts.next().unwrap_or("").to_string();
|
||||
if hlir.is_empty() || backend.is_empty() || parts.next().is_some() {
|
||||
eprintln!("Invalid --inspect-op format. Expected HLIR:Backend, got {val}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
args.inspect_ops.push((hlir, backend));
|
||||
}
|
||||
"--check" => {
|
||||
let val = iter.next().expect("Missing value for --check");
|
||||
let check = match val.as_str() {
|
||||
"missing-backend" => Check::MissingBackend,
|
||||
"dtype" => Check::DType,
|
||||
"fn" | "function" => Check::Function,
|
||||
"all" => Check::All,
|
||||
_ => {
|
||||
eprintln!("Unknown --check {val}. Use: missing-backend|dtype|fn|all");
|
||||
std::process::exit(2);
|
||||
}
|
||||
};
|
||||
args.checks.push(check);
|
||||
}
|
||||
"--json" => {
|
||||
let val = iter.next().expect("Missing value for --json");
|
||||
args.json_out = Some(val.into());
|
||||
}
|
||||
"--all" => args.all = true,
|
||||
"--help" | "-h" => {
|
||||
println!(
|
||||
"Usage: debug_ops [OPTIONS]\n\n\
|
||||
Options:\n \
|
||||
--case <CASE> Test case: mul|sigmoid|tanh|gelu-inner|gelu (default: gelu)\n \
|
||||
(also: layer-norm)\n \
|
||||
--size <N> Tensor size (default: 262144)\n \
|
||||
--all Run all test cases\n \
|
||||
--analyze Run lowering analysis\n \
|
||||
--trace-fact FN VAR Trace fact FN for VAR (HLIR-only), e.g. dtype t24\n \
|
||||
--trace-first-missing-fact FN Find first missing FN within an op-head, then trace it (HLIR-only)\n \
|
||||
--within-op OPHEAD Used with --trace-first-missing-fact (e.g. Add)\n \
|
||||
--inspect-var VAR Print detailed eclass + dtype info for VAR (HLIR-only)\n \
|
||||
--inspect-op HLIR:Backend Check backend coverage for an op mapping\n \
|
||||
--check KIND Run checks: missing-backend|dtype|fn|all\n \
|
||||
--json PATH Write JSON report (use '-' for stdout)\n \
|
||||
--dump-egglog PATH Write egglog program to file\n \
|
||||
--print-egglog Print egglog program to stdout\n \
|
||||
--help Show this help"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => {
|
||||
eprintln!("Unknown argument: {}. Use --help for usage.", other);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expand checks into concrete actions and validate requirements.
|
||||
if args.checks.contains(&Check::All) {
|
||||
args.checks = vec![Check::MissingBackend, Check::DType, Check::Function];
|
||||
}
|
||||
if args.checks.contains(&Check::DType) {
|
||||
// Preserve the previous semantics: scan Add for missing dtype, then trace.
|
||||
let already_has_add_dtype = args
|
||||
.trace_first_missing_facts
|
||||
.iter()
|
||||
.any(|r| r.fn_name == "dtype" && r.within_op == "Add");
|
||||
if !already_has_add_dtype {
|
||||
args.trace_first_missing_facts.push(TraceFirstMissingFact {
|
||||
fn_name: "dtype".to_string(),
|
||||
within_op: "Add".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
if args.checks.contains(&Check::MissingBackend) && args.inspect_ops.is_empty() {
|
||||
eprintln!("--check missing-backend requires at least one --inspect-op HLIR:Backend");
|
||||
std::process::exit(2);
|
||||
}
|
||||
if args.checks.contains(&Check::Function) && args.trace_facts.is_empty() {
|
||||
eprintln!("--check fn requires at least one --trace-fact FN VAR");
|
||||
std::process::exit(2);
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Logic
|
||||
// ============================================================================
|
||||
|
||||
fn run_case<B: BackendConfig>(case: Case, size: usize, args: &Args)
|
||||
where
|
||||
B::Runtime: luminal::op::Runtime,
|
||||
<B::Runtime as luminal::op::Runtime>::Ops: luminal::op::IntoEgglogOp,
|
||||
{
|
||||
println!(
|
||||
"\n=== Case: {} (size={}) [{}] ===",
|
||||
case.name(),
|
||||
size,
|
||||
B::NAME
|
||||
);
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
case.build(&mut cx, size);
|
||||
|
||||
// Summarize HLIR
|
||||
let hlir_counts = summarize_hlir_ops(&cx);
|
||||
println!("-- HLIR node types --");
|
||||
for (k, v) in &hlir_counts {
|
||||
println!(" {}: {}", k, v);
|
||||
}
|
||||
|
||||
// Get egglog program
|
||||
let (program, root) = hlir_to_egglog(&cx);
|
||||
|
||||
// Summarize egglog ops
|
||||
let egglog_counts = summarize_egglog_ops(&program);
|
||||
println!("-- Egglog op heads --");
|
||||
for (k, v) in &egglog_counts {
|
||||
println!(" {}: {}", k, v);
|
||||
}
|
||||
println!("-- Egglog root: {} --", root);
|
||||
|
||||
// Dump egglog if requested
|
||||
if let Some(ref base_path) = args.dump_egglog {
|
||||
let path = if args.all {
|
||||
let parent = base_path.parent().unwrap_or(std::path::Path::new("."));
|
||||
let stem = base_path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("debug");
|
||||
parent.join(format!("{}-{}.egg", stem, case.name()))
|
||||
} else {
|
||||
base_path.clone()
|
||||
};
|
||||
|
||||
let content = format!("; hlir_to_egglog dump\n; root: {root}\n{program}");
|
||||
std::fs::write(&path, content).expect("Failed to write egglog file");
|
||||
println!("Wrote egglog program to {}", path.display());
|
||||
}
|
||||
|
||||
if args.print_egglog {
|
||||
println!("-- Egglog program --\n{program}");
|
||||
}
|
||||
|
||||
let find_vars_by_head = |head: &str| -> Vec<String> {
|
||||
let mut vars = Vec::new();
|
||||
for line in program.lines() {
|
||||
let line = line.trim();
|
||||
if !line.starts_with("(let ") {
|
||||
continue;
|
||||
}
|
||||
let tokens: Vec<&str> = line.split_whitespace().collect();
|
||||
if tokens.len() >= 3 && tokens[0] == "(let" {
|
||||
let var = tokens[1].to_string();
|
||||
let op = tokens[2].trim_start_matches('(');
|
||||
if op == head {
|
||||
vars.push(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
vars
|
||||
};
|
||||
|
||||
// Validate any pending --within-op pairing.
|
||||
for req in &args.trace_first_missing_facts {
|
||||
if req.within_op.is_empty() {
|
||||
eprintln!(
|
||||
"--trace-first-missing-fact {} requires --within-op OPHEAD",
|
||||
req.fn_name
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare fact queries needed for scan-first-missing-fact.
|
||||
let mut hlir_analysis = None;
|
||||
let mut backend_analysis = None;
|
||||
let mut fact_queries: Vec<FactQuery> = Vec::new();
|
||||
for req in &args.trace_first_missing_facts {
|
||||
fact_queries.push(FactQuery {
|
||||
fn_name: req.fn_name.clone(),
|
||||
vars: find_vars_by_head(&req.within_op),
|
||||
});
|
||||
}
|
||||
|
||||
// Only compute backend analysis if requested; compute HLIR analysis if needed
|
||||
// for either --analyze or scan-first-missing-fact.
|
||||
let need_backend_analysis = args.analyze || !args.inspect_ops.is_empty();
|
||||
let need_hlir_analysis = args.analyze || !fact_queries.is_empty();
|
||||
|
||||
if need_backend_analysis {
|
||||
let (hlir, backend) =
|
||||
analyze_lowering::<B::Runtime>(&program, &root, &fact_queries, &args.inspect_ops);
|
||||
hlir_analysis = Some(hlir);
|
||||
backend_analysis = Some(backend);
|
||||
} else if need_hlir_analysis {
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
hlir_analysis = Some(analyze_with_ops(
|
||||
&program,
|
||||
&root,
|
||||
hlir_ops,
|
||||
"HLIR",
|
||||
&fact_queries,
|
||||
&[],
|
||||
));
|
||||
}
|
||||
|
||||
if args.analyze {
|
||||
println!("-- Lowering analysis --");
|
||||
if let Some(ref hlir) = hlir_analysis {
|
||||
print_lowering_analysis(hlir);
|
||||
}
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty() {
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
}
|
||||
|
||||
// Trace facts for explicit variables.
|
||||
let mut function_traces = Vec::new();
|
||||
for (fn_name, var) in &args.trace_facts {
|
||||
if fn_name == "dtype" {
|
||||
println!("-- Trace dtype chain for {} (HLIR-only) --", var);
|
||||
let chain = analyze_hlir_dtype_chain(&program, var);
|
||||
print_dtype_chain(&chain);
|
||||
// Also record a function-trace entry for JSON output.
|
||||
function_traces.push(analyze_hlir_function_chain(&program, fn_name, var));
|
||||
} else {
|
||||
let trace = analyze_hlir_function_chain(&program, fn_name, var);
|
||||
print_function_chain(&trace);
|
||||
function_traces.push(trace);
|
||||
}
|
||||
}
|
||||
|
||||
// Scan for first missing fact within an op-head, then trace.
|
||||
for req in &args.trace_first_missing_facts {
|
||||
let Some(ref hlir) = hlir_analysis else {
|
||||
println!(
|
||||
"-- Trace first missing fact (fn={}) within op={} --",
|
||||
req.fn_name, req.within_op
|
||||
);
|
||||
println!(" error Skipped: HLIR analysis did not run");
|
||||
continue;
|
||||
};
|
||||
|
||||
let vars = find_vars_by_head(&req.within_op);
|
||||
if vars.is_empty() {
|
||||
println!(
|
||||
"-- Trace first missing fact (fn={}) within op={} --",
|
||||
req.fn_name, req.within_op
|
||||
);
|
||||
println!(" √ No matching vars found (op head not present)");
|
||||
continue;
|
||||
}
|
||||
|
||||
let table = hlir.facts.get(&req.fn_name);
|
||||
let first_missing = table.and_then(|t| {
|
||||
vars.iter()
|
||||
.find_map(|v| t.get(v).and_then(|s| s.is_missing().then(|| v.clone())))
|
||||
});
|
||||
|
||||
if let Some(var) = first_missing {
|
||||
println!(
|
||||
"-- Trace first missing fact (fn={}) within op={} --",
|
||||
req.fn_name, req.within_op
|
||||
);
|
||||
println!(" ❌ first missing at: {}", var);
|
||||
if req.fn_name == "dtype" {
|
||||
let chain = analyze_hlir_dtype_chain(&program, &var);
|
||||
print_dtype_chain(&chain);
|
||||
function_traces.push(analyze_hlir_function_chain(&program, "dtype", &var));
|
||||
} else {
|
||||
let trace = analyze_hlir_function_chain(&program, &req.fn_name, &var);
|
||||
print_function_chain(&trace);
|
||||
function_traces.push(trace);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"-- Trace first missing fact (fn={}) within op={} --",
|
||||
req.fn_name, req.within_op
|
||||
);
|
||||
println!(" √ No missing values found");
|
||||
}
|
||||
}
|
||||
|
||||
let mut var_inspections = Vec::new();
|
||||
if !args.inspect_vars.is_empty() {
|
||||
for var in &args.inspect_vars {
|
||||
let inspection = inspect_var_hlir(&program, var);
|
||||
print_var_inspection(&inspection);
|
||||
var_inspections.push(inspection);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to build search space
|
||||
let prev_hook = std::panic::take_hook();
|
||||
std::panic::set_hook(Box::new(|_| {}));
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
B::build_search_space(&mut cx);
|
||||
}));
|
||||
std::panic::set_hook(prev_hook);
|
||||
|
||||
let build_succeeded = result.is_ok();
|
||||
match result {
|
||||
Ok(()) => println!("√ build_search_space succeeded"),
|
||||
Err(_) => println!("❌ build_search_space failed"),
|
||||
}
|
||||
|
||||
if let Some(ref path) = args.json_out {
|
||||
let report = DebugReport {
|
||||
case_name: case.name().to_string(),
|
||||
size,
|
||||
hlir_counts,
|
||||
egglog_counts,
|
||||
hlir_analysis,
|
||||
backend_analysis,
|
||||
var_inspections,
|
||||
function_traces,
|
||||
build_succeeded,
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&report).expect("failed to serialize report");
|
||||
if path.as_os_str() == "-" {
|
||||
println!("{}", json);
|
||||
} else {
|
||||
std::fs::write(path, json).expect("failed to write json report");
|
||||
println!("Wrote JSON report to {}", path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
|
||||
println!("=== debug_ops ({}) ===", ActiveBackend::NAME);
|
||||
println!("Backend: {}", ActiveBackend::NAME);
|
||||
println!("Tip: Use --analyze for detailed lowering analysis.\n");
|
||||
|
||||
if args.all {
|
||||
for case in Case::all() {
|
||||
run_case::<ActiveBackend>(*case, args.size, &args);
|
||||
}
|
||||
} else {
|
||||
run_case::<ActiveBackend>(args.case, args.size, &args);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
fn main() {}
|
||||
550
crates/luminal_bench/src/egglog_debug/analysis.rs
Normal file
550
crates/luminal_bench/src/egglog_debug/analysis.rs
Normal file
@@ -0,0 +1,550 @@
|
||||
//! Core analysis functions for egglog debugging.
|
||||
|
||||
use super::{
|
||||
DTypeChainAnalysis, DTypeStatus, DependencyGraph, FactStatus, FunctionChainAnalysis,
|
||||
FunctionTraceEntry,
|
||||
};
|
||||
use egraph_serialize::ClassId;
|
||||
use luminal::egglog_utils;
|
||||
use luminal::hlir::HLIROps;
|
||||
use luminal::op::{EgglogOp, IntoEgglogOp, Runtime};
|
||||
use luminal::prelude::egglog;
|
||||
use luminal::prelude::egglog::prelude::exprs;
|
||||
use luminal::prelude::egglog_ast::{RustSpan, Span};
|
||||
use luminal::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Analysis result for lowering.
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct LoweringAnalysis {
|
||||
pub label: String,
|
||||
pub root_labels: Vec<String>,
|
||||
pub output_input_labels: Vec<String>,
|
||||
/// Optional op-coverage reports (only filled when explicitly requested).
|
||||
pub op_reports: Vec<OpLoweringReport>,
|
||||
/// Optional evaluated facts, keyed by function name then variable name.
|
||||
pub facts: BTreeMap<String, BTreeMap<String, FactStatus>>,
|
||||
}
|
||||
|
||||
/// Query for evaluating a function on a set of variables.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FactQuery {
|
||||
pub fn_name: String,
|
||||
pub vars: Vec<String>,
|
||||
}
|
||||
|
||||
/// Missing backend equivalent for a specific HLIR op instance.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpMissing {
|
||||
pub class_id: String,
|
||||
pub op: String,
|
||||
pub children: Vec<ChildInspection>,
|
||||
}
|
||||
|
||||
/// Report for op lowering coverage.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct OpLoweringReport {
|
||||
pub label: String,
|
||||
pub hlir_op: String,
|
||||
pub backend_op: String,
|
||||
pub total_classes: usize,
|
||||
pub missing: Vec<OpMissing>,
|
||||
}
|
||||
|
||||
/// Inspection of a specific variable's eclass and dtype facts.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct VarInspection {
|
||||
pub label: String,
|
||||
pub var: String,
|
||||
pub let_line: Option<String>,
|
||||
pub eval_error: Option<String>,
|
||||
pub class_id: Option<String>,
|
||||
pub class_type: Option<String>,
|
||||
pub class_labels: Vec<String>,
|
||||
pub dtype: Option<String>,
|
||||
pub enodes: Vec<EnodeInspection>,
|
||||
}
|
||||
|
||||
/// Inspection of an enode within an eclass.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct EnodeInspection {
|
||||
pub label: String,
|
||||
pub children: Vec<ChildInspection>,
|
||||
}
|
||||
|
||||
/// Inspection of a child eclass.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ChildInspection {
|
||||
pub class_id: String,
|
||||
pub class_type: String,
|
||||
pub class_labels: Vec<String>,
|
||||
pub dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn find_let_line(program: &str, var: &str) -> Option<String> {
|
||||
for line in program.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with("(let ") && line.split_whitespace().nth(1) == Some(var) {
|
||||
return Some(line.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn run_egraph(program: &str, ops: Vec<Arc<Box<dyn EgglogOp>>>) -> egglog::EGraph {
|
||||
let code = egglog_utils::full_egglog(program, &ops, false);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code).unwrap();
|
||||
let _outputs = egraph.run_program(commands).unwrap();
|
||||
egraph
|
||||
}
|
||||
|
||||
fn annotate_dtypes(graph: &mut DependencyGraph, egraph: &mut egglog::EGraph) {
|
||||
for node in graph.nodes.values_mut() {
|
||||
node.dtype = Some(eval_dtype(egraph, &node.var));
|
||||
}
|
||||
}
|
||||
|
||||
fn class_labels(serialized: &egglog_utils::SerializedEGraph, class_id: &ClassId) -> Vec<String> {
|
||||
let Some((_, nodes)) = serialized.eclasses.get(class_id) else {
|
||||
return vec!["<missing>".to_string()];
|
||||
};
|
||||
let mut labels: Vec<String> = nodes
|
||||
.iter()
|
||||
.filter_map(|node_id| {
|
||||
serialized
|
||||
.enodes
|
||||
.get(node_id)
|
||||
.map(|(label, _)| label.clone())
|
||||
})
|
||||
.collect();
|
||||
labels.sort();
|
||||
labels.dedup();
|
||||
labels
|
||||
}
|
||||
|
||||
fn class_type(serialized: &egglog_utils::SerializedEGraph, class_id: &ClassId) -> String {
|
||||
serialized
|
||||
.eclasses
|
||||
.get(class_id)
|
||||
.map(|(typ, _)| typ.clone())
|
||||
.unwrap_or_else(|| "<missing>".to_string())
|
||||
}
|
||||
|
||||
fn collect_dtype_facts(serialized: &egglog_utils::SerializedEGraph) -> FxHashMap<ClassId, String> {
|
||||
let mut map: FxHashMap<ClassId, String> = FxHashMap::default();
|
||||
for (node_id, (label, children)) in &serialized.enodes {
|
||||
if !label.starts_with("dtype") {
|
||||
continue;
|
||||
}
|
||||
if children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let input_class = children[0].clone();
|
||||
let dtype_class = serialized.node_to_class[node_id].clone();
|
||||
let dtype_labels = class_labels(serialized, &dtype_class);
|
||||
let dtype_label = dtype_labels
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "<unknown>".to_string());
|
||||
map.insert(input_class, dtype_label);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
fn eval_function(egraph: &mut egglog::EGraph, fn_name: &str, var: &str) -> FactStatus {
|
||||
let expr = exprs::call(fn_name, vec![exprs::var(var)]);
|
||||
match egraph.eval_expr(&expr) {
|
||||
Ok((sort, value)) => match egraph.extract_value_to_string(&sort, value) {
|
||||
Ok((s, _)) => FactStatus::Resolved(s),
|
||||
Err(_) => FactStatus::Missing("extract-error".to_string()),
|
||||
},
|
||||
Err(err) => FactStatus::Missing(format!("{err}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Inspect a specific var with a given set of ops.
|
||||
pub fn inspect_var_with_ops(
|
||||
program: &str,
|
||||
ops: Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
var: &str,
|
||||
label: &str,
|
||||
) -> VarInspection {
|
||||
let mut egraph = run_egraph(program, ops);
|
||||
let let_line = find_let_line(program, var);
|
||||
|
||||
let mut inspection = VarInspection {
|
||||
label: label.to_string(),
|
||||
var: var.to_string(),
|
||||
let_line,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let var_expr = egglog::var!(var.to_string());
|
||||
let (sort, value) = match egraph.eval_expr(&var_expr) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
inspection.eval_error = Some(format!("{err}"));
|
||||
return inspection;
|
||||
}
|
||||
};
|
||||
|
||||
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
|
||||
let dtype_facts = collect_dtype_facts(&serialized);
|
||||
|
||||
let class_id = serialized.roots.first().cloned();
|
||||
if let Some(class_id) = class_id {
|
||||
inspection.class_id = Some(format!("{:?}", class_id));
|
||||
inspection.class_type = Some(class_type(&serialized, &class_id));
|
||||
inspection.class_labels = class_labels(&serialized, &class_id);
|
||||
inspection.dtype = dtype_facts.get(&class_id).cloned();
|
||||
|
||||
if let Some((_, nodes)) = serialized.eclasses.get(&class_id) {
|
||||
for node_id in nodes {
|
||||
let Some((label, children)) = serialized.enodes.get(node_id) else {
|
||||
continue;
|
||||
};
|
||||
let mut enode = EnodeInspection {
|
||||
label: label.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
for child in children {
|
||||
let child_labels = class_labels(&serialized, child);
|
||||
let child_type = class_type(&serialized, child);
|
||||
let dtype = dtype_facts.get(child).cloned();
|
||||
enode.children.push(ChildInspection {
|
||||
class_id: format!("{:?}", child),
|
||||
class_type: child_type,
|
||||
class_labels: child_labels,
|
||||
dtype,
|
||||
});
|
||||
}
|
||||
inspection.enodes.push(enode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inspection
|
||||
}
|
||||
|
||||
/// Inspect a specific var using HLIR-only ops.
|
||||
pub fn inspect_var_hlir(program: &str, var: &str) -> VarInspection {
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
inspect_var_with_ops(program, hlir_ops, var, "HLIR")
|
||||
}
|
||||
|
||||
/// Evaluate dtype for a variable in an egraph.
|
||||
pub fn eval_dtype(egraph: &mut egglog::EGraph, var: &str) -> DTypeStatus {
|
||||
let expr = egglog::call!("dtype", vec![egglog::var!(var.to_string())]);
|
||||
match egraph.eval_expr(&expr) {
|
||||
Ok((sort, value)) => match egraph.extract_value_to_string(&sort, value) {
|
||||
Ok((s, _)) => DTypeStatus::Resolved(s),
|
||||
Err(_) => DTypeStatus::Missing("extract-error".to_string()),
|
||||
},
|
||||
Err(err) => DTypeStatus::Missing(format!("{err}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze backend lowering for a specific HLIR op -> backend op mapping.
|
||||
pub fn analyze_op_lowering_with_ops(
|
||||
program: &str,
|
||||
ops: Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
hlir_op: &str,
|
||||
backend_op: &str,
|
||||
label: &str,
|
||||
) -> OpLoweringReport {
|
||||
let mut egraph = run_egraph(program, ops);
|
||||
let (sort, value) = egraph
|
||||
.eval_expr(&egglog::var!("t0"))
|
||||
.or_else(|_| egraph.eval_expr(&egglog::var!("t1")))
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("failed to eval any root variable (t0/t1) for op inspection");
|
||||
});
|
||||
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
|
||||
let dtype_facts = collect_dtype_facts(&serialized);
|
||||
|
||||
let mut eclass_has_backend: FxHashSet<ClassId> = FxHashSet::default();
|
||||
for (node_id, (lbl, _)) in &serialized.enodes {
|
||||
if lbl == backend_op {
|
||||
eclass_has_backend.insert(serialized.node_to_class[node_id].clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_classes: FxHashSet<ClassId> = FxHashSet::default();
|
||||
let mut missing: Vec<OpMissing> = Vec::new();
|
||||
for (node_id, (lbl, children)) in &serialized.enodes {
|
||||
if lbl != hlir_op {
|
||||
continue;
|
||||
}
|
||||
let class_id = &serialized.node_to_class[node_id];
|
||||
if seen_classes.contains(class_id) {
|
||||
continue;
|
||||
}
|
||||
seen_classes.insert(class_id.clone());
|
||||
if eclass_has_backend.contains(class_id) {
|
||||
continue;
|
||||
}
|
||||
let mut child_summaries = Vec::new();
|
||||
for child in children {
|
||||
let labels = class_labels(&serialized, child);
|
||||
let typ = class_type(&serialized, child);
|
||||
let dtype = dtype_facts.get(child).cloned();
|
||||
child_summaries.push(ChildInspection {
|
||||
class_id: format!("{:?}", child),
|
||||
class_type: typ,
|
||||
class_labels: labels,
|
||||
dtype,
|
||||
});
|
||||
}
|
||||
missing.push(OpMissing {
|
||||
class_id: format!("{:?}", class_id),
|
||||
op: hlir_op.to_string(),
|
||||
children: child_summaries,
|
||||
});
|
||||
}
|
||||
|
||||
OpLoweringReport {
|
||||
label: label.to_string(),
|
||||
hlir_op: hlir_op.to_string(),
|
||||
backend_op: backend_op.to_string(),
|
||||
total_classes: seen_classes.len(),
|
||||
missing,
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze backend lowering for a specific HLIR op using a runtime's ops.
|
||||
pub fn analyze_backend_op_lowering<R: Runtime>(
|
||||
program: &str,
|
||||
hlir_op: &str,
|
||||
backend_op: &str,
|
||||
) -> OpLoweringReport
|
||||
where
|
||||
R::Ops: IntoEgglogOp,
|
||||
{
|
||||
let mut backend_ops = R::Ops::into_vec();
|
||||
backend_ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let label = format!(
|
||||
"{}+HLIR",
|
||||
std::any::type_name::<R>()
|
||||
.split("::")
|
||||
.last()
|
||||
.unwrap_or("Backend")
|
||||
);
|
||||
analyze_op_lowering_with_ops(program, backend_ops, hlir_op, backend_op, &label)
|
||||
}
|
||||
|
||||
/// Analyze lowering with a specific set of ops.
|
||||
///
|
||||
/// This is the core analysis function that works with any set of ops.
|
||||
/// Use `analyze_lowering` for convenience when working with a specific backend.
|
||||
pub fn analyze_with_ops(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
label: &str,
|
||||
fact_queries: &[FactQuery],
|
||||
op_mappings: &[(String, String)],
|
||||
) -> LoweringAnalysis {
|
||||
let mut egraph = run_egraph(program, ops);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&egglog::var!(root)).unwrap();
|
||||
let serialized = egglog_utils::SerializedEGraph::new(&egraph, vec![(sort, value)]);
|
||||
let dtype_facts = collect_dtype_facts(&serialized);
|
||||
|
||||
let root_class_id = serialized.roots.first().unwrap();
|
||||
let root_labels = class_labels(&serialized, root_class_id);
|
||||
|
||||
let mut analysis = LoweringAnalysis {
|
||||
label: label.to_string(),
|
||||
root_labels,
|
||||
output_input_labels: Vec::new(),
|
||||
op_reports: Vec::new(),
|
||||
facts: BTreeMap::new(),
|
||||
};
|
||||
|
||||
// Output input labels (if any Output exists under this root).
|
||||
for (lbl, children) in serialized.enodes.values() {
|
||||
if lbl != "Output" || children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
analysis.output_input_labels = class_labels(&serialized, &children[0]);
|
||||
break;
|
||||
}
|
||||
|
||||
// Op coverage reports (only when explicitly requested).
|
||||
for (hlir_op, backend_op) in op_mappings {
|
||||
// Determine which eclasses contain the backend op.
|
||||
let mut eclass_has_backend: FxHashSet<ClassId> = FxHashSet::default();
|
||||
for (node_id, (lbl, _)) in &serialized.enodes {
|
||||
if lbl == backend_op {
|
||||
eclass_has_backend.insert(serialized.node_to_class[node_id].clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_classes: FxHashSet<ClassId> = FxHashSet::default();
|
||||
let mut missing: Vec<OpMissing> = Vec::new();
|
||||
for (node_id, (lbl, children)) in &serialized.enodes {
|
||||
if lbl != hlir_op {
|
||||
continue;
|
||||
}
|
||||
let class_id = &serialized.node_to_class[node_id];
|
||||
if !seen_classes.insert(class_id.clone()) {
|
||||
continue;
|
||||
}
|
||||
if eclass_has_backend.contains(class_id) {
|
||||
continue;
|
||||
}
|
||||
let mut child_summaries = Vec::new();
|
||||
for child in children {
|
||||
child_summaries.push(ChildInspection {
|
||||
class_id: format!("{:?}", child),
|
||||
class_type: class_type(&serialized, child),
|
||||
class_labels: class_labels(&serialized, child),
|
||||
dtype: dtype_facts.get(child).cloned(),
|
||||
});
|
||||
}
|
||||
missing.push(OpMissing {
|
||||
class_id: format!("{:?}", class_id),
|
||||
op: hlir_op.clone(),
|
||||
children: child_summaries,
|
||||
});
|
||||
}
|
||||
|
||||
analysis.op_reports.push(OpLoweringReport {
|
||||
label: label.to_string(),
|
||||
hlir_op: hlir_op.clone(),
|
||||
backend_op: backend_op.clone(),
|
||||
total_classes: seen_classes.len(),
|
||||
missing,
|
||||
});
|
||||
}
|
||||
|
||||
// Evaluate requested facts.
|
||||
for q in fact_queries {
|
||||
let mut table: BTreeMap<String, FactStatus> = BTreeMap::new();
|
||||
for var in &q.vars {
|
||||
table.insert(var.clone(), eval_function(&mut egraph, &q.fn_name, var));
|
||||
}
|
||||
analysis.facts.insert(q.fn_name.clone(), table);
|
||||
}
|
||||
|
||||
analysis
|
||||
}
|
||||
|
||||
/// Analyze dtype propagation chain for a specific variable with a given set of ops.
|
||||
pub fn analyze_dtype_chain_with_ops(
|
||||
program: &str,
|
||||
ops: Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
target: &str,
|
||||
) -> DTypeChainAnalysis {
|
||||
let mut egraph = run_egraph(program, ops);
|
||||
let mut graph = DependencyGraph::from_program(program);
|
||||
annotate_dtypes(&mut graph, &mut egraph);
|
||||
DTypeChainAnalysis::analyze(&graph, target)
|
||||
}
|
||||
|
||||
/// Analyze dtype propagation chain for a specific variable using HLIR-only ops.
|
||||
pub fn analyze_hlir_dtype_chain(program: &str, target: &str) -> DTypeChainAnalysis {
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
analyze_dtype_chain_with_ops(program, hlir_ops, target)
|
||||
}
|
||||
|
||||
/// Analyze function propagation chain for a specific variable with a given set of ops.
|
||||
pub fn analyze_function_chain_with_ops(
|
||||
program: &str,
|
||||
ops: Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
fn_name: &str,
|
||||
target: &str,
|
||||
) -> FunctionChainAnalysis {
|
||||
let mut egraph = run_egraph(program, ops);
|
||||
let graph = DependencyGraph::from_program(program);
|
||||
let trace = graph.trace_back(target, 20);
|
||||
let mut chain = Vec::new();
|
||||
let mut first_missing = None;
|
||||
|
||||
for entry in trace {
|
||||
let status = eval_function(&mut egraph, fn_name, &entry.var);
|
||||
if first_missing.is_none() && status.is_missing() {
|
||||
first_missing = Some(entry.var.clone());
|
||||
}
|
||||
chain.push(FunctionTraceEntry {
|
||||
depth: entry.depth,
|
||||
var: entry.var,
|
||||
op_type: entry.op_type,
|
||||
status,
|
||||
});
|
||||
}
|
||||
|
||||
let all_resolved = first_missing.is_none();
|
||||
FunctionChainAnalysis {
|
||||
target: target.to_string(),
|
||||
fn_name: fn_name.to_string(),
|
||||
chain,
|
||||
first_missing,
|
||||
all_resolved,
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze function propagation chain for a specific variable using HLIR-only ops.
|
||||
pub fn analyze_hlir_function_chain(
|
||||
program: &str,
|
||||
fn_name: &str,
|
||||
target: &str,
|
||||
) -> FunctionChainAnalysis {
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
analyze_function_chain_with_ops(program, hlir_ops, fn_name, target)
|
||||
}
|
||||
|
||||
/// Run full lowering analysis comparing HLIR-only vs Backend+HLIR.
|
||||
///
|
||||
/// This is a generic function that works with any backend implementing `Runtime`.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// * `R` - The runtime type (e.g., `MetalRuntime`, `CudaRuntime`)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `program` - The egglog program string
|
||||
/// * `root` - The root variable name
|
||||
/// * `backend_add_name` - The name of the backend's Add operation (e.g., "MetalAdd")
|
||||
pub fn analyze_lowering<R: Runtime>(
|
||||
program: &str,
|
||||
root: &str,
|
||||
fact_queries: &[FactQuery],
|
||||
op_mappings: &[(String, String)],
|
||||
) -> (LoweringAnalysis, LoweringAnalysis)
|
||||
where
|
||||
R::Ops: IntoEgglogOp,
|
||||
{
|
||||
// HLIR-only analysis
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
let hlir_analysis = analyze_with_ops(program, root, hlir_ops, "HLIR", fact_queries, &[]);
|
||||
|
||||
// Backend+HLIR analysis
|
||||
let mut backend_ops = R::Ops::into_vec();
|
||||
backend_ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let backend_label = format!(
|
||||
"{}+HLIR",
|
||||
std::any::type_name::<R>()
|
||||
.split("::")
|
||||
.last()
|
||||
.unwrap_or("Backend")
|
||||
);
|
||||
let backend_analysis = analyze_with_ops(
|
||||
program,
|
||||
root,
|
||||
backend_ops,
|
||||
&backend_label,
|
||||
fact_queries,
|
||||
op_mappings,
|
||||
);
|
||||
|
||||
(hlir_analysis, backend_analysis)
|
||||
}
|
||||
|
||||
/// Convenience function for HLIR-only analysis (no backend).
|
||||
pub fn analyze_hlir_only(program: &str, root: &str) -> LoweringAnalysis {
|
||||
let hlir_ops = <HLIROps as IntoEgglogOp>::into_vec();
|
||||
analyze_with_ops(program, root, hlir_ops, "HLIR", &[], &[])
|
||||
}
|
||||
102
crates/luminal_bench/src/egglog_debug/mod.rs
Normal file
102
crates/luminal_bench/src/egglog_debug/mod.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Egglog debugging and analysis utilities.
|
||||
//!
|
||||
//! This module provides tools for diagnosing egglog lowering issues,
|
||||
//! particularly when HLIR operations fail to convert to backend implementations.
|
||||
|
||||
mod analysis;
|
||||
mod report;
|
||||
mod trace;
|
||||
|
||||
pub use analysis::*;
|
||||
pub use report::*;
|
||||
pub use trace::*;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Extract the operation head from an egglog expression.
|
||||
///
|
||||
/// Example: `(Add t1 t2 ...)` -> `Some("Add")`
|
||||
pub fn egglog_op_head(code: &str) -> Option<&str> {
|
||||
let code = code.trim();
|
||||
code.strip_prefix('(')
|
||||
.and_then(|s| s.split_whitespace().next())
|
||||
}
|
||||
|
||||
/// Summarize HLIR node types in a graph.
|
||||
pub fn summarize_hlir_ops(cx: &luminal::prelude::Graph) -> BTreeMap<String, usize> {
|
||||
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
|
||||
for node in cx.graph.node_indices() {
|
||||
let name = cx.graph[node].type_name().to_string();
|
||||
*counts.entry(name).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
/// Summarize egglog operation heads from a program string.
|
||||
pub fn summarize_egglog_ops(program: &str) -> BTreeMap<String, usize> {
|
||||
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
|
||||
for line in program.lines() {
|
||||
// Parse lines like: (let t1 (Add ...))
|
||||
let Some(code) = line.splitn(3, ' ').nth(2) else {
|
||||
continue;
|
||||
};
|
||||
let Some(head) = egglog_op_head(code) else {
|
||||
continue;
|
||||
};
|
||||
*counts.entry(head.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
/// Result of dtype analysis for a node.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum DTypeStatus {
|
||||
/// dtype was successfully resolved
|
||||
Resolved(String),
|
||||
/// dtype lookup failed
|
||||
Missing(String),
|
||||
}
|
||||
|
||||
impl DTypeStatus {
|
||||
pub fn is_missing(&self) -> bool {
|
||||
matches!(self, DTypeStatus::Missing(_))
|
||||
}
|
||||
|
||||
pub fn is_resolved(&self) -> bool {
|
||||
matches!(self, DTypeStatus::Resolved(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DTypeStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DTypeStatus::Resolved(s) => write!(f, "{}", s),
|
||||
DTypeStatus::Missing(err) => write!(f, "<missing:{}>", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of evaluating an arbitrary egglog function for a node.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FactStatus {
|
||||
/// function value was successfully resolved
|
||||
Resolved(String),
|
||||
/// function lookup failed
|
||||
Missing(String),
|
||||
}
|
||||
|
||||
impl FactStatus {
|
||||
pub fn is_missing(&self) -> bool {
|
||||
matches!(self, FactStatus::Missing(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FactStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
FactStatus::Resolved(s) => write!(f, "{}", s),
|
||||
FactStatus::Missing(err) => write!(f, "<missing:{}>", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
247
crates/luminal_bench/src/egglog_debug/report.rs
Normal file
247
crates/luminal_bench/src/egglog_debug/report.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Formatted output and reporting for egglog debug analysis.
|
||||
|
||||
use super::{
|
||||
DTypeChainAnalysis, DTypeStatus, EnodeInspection, FunctionChainAnalysis, LoweringAnalysis,
|
||||
OpLoweringReport, TraceEntry, VarInspection,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Print HLIR node type summary.
|
||||
pub fn print_hlir_summary(counts: &BTreeMap<String, usize>) {
|
||||
println!("-- HLIR node types --");
|
||||
for (k, v) in counts {
|
||||
println!(" {}: {}", k, v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Print egglog op head summary.
|
||||
pub fn print_egglog_summary(counts: &BTreeMap<String, usize>) {
|
||||
println!("-- Egglog op heads --");
|
||||
for (k, v) in counts {
|
||||
println!(" {}: {}", k, v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Print lowering analysis results.
|
||||
pub fn print_lowering_analysis(analysis: &LoweringAnalysis) {
|
||||
println!("-- {} Analysis --", analysis.label);
|
||||
println!(" Root eclass labels: {}", analysis.root_labels.join("|"));
|
||||
|
||||
if !analysis.output_input_labels.is_empty() {
|
||||
println!(
|
||||
" Output input labels: {}",
|
||||
analysis.output_input_labels.join("|")
|
||||
);
|
||||
}
|
||||
|
||||
if !analysis.facts.is_empty() {
|
||||
println!(" Facts:");
|
||||
for (fn_name, table) in &analysis.facts {
|
||||
println!(" {}:", fn_name);
|
||||
for (var, status) in table {
|
||||
let prefix = if status.is_missing() { "❌" } else { "√" };
|
||||
println!(" {} {}: {}", prefix, var, status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !analysis.op_reports.is_empty() {
|
||||
for report in &analysis.op_reports {
|
||||
print_op_lowering_report(report);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Print dependency trace as a tree.
|
||||
pub fn print_trace_tree(trace: &[TraceEntry]) {
|
||||
println!("-- Dependency trace --");
|
||||
for entry in trace {
|
||||
println!("{}", entry.format_tree());
|
||||
}
|
||||
}
|
||||
|
||||
/// Print dtype chain analysis.
|
||||
pub fn print_dtype_chain(analysis: &DTypeChainAnalysis) {
|
||||
println!("-- DType chain analysis for {} --", analysis.target);
|
||||
|
||||
if analysis.all_resolved {
|
||||
println!(" √ All nodes in chain have resolved dtype");
|
||||
} else if let Some(ref first) = analysis.first_missing {
|
||||
println!(" ❌ First missing dtype at: {}", first);
|
||||
}
|
||||
|
||||
println!(" Chain:");
|
||||
for entry in &analysis.chain {
|
||||
let dtype_str = match &entry.dtype {
|
||||
Some(DTypeStatus::Resolved(s)) => format!("√ {}", s),
|
||||
Some(DTypeStatus::Missing(_)) => "❌ missing".to_string(),
|
||||
None => "? unknown".to_string(),
|
||||
};
|
||||
let indent = " ".repeat(entry.depth + 1);
|
||||
println!("{}{} ({}) {}", indent, entry.var, entry.op_type, dtype_str);
|
||||
}
|
||||
}
|
||||
|
||||
/// Print function chain analysis.
|
||||
pub fn print_function_chain(analysis: &FunctionChainAnalysis) {
|
||||
println!(
|
||||
"-- Function chain analysis for {} (fn={}) --",
|
||||
analysis.target, analysis.fn_name
|
||||
);
|
||||
|
||||
if analysis.all_resolved {
|
||||
println!(" √ All nodes in chain have resolved value");
|
||||
} else if let Some(ref first) = analysis.first_missing {
|
||||
println!(" ❌ First missing at: {}", first);
|
||||
}
|
||||
|
||||
println!(" Chain:");
|
||||
for entry in &analysis.chain {
|
||||
println!("{}", entry.format_tree());
|
||||
}
|
||||
}
|
||||
|
||||
/// Print op lowering report.
|
||||
pub fn print_op_lowering_report(report: &OpLoweringReport) {
|
||||
println!(
|
||||
"-- Op lowering [{}] {} -> {} --",
|
||||
report.label, report.hlir_op, report.backend_op
|
||||
);
|
||||
println!(" total eclasses: {}", report.total_classes);
|
||||
if report.missing.is_empty() {
|
||||
println!(" √ All eclasses have backend equivalent");
|
||||
} else {
|
||||
println!(" ❌ Missing backend in {} eclasses:", report.missing.len());
|
||||
for miss in &report.missing {
|
||||
println!(" - class={} op={}", miss.class_id, miss.op);
|
||||
for (idx, child) in miss.children.iter().enumerate() {
|
||||
let labels = if child.class_labels.is_empty() {
|
||||
"<none>".to_string()
|
||||
} else {
|
||||
child.class_labels.join("|")
|
||||
};
|
||||
let dtype = child
|
||||
.dtype
|
||||
.clone()
|
||||
.unwrap_or_else(|| "<missing>".to_string());
|
||||
println!(
|
||||
" [{}] class={} type={} labels={} dtype={}",
|
||||
idx, child.class_id, child.class_type, labels, dtype
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn print_enode(enode: &EnodeInspection) {
|
||||
println!(" - {}", enode.label);
|
||||
for (idx, child) in enode.children.iter().enumerate() {
|
||||
let dtype = child
|
||||
.dtype
|
||||
.clone()
|
||||
.unwrap_or_else(|| "<missing>".to_string());
|
||||
let labels = if child.class_labels.is_empty() {
|
||||
"<none>".to_string()
|
||||
} else {
|
||||
child.class_labels.join("|")
|
||||
};
|
||||
println!(
|
||||
" [{}] class={} type={} labels={} dtype={}",
|
||||
idx, child.class_id, child.class_type, labels, dtype
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Print inspection results for a specific variable.
|
||||
pub fn print_var_inspection(inspection: &VarInspection) {
|
||||
println!(
|
||||
"-- Var inspection [{}] {} --",
|
||||
inspection.label, inspection.var
|
||||
);
|
||||
|
||||
if let Some(ref line) = inspection.let_line {
|
||||
println!(" let: {}", line);
|
||||
}
|
||||
if let Some(ref err) = inspection.eval_error {
|
||||
println!(" eval error: {}", err);
|
||||
return;
|
||||
}
|
||||
|
||||
let class_id = inspection.class_id.as_deref().unwrap_or("<unknown>");
|
||||
let class_type = inspection.class_type.as_deref().unwrap_or("<unknown>");
|
||||
let dtype = inspection.dtype.as_deref().unwrap_or("<missing>");
|
||||
let labels = if inspection.class_labels.is_empty() {
|
||||
"<none>".to_string()
|
||||
} else {
|
||||
inspection.class_labels.join("|")
|
||||
};
|
||||
|
||||
println!(
|
||||
" class: {} type={} labels={}",
|
||||
class_id, class_type, labels
|
||||
);
|
||||
println!(" dtype: {}", dtype);
|
||||
println!(" enodes:");
|
||||
for enode in &inspection.enodes {
|
||||
print_enode(enode);
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary report for a debug session.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DebugReport {
|
||||
pub case_name: String,
|
||||
pub size: usize,
|
||||
pub hlir_counts: BTreeMap<String, usize>,
|
||||
pub egglog_counts: BTreeMap<String, usize>,
|
||||
pub hlir_analysis: Option<LoweringAnalysis>,
|
||||
pub backend_analysis: Option<LoweringAnalysis>,
|
||||
pub var_inspections: Vec<VarInspection>,
|
||||
pub function_traces: Vec<FunctionChainAnalysis>,
|
||||
pub build_succeeded: bool,
|
||||
}
|
||||
|
||||
impl DebugReport {
|
||||
/// Print full report to stdout.
|
||||
pub fn print(&self) {
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!("Case: {} (size={})", self.case_name, self.size);
|
||||
println!("{}", "=".repeat(60));
|
||||
|
||||
print_hlir_summary(&self.hlir_counts);
|
||||
println!();
|
||||
print_egglog_summary(&self.egglog_counts);
|
||||
|
||||
if let Some(ref analysis) = self.hlir_analysis {
|
||||
println!();
|
||||
print_lowering_analysis(analysis);
|
||||
}
|
||||
|
||||
if let Some(ref analysis) = self.backend_analysis {
|
||||
println!();
|
||||
print_lowering_analysis(analysis);
|
||||
}
|
||||
|
||||
if !self.function_traces.is_empty() {
|
||||
for trace in &self.function_traces {
|
||||
println!();
|
||||
print_function_chain(trace);
|
||||
}
|
||||
}
|
||||
|
||||
if !self.var_inspections.is_empty() {
|
||||
for inspection in &self.var_inspections {
|
||||
println!();
|
||||
print_var_inspection(inspection);
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
if self.build_succeeded {
|
||||
println!("√ build_search_space succeeded");
|
||||
} else {
|
||||
println!("❌ build_search_space failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
232
crates/luminal_bench/src/egglog_debug/trace.rs
Normal file
232
crates/luminal_bench/src/egglog_debug/trace.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
//! Dependency chain tracing for dtype propagation analysis.
|
||||
|
||||
use super::{DTypeStatus, FactStatus};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// A node in the dependency graph.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DepNode {
|
||||
pub var: String,
|
||||
pub op_type: String,
|
||||
pub inputs: Vec<String>,
|
||||
pub dtype: Option<DTypeStatus>,
|
||||
}
|
||||
|
||||
/// Dependency graph built from egglog program.
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct DependencyGraph {
|
||||
pub nodes: HashMap<String, DepNode>,
|
||||
pub roots: Vec<String>,
|
||||
}
|
||||
|
||||
impl DependencyGraph {
|
||||
/// Build dependency graph from egglog program.
|
||||
pub fn from_program(program: &str) -> Self {
|
||||
let mut graph = DependencyGraph::default();
|
||||
|
||||
for line in program.lines() {
|
||||
let line = line.trim();
|
||||
if !line.starts_with("(let ") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse: (let t1 (OpName args...))
|
||||
let tokens: Vec<&str> = line.split_whitespace().collect();
|
||||
if tokens.len() < 3 || tokens[0] != "(let" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let var = tokens[1].to_string();
|
||||
let op_type = tokens[2].trim_start_matches('(').to_string();
|
||||
|
||||
// Extract input variables (t followed by digits)
|
||||
let mut inputs = Vec::new();
|
||||
let bytes = line.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
if bytes[i] == b't' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
|
||||
let start = i;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i].is_ascii_digit() {
|
||||
i += 1;
|
||||
}
|
||||
let found_var = String::from_utf8_lossy(&bytes[start..i]).to_string();
|
||||
// Don't include self
|
||||
if found_var != var {
|
||||
inputs.push(found_var);
|
||||
}
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
graph.nodes.insert(
|
||||
var.clone(),
|
||||
DepNode {
|
||||
var,
|
||||
op_type,
|
||||
inputs,
|
||||
dtype: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Find roots (nodes that are not inputs to any other node)
|
||||
let all_inputs: HashSet<String> = graph
|
||||
.nodes
|
||||
.values()
|
||||
.flat_map(|n| n.inputs.iter().cloned())
|
||||
.collect();
|
||||
|
||||
graph.roots = graph
|
||||
.nodes
|
||||
.keys()
|
||||
.filter(|k| !all_inputs.contains(*k))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
graph
|
||||
}
|
||||
|
||||
/// Trace the dependency chain from a target variable back to inputs.
|
||||
pub fn trace_back(&self, target: &str, max_depth: usize) -> Vec<TraceEntry> {
|
||||
let mut result = Vec::new();
|
||||
self.trace_back_recursive(target, 0, max_depth, &mut result, &mut HashSet::new());
|
||||
result
|
||||
}
|
||||
|
||||
fn trace_back_recursive(
|
||||
&self,
|
||||
var: &str,
|
||||
depth: usize,
|
||||
max_depth: usize,
|
||||
result: &mut Vec<TraceEntry>,
|
||||
visited: &mut HashSet<String>,
|
||||
) {
|
||||
if depth > max_depth || visited.contains(var) {
|
||||
return;
|
||||
}
|
||||
visited.insert(var.to_string());
|
||||
|
||||
let node = match self.nodes.get(var) {
|
||||
Some(n) => n,
|
||||
None => {
|
||||
result.push(TraceEntry {
|
||||
depth,
|
||||
var: var.to_string(),
|
||||
op_type: "<unknown>".to_string(),
|
||||
dtype: None,
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
result.push(TraceEntry {
|
||||
depth,
|
||||
var: node.var.clone(),
|
||||
op_type: node.op_type.clone(),
|
||||
dtype: node.dtype.clone(),
|
||||
});
|
||||
|
||||
for input in &node.inputs {
|
||||
self.trace_back_recursive(input, depth + 1, max_depth, result, visited);
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the first node in a chain that has missing dtype.
|
||||
pub fn find_dtype_break(&self, target: &str) -> Option<String> {
|
||||
let trace = self.trace_back(target, 20);
|
||||
for entry in trace {
|
||||
if let Some(DTypeStatus::Missing(_)) = entry.dtype {
|
||||
return Some(entry.var);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Entry in a trace result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TraceEntry {
|
||||
pub depth: usize,
|
||||
pub var: String,
|
||||
pub op_type: String,
|
||||
pub dtype: Option<DTypeStatus>,
|
||||
}
|
||||
|
||||
impl TraceEntry {
|
||||
/// Format as indented tree line.
|
||||
pub fn format_tree(&self) -> String {
|
||||
let indent = " ".repeat(self.depth);
|
||||
let prefix = if self.depth == 0 { "" } else { "├── " };
|
||||
let dtype_str = match &self.dtype {
|
||||
Some(d) => format!(" dtype={}", d),
|
||||
None => String::new(),
|
||||
};
|
||||
format!(
|
||||
"{}{}{} ({}){}",
|
||||
indent, prefix, self.var, self.op_type, dtype_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Entry in a function trace result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionTraceEntry {
|
||||
pub depth: usize,
|
||||
pub var: String,
|
||||
pub op_type: String,
|
||||
pub status: FactStatus,
|
||||
}
|
||||
|
||||
impl FunctionTraceEntry {
|
||||
pub fn format_tree(&self) -> String {
|
||||
let indent = " ".repeat(self.depth);
|
||||
let prefix = if self.depth == 0 { "" } else { "├── " };
|
||||
format!(
|
||||
"{}{}{} ({}) {}",
|
||||
indent, prefix, self.var, self.op_type, self.status
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of dtype chain analysis.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DTypeChainAnalysis {
|
||||
pub target: String,
|
||||
pub chain: Vec<TraceEntry>,
|
||||
pub first_missing: Option<String>,
|
||||
pub all_resolved: bool,
|
||||
}
|
||||
|
||||
impl DTypeChainAnalysis {
|
||||
/// Create from dependency graph and target variable.
|
||||
pub fn analyze(graph: &DependencyGraph, target: &str) -> Self {
|
||||
let chain = graph.trace_back(target, 20);
|
||||
let first_missing = chain
|
||||
.iter()
|
||||
.find(|e| matches!(&e.dtype, Some(DTypeStatus::Missing(_))))
|
||||
.map(|e| e.var.clone());
|
||||
let all_resolved = chain
|
||||
.iter()
|
||||
.all(|e| matches!(&e.dtype, Some(DTypeStatus::Resolved(_)) | None));
|
||||
|
||||
DTypeChainAnalysis {
|
||||
target: target.to_string(),
|
||||
chain,
|
||||
first_missing,
|
||||
all_resolved,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of function chain analysis.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionChainAnalysis {
|
||||
pub target: String,
|
||||
pub fn_name: String,
|
||||
pub chain: Vec<FunctionTraceEntry>,
|
||||
pub first_missing: Option<String>,
|
||||
pub all_resolved: bool,
|
||||
}
|
||||
95
crates/luminal_bench/src/lib.rs
Normal file
95
crates/luminal_bench/src/lib.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
//! # Luminal Benchmark Infrastructure
|
||||
//!
|
||||
//! Universal benchmark framework for Luminal backends.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - **BenchmarkBackend**: Trait that backends implement to enable benchmarking
|
||||
//! - **BenchmarkPattern**: Trait for defining benchmark workloads
|
||||
//! - **Micro benchmarks (L1)**: Single-operator performance tests (HLIR primitives)
|
||||
//! - **Pattern benchmarks (L2)**: Composite operator performance tests
|
||||
//!
|
||||
//! Usage 和调试方式见 crate 根目录的 `README.md`。
|
||||
|
||||
mod metrics;
|
||||
mod micro;
|
||||
mod patterns;
|
||||
|
||||
/// Egglog debugging and analysis utilities.
|
||||
/// This module is backend-agnostic; specific backends are selected via feature flags
|
||||
/// in the debug_ops example.
|
||||
pub mod egglog_debug;
|
||||
|
||||
pub use metrics::*;
|
||||
pub use micro::*;
|
||||
pub use patterns::*;
|
||||
|
||||
use luminal::op::Runtime;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Hardware information for a backend device
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HardwareInfo {
|
||||
pub device_name: String,
|
||||
pub memory_gb: f64,
|
||||
/// Peak memory bandwidth in GB/s (if known)
|
||||
pub peak_bandwidth_gbps: Option<f64>,
|
||||
/// Peak compute throughput in TFLOPS (if known)
|
||||
pub peak_tflops: Option<f64>,
|
||||
}
|
||||
|
||||
/// Trait that backends implement to enable benchmarking
|
||||
pub trait BenchmarkBackend {
|
||||
type Runtime: Runtime;
|
||||
|
||||
/// Initialize the runtime
|
||||
fn initialize() -> Self::Runtime;
|
||||
|
||||
/// Get backend name (used in reports)
|
||||
fn name() -> &'static str;
|
||||
|
||||
/// Get hardware information
|
||||
fn hardware_info() -> HardwareInfo;
|
||||
}
|
||||
|
||||
/// Size configuration for benchmarks
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BenchSize {
|
||||
pub name: &'static str,
|
||||
pub value: usize,
|
||||
}
|
||||
|
||||
impl BenchSize {
|
||||
pub const fn new(name: &'static str, value: usize) -> Self {
|
||||
Self { name, value }
|
||||
}
|
||||
}
|
||||
|
||||
/// Standard benchmark sizes for micro benchmarks
|
||||
pub const MICRO_SIZES: &[BenchSize] = &[
|
||||
BenchSize::new("1k", 1_000),
|
||||
BenchSize::new("100k", 100_000),
|
||||
BenchSize::new("1m", 1_000_000),
|
||||
BenchSize::new("10m", 10_000_000),
|
||||
];
|
||||
|
||||
/// Trait for defining benchmark workloads (dyn-compatible version)
|
||||
pub trait BenchmarkPattern {
|
||||
/// Pattern name (used in reports)
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Available sizes for this pattern
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
/// Build the computation graph for this pattern
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize);
|
||||
}
|
||||
|
||||
// Re-export backend implementations when features are enabled
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::MetalBenchmark;
|
||||
63
crates/luminal_bench/src/metal_backend.rs
Normal file
63
crates/luminal_bench/src/metal_backend.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
//! Metal backend implementation for benchmarking
|
||||
|
||||
use crate::{BenchmarkBackend, HardwareInfo};
|
||||
use luminal::op::Runtime;
|
||||
use luminal_metal::runtime::MetalRuntime;
|
||||
|
||||
/// Metal benchmark backend
|
||||
pub struct MetalBenchmark;
|
||||
|
||||
impl BenchmarkBackend for MetalBenchmark {
|
||||
type Runtime = MetalRuntime;
|
||||
|
||||
fn initialize() -> Self::Runtime {
|
||||
MetalRuntime::initialize(())
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"metal"
|
||||
}
|
||||
|
||||
fn hardware_info() -> HardwareInfo {
|
||||
// Try to get device info from Metal
|
||||
let device = metal::Device::system_default().expect("No Metal device found");
|
||||
let device_name = device.name().to_string();
|
||||
|
||||
// Estimate based on common Apple Silicon specs
|
||||
let (memory_gb, peak_bandwidth_gbps, peak_tflops) = estimate_device_specs(&device_name);
|
||||
|
||||
HardwareInfo {
|
||||
device_name,
|
||||
memory_gb,
|
||||
peak_bandwidth_gbps: Some(peak_bandwidth_gbps),
|
||||
peak_tflops: Some(peak_tflops),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate device specs based on device name
|
||||
fn estimate_device_specs(device_name: &str) -> (f64, f64, f64) {
|
||||
// Memory (GB), Bandwidth (GB/s), FP32 TFLOPS
|
||||
if device_name.contains("M3 Max") {
|
||||
(128.0, 400.0, 14.0)
|
||||
} else if device_name.contains("M3 Pro") {
|
||||
(36.0, 200.0, 7.0)
|
||||
} else if device_name.contains("M3") {
|
||||
(24.0, 100.0, 3.5)
|
||||
} else if device_name.contains("M2 Max") {
|
||||
(96.0, 400.0, 13.6)
|
||||
} else if device_name.contains("M2 Pro") {
|
||||
(32.0, 200.0, 6.8)
|
||||
} else if device_name.contains("M2") {
|
||||
(24.0, 100.0, 3.6)
|
||||
} else if device_name.contains("M1 Max") {
|
||||
(64.0, 400.0, 10.4)
|
||||
} else if device_name.contains("M1 Pro") {
|
||||
(32.0, 200.0, 5.2)
|
||||
} else if device_name.contains("M1") {
|
||||
(16.0, 68.0, 2.6)
|
||||
} else {
|
||||
// Generic fallback
|
||||
(8.0, 50.0, 1.0)
|
||||
}
|
||||
}
|
||||
331
crates/luminal_bench/src/metrics.rs
Normal file
331
crates/luminal_bench/src/metrics.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
//! Benchmark metrics and mapping
|
||||
//!
|
||||
//! Provides a mapping from benchmark names to their constant metrics (bytes, flops).
|
||||
//! Combined with Criterion's time measurements, this allows computing derived metrics
|
||||
//! like throughput, MBU, and MFU.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Constant metrics for a single benchmark configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchMetrics {
|
||||
/// Total bytes transferred (loaded + stored)
|
||||
pub bytes: usize,
|
||||
/// Bytes loaded from memory
|
||||
pub bytes_loaded: usize,
|
||||
/// Bytes stored to memory
|
||||
pub bytes_stored: usize,
|
||||
/// Floating-point operations
|
||||
pub flops: usize,
|
||||
}
|
||||
|
||||
impl BenchMetrics {
|
||||
pub fn new(bytes_loaded: usize, bytes_stored: usize, flops: usize) -> Self {
|
||||
Self {
|
||||
bytes: bytes_loaded + bytes_stored,
|
||||
bytes_loaded,
|
||||
bytes_stored,
|
||||
flops,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate throughput in GB/s given execution time in microseconds
|
||||
pub fn throughput_gbps(&self, time_us: f64) -> f64 {
|
||||
if time_us <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.bytes as f64 / time_us / 1000.0
|
||||
}
|
||||
|
||||
/// Calculate TFLOPS given execution time in microseconds
|
||||
pub fn tflops(&self, time_us: f64) -> f64 {
|
||||
if time_us <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.flops as f64 / time_us / 1_000_000.0
|
||||
}
|
||||
|
||||
/// Calculate MBU given execution time and peak bandwidth
|
||||
pub fn mbu(&self, time_us: f64, peak_bandwidth_gbps: f64) -> f64 {
|
||||
self.throughput_gbps(time_us) / peak_bandwidth_gbps * 100.0
|
||||
}
|
||||
|
||||
/// Calculate MFU given execution time and peak TFLOPS
|
||||
pub fn mfu(&self, time_us: f64, peak_tflops: f64) -> f64 {
|
||||
self.tflops(time_us) / peak_tflops * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Hardware specifications for a benchmark target
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HardwareSpec {
|
||||
pub device_name: String,
|
||||
pub memory_gb: f64,
|
||||
pub peak_bandwidth_gbps: f64,
|
||||
pub peak_tflops: f64,
|
||||
}
|
||||
|
||||
/// Complete benchmark metrics mapping
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchMetricsMap {
|
||||
/// Hardware specifications
|
||||
pub hardware: HardwareSpec,
|
||||
/// Mapping from "pattern/size" to metrics
|
||||
pub benchmarks: HashMap<String, BenchMetrics>,
|
||||
}
|
||||
|
||||
impl BenchMetricsMap {
|
||||
pub fn new(hardware: HardwareSpec) -> Self {
|
||||
Self {
|
||||
hardware,
|
||||
benchmarks: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add metrics for a benchmark
|
||||
pub fn add(&mut self, pattern: &str, size: &str, metrics: BenchMetrics) {
|
||||
let key = format!("{}/{}", pattern, size);
|
||||
self.benchmarks.insert(key, metrics);
|
||||
}
|
||||
|
||||
/// Get metrics for a benchmark
|
||||
pub fn get(&self, pattern: &str, size: &str) -> Option<&BenchMetrics> {
|
||||
let key = format!("{}/{}", pattern, size);
|
||||
self.benchmarks.get(&key)
|
||||
}
|
||||
|
||||
/// Export to JSON
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string_pretty(self).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Save to file
|
||||
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
|
||||
let json = self.to_json();
|
||||
std::fs::write(path, json)
|
||||
}
|
||||
|
||||
/// Load from file
|
||||
pub fn load(path: &std::path::Path) -> std::io::Result<Self> {
|
||||
let json = std::fs::read_to_string(path)?;
|
||||
serde_json::from_str(&json)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Legacy types (kept for compatibility)
|
||||
// ============================================================================
|
||||
|
||||
/// Result of a single benchmark run
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchResult {
|
||||
/// Backend name (e.g., "metal", "cuda")
|
||||
pub backend: String,
|
||||
/// Benchmark pattern name (e.g., "add_vec")
|
||||
pub benchmark: String,
|
||||
/// Size label (e.g., "1m")
|
||||
pub size_label: String,
|
||||
/// Actual size value
|
||||
pub size_value: usize,
|
||||
/// Mean execution time in microseconds
|
||||
pub mean_us: f64,
|
||||
/// Standard deviation in microseconds
|
||||
pub std_us: f64,
|
||||
/// Throughput in GB/s (if applicable)
|
||||
pub throughput_gbps: Option<f64>,
|
||||
/// Memory Bandwidth Utilization (if peak bandwidth known)
|
||||
pub mbu: Option<f64>,
|
||||
}
|
||||
|
||||
impl BenchResult {
|
||||
/// Calculate throughput given bytes transferred
|
||||
pub fn with_throughput(mut self, bytes: usize) -> Self {
|
||||
if self.mean_us > 0.0 {
|
||||
// bytes / microseconds = MB/s, then convert to GB/s
|
||||
self.throughput_gbps = Some((bytes as f64) / self.mean_us / 1000.0);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate MBU given peak bandwidth
|
||||
pub fn with_mbu(mut self, peak_bandwidth_gbps: f64) -> Self {
|
||||
if let Some(throughput) = self.throughput_gbps {
|
||||
self.mbu = Some(throughput / peak_bandwidth_gbps * 100.0);
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of benchmark results for reporting
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchReport {
|
||||
pub backend: String,
|
||||
pub hardware: String,
|
||||
pub results: Vec<BenchResult>,
|
||||
}
|
||||
|
||||
impl BenchReport {
|
||||
pub fn new(backend: &str, hardware: &str) -> Self {
|
||||
Self {
|
||||
backend: backend.to_string(),
|
||||
hardware: hardware.to_string(),
|
||||
results: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_result(&mut self, result: BenchResult) {
|
||||
self.results.push(result);
|
||||
}
|
||||
|
||||
/// Export to JSON (for CI integration)
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string_pretty(self).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Full Report with Derived Metrics
|
||||
// ============================================================================
|
||||
|
||||
/// Single benchmark result with all metrics (constant + derived)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FullBenchResult {
|
||||
pub pattern: String,
|
||||
pub size: String,
|
||||
pub size_value: usize,
|
||||
/// Execution time in microseconds
|
||||
pub time_us: f64,
|
||||
/// Bytes transferred
|
||||
pub bytes: usize,
|
||||
/// Floating-point operations
|
||||
pub flops: usize,
|
||||
/// Throughput in GB/s
|
||||
pub throughput_gbps: f64,
|
||||
/// Memory Bandwidth Utilization (%)
|
||||
pub mbu_percent: f64,
|
||||
/// Compute in TFLOPS
|
||||
pub tflops: f64,
|
||||
/// Model FLOPs Utilization (%)
|
||||
pub mfu_percent: f64,
|
||||
}
|
||||
|
||||
/// Full benchmark report with derived metrics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FullBenchReport {
|
||||
pub hardware: HardwareSpec,
|
||||
pub timestamp: String,
|
||||
pub results: Vec<FullBenchResult>,
|
||||
}
|
||||
|
||||
impl FullBenchReport {
|
||||
pub fn new(hardware: HardwareSpec) -> Self {
|
||||
let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
Self {
|
||||
hardware,
|
||||
timestamp,
|
||||
results: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_result(&mut self, result: FullBenchResult) {
|
||||
self.results.push(result);
|
||||
}
|
||||
|
||||
/// Save to JSON file
|
||||
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
|
||||
let json = serde_json::to_string_pretty(self).unwrap_or_default();
|
||||
std::fs::write(path, json)
|
||||
}
|
||||
|
||||
/// Print summary table to terminal
|
||||
pub fn print_summary(&self) {
|
||||
println!("\n{}", "=".repeat(100));
|
||||
println!("BENCHMARK RESULTS - {}", self.hardware.device_name);
|
||||
println!(
|
||||
"Peak Bandwidth: {:.0} GB/s | Peak Compute: {:.1} TFLOPS",
|
||||
self.hardware.peak_bandwidth_gbps, self.hardware.peak_tflops
|
||||
);
|
||||
println!("{}", "=".repeat(100));
|
||||
println!(
|
||||
"{:<20} {:>8} {:>12} {:>10} {:>8} {:>10} {:>8}",
|
||||
"Pattern", "Size", "Time(μs)", "GB/s", "MBU%", "TFLOPS", "MFU%"
|
||||
);
|
||||
println!("{}", "-".repeat(100));
|
||||
|
||||
for r in &self.results {
|
||||
println!(
|
||||
"{:<20} {:>8} {:>12.2} {:>10.2} {:>7.1}% {:>10.4} {:>7.1}%",
|
||||
r.pattern,
|
||||
r.size,
|
||||
r.time_us,
|
||||
r.throughput_gbps,
|
||||
r.mbu_percent,
|
||||
r.tflops,
|
||||
r.mfu_percent
|
||||
);
|
||||
}
|
||||
println!("{}", "=".repeat(100));
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe collector for benchmark results
|
||||
#[derive(Clone)]
|
||||
pub struct BenchResultCollector {
|
||||
hardware: HardwareSpec,
|
||||
results: Arc<Mutex<Vec<FullBenchResult>>>,
|
||||
}
|
||||
|
||||
impl BenchResultCollector {
|
||||
pub fn new(hardware: HardwareSpec) -> Self {
|
||||
Self {
|
||||
hardware,
|
||||
results: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a benchmark result
|
||||
pub fn add(
|
||||
&self,
|
||||
pattern: &str,
|
||||
size: &str,
|
||||
size_value: usize,
|
||||
time_us: f64,
|
||||
metrics: &BenchMetrics,
|
||||
) {
|
||||
let throughput_gbps = metrics.throughput_gbps(time_us);
|
||||
let tflops = metrics.tflops(time_us);
|
||||
let mbu_percent = metrics.mbu(time_us, self.hardware.peak_bandwidth_gbps);
|
||||
let mfu_percent = metrics.mfu(time_us, self.hardware.peak_tflops);
|
||||
|
||||
let result = FullBenchResult {
|
||||
pattern: pattern.to_string(),
|
||||
size: size.to_string(),
|
||||
size_value,
|
||||
time_us,
|
||||
bytes: metrics.bytes,
|
||||
flops: metrics.flops,
|
||||
throughput_gbps,
|
||||
mbu_percent,
|
||||
tflops,
|
||||
mfu_percent,
|
||||
};
|
||||
|
||||
self.results.lock().unwrap().push(result);
|
||||
}
|
||||
|
||||
/// Generate full report
|
||||
pub fn into_report(self) -> FullBenchReport {
|
||||
let mut report = FullBenchReport::new(self.hardware);
|
||||
report.results = self.results.lock().unwrap().clone();
|
||||
// Sort by pattern name, then by size
|
||||
report.results.sort_by(|a, b| {
|
||||
a.pattern
|
||||
.cmp(&b.pattern)
|
||||
.then_with(|| a.size_value.cmp(&b.size_value))
|
||||
});
|
||||
report
|
||||
}
|
||||
}
|
||||
338
crates/luminal_bench/src/micro.rs
Normal file
338
crates/luminal_bench/src/micro.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
//! L1 micro benchmark patterns (single-op graphs), used by `benches/micro.rs`.
|
||||
|
||||
use crate::{BenchSize, BenchmarkPattern, MICRO_SIZES};
|
||||
use luminal::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Binary Operators
|
||||
// ============================================================================
|
||||
|
||||
/// Vector addition benchmark: a + b
|
||||
#[derive(Debug, Default)]
|
||||
pub struct AddVec;
|
||||
|
||||
impl BenchmarkPattern for AddVec {
|
||||
fn name(&self) -> &'static str {
|
||||
"add_vec"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let b = cx.tensor(size.value);
|
||||
let _ = (a + b).output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector multiplication benchmark: a * b
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MulVec;
|
||||
|
||||
impl BenchmarkPattern for MulVec {
|
||||
fn name(&self) -> &'static str {
|
||||
"mul_vec"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let b = cx.tensor(size.value);
|
||||
let _ = (a * b).output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector modulo benchmark: a % b
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ModVec;
|
||||
|
||||
impl BenchmarkPattern for ModVec {
|
||||
fn name(&self) -> &'static str {
|
||||
"mod_vec"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let b = cx.tensor(size.value);
|
||||
let _ = (a % b).output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector less-than comparison benchmark: a < b
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LessThanVec;
|
||||
|
||||
impl BenchmarkPattern for LessThanVec {
|
||||
fn name(&self) -> &'static str {
|
||||
"less_than_vec"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let b = cx.tensor(size.value);
|
||||
let _ = a.lt(b).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Reduction Operators
|
||||
// ============================================================================
|
||||
|
||||
/// Sum reduction benchmark: sum(a)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SumReduce;
|
||||
|
||||
impl BenchmarkPattern for SumReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"sum_reduce"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.sum(0).output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Max reduction benchmark: max(a)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MaxReduce;
|
||||
|
||||
impl BenchmarkPattern for MaxReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"max_reduce"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.max(0).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Unary Operators
|
||||
// ============================================================================
|
||||
|
||||
/// Exp2 benchmark: 2^x
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp2Bench;
|
||||
|
||||
impl BenchmarkPattern for Exp2Bench {
|
||||
fn name(&self) -> &'static str {
|
||||
"exp2"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.exp2().output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Log2 benchmark: log2(x)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Log2Bench;
|
||||
|
||||
impl BenchmarkPattern for Log2Bench {
|
||||
fn name(&self) -> &'static str {
|
||||
"log2"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.log2().output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sin benchmark: sin(x)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SinBench;
|
||||
|
||||
impl BenchmarkPattern for SinBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"sin"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.sin().output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Recip benchmark: 1/x
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RecipBench;
|
||||
|
||||
impl BenchmarkPattern for RecipBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"recip"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.reciprocal().output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sqrt benchmark: sqrt(x)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SqrtBench;
|
||||
|
||||
impl BenchmarkPattern for SqrtBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"sqrt"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
let _ = a.sqrt().output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Indexing Operators
|
||||
// ============================================================================
|
||||
|
||||
/// Gather benchmark: gather(data, indices)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct GatherBench;
|
||||
|
||||
impl BenchmarkPattern for GatherBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"gather"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
// Simple 1D gather: data[indices]
|
||||
// data: 1D tensor of size.value elements
|
||||
// indices: 1D tensor selecting num_indices elements
|
||||
let num_indices = 1024.min(size.value);
|
||||
|
||||
let data = cx.tensor(size.value);
|
||||
// Indices must be integer type for gather operation
|
||||
let indices = cx.tensor(num_indices).as_dtype(luminal::dtype::DType::Int);
|
||||
let _ = data.gather(indices).output();
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast benchmark: type conversion (f32 -> f16 -> f32)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CastBench;
|
||||
|
||||
impl BenchmarkPattern for CastBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"cast"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
MICRO_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let a = cx.tensor(size.value);
|
||||
// Cast to f16 then back to f32 to measure round-trip cost
|
||||
let _ = a
|
||||
.cast(luminal::dtype::DType::F16)
|
||||
.cast(luminal::dtype::DType::F32)
|
||||
.output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Pattern Registry
|
||||
// ============================================================================
|
||||
|
||||
/// Get all micro benchmark patterns (HLIR primitives supported by Metal)
|
||||
pub fn all_micro_patterns() -> Vec<Box<dyn BenchmarkPattern>> {
|
||||
vec![
|
||||
// Binary operators
|
||||
Box::new(AddVec),
|
||||
Box::new(MulVec),
|
||||
Box::new(ModVec),
|
||||
Box::new(LessThanVec),
|
||||
// Reduction operators
|
||||
Box::new(SumReduce),
|
||||
Box::new(MaxReduce),
|
||||
// Unary operators
|
||||
Box::new(Exp2Bench),
|
||||
Box::new(Log2Bench),
|
||||
Box::new(SinBench),
|
||||
Box::new(RecipBench),
|
||||
Box::new(SqrtBench),
|
||||
// Indexing operators
|
||||
Box::new(GatherBench),
|
||||
// Note: CastBench removed - Metal backend does not implement Cast yet
|
||||
]
|
||||
}
|
||||
|
||||
/// Calculate bytes transferred for a benchmark pattern
|
||||
pub fn bytes_for_pattern(pattern_name: &str, size: usize) -> usize {
|
||||
let elem_size = std::mem::size_of::<f32>();
|
||||
match pattern_name {
|
||||
// Binary operators: read 2 inputs + write 1 output = 3 * size * 4 bytes
|
||||
"add_vec" | "mul_vec" | "mod_vec" | "less_than_vec" => 3 * size * elem_size,
|
||||
|
||||
// Reduction operators: read 1 input + write 1 output (scalar)
|
||||
"sum_reduce" | "max_reduce" => size * elem_size + elem_size,
|
||||
|
||||
// Unary operators: read 1 input + write 1 output = 2 * size * 4 bytes
|
||||
"exp2" | "log2" | "sin" | "recip" | "sqrt" => 2 * size * elem_size,
|
||||
|
||||
// Cast: read 1 input (f32) + write intermediate (f16) + read (f16) + write output (f32)
|
||||
// Simplified: 2 * size * 4 bytes (f32 in + f32 out)
|
||||
"cast" => 2 * size * elem_size,
|
||||
|
||||
// Gather: read indices + read gathered data + write output
|
||||
// Simple 1D gather: indices (num_indices * 4) + read data + write output
|
||||
"gather" => {
|
||||
let num_indices = 1024.min(size);
|
||||
// Read indices (i32) + read data (random access, ~num_indices elements) + write output
|
||||
num_indices * elem_size + num_indices * elem_size + num_indices * elem_size
|
||||
}
|
||||
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
299
crates/luminal_bench/src/patterns.rs
Normal file
299
crates/luminal_bench/src/patterns.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
//! L2 pattern benchmark patterns (composite graphs), used by `benches/patterns.rs`.
|
||||
|
||||
use crate::{BenchSize, BenchmarkPattern};
|
||||
use luminal::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Size Configurations
|
||||
// ============================================================================
|
||||
|
||||
/// Matrix multiplication size configuration
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MatMulSize {
|
||||
pub name: &'static str,
|
||||
pub m: usize,
|
||||
pub k: usize,
|
||||
pub n: usize,
|
||||
}
|
||||
|
||||
impl MatMulSize {
|
||||
pub const fn new(name: &'static str, m: usize, k: usize, n: usize) -> Self {
|
||||
Self { name, m, k, n }
|
||||
}
|
||||
}
|
||||
|
||||
/// Dummy size for patterns that handle sizes internally
|
||||
pub const CUSTOM_SIZE: &[BenchSize] = &[BenchSize::new("custom", 0)];
|
||||
|
||||
/// Standard matrix multiplication sizes
|
||||
pub const MATMUL_SIZES: &[MatMulSize] = &[
|
||||
// Square matrices
|
||||
MatMulSize::new("128x128", 128, 128, 128),
|
||||
MatMulSize::new("512x512", 512, 512, 512),
|
||||
MatMulSize::new("1024x1024", 1024, 1024, 1024),
|
||||
// LLM-like shapes (batch=1, hidden_dim, ffn_dim)
|
||||
MatMulSize::new("1x4096x4096", 1, 4096, 4096),
|
||||
// MatMulSize::new("32x4096x4096", 32, 4096, 4096),
|
||||
];
|
||||
|
||||
/// Transformer-like sizes for softmax, layernorm, etc.
|
||||
pub const TRANSFORMER_SIZES: &[BenchSize] = &[
|
||||
BenchSize::new("128x128", 128 * 128), // small attention
|
||||
BenchSize::new("512x512", 512 * 512), // medium attention
|
||||
BenchSize::new("2048x128", 2048 * 128), // typical seq_len x head_dim
|
||||
// BenchSize::new("4096x128", 4096 * 128), // long context
|
||||
];
|
||||
|
||||
/// Attention size configurations (seq_len, head_dim)
|
||||
pub const ATTENTION_SIZES: &[(usize, usize)] = &[
|
||||
(128, 64), // small: seq=128, head_dim=64
|
||||
(512, 64), // medium: seq=512, head_dim=64
|
||||
(1024, 64), // large: seq=1024, head_dim=64
|
||||
// (2048, 64), // xlarge: seq=2048, head_dim=64
|
||||
];
|
||||
|
||||
// ============================================================================
|
||||
// MatMul Pattern
|
||||
// ============================================================================
|
||||
|
||||
/// Matrix multiplication benchmark: C = A @ B
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MatMulBench {
|
||||
pub size: MatMulSize,
|
||||
}
|
||||
|
||||
impl MatMulBench {
|
||||
pub fn new(size: MatMulSize) -> Self {
|
||||
Self { size }
|
||||
}
|
||||
}
|
||||
|
||||
impl BenchmarkPattern for MatMulBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"matmul"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
CUSTOM_SIZE
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, _size: BenchSize) {
|
||||
let a = cx.tensor((self.size.m, self.size.k));
|
||||
let b = cx.tensor((self.size.k, self.size.n));
|
||||
let _ = a.matmul(b).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Softmax Pattern
|
||||
// ============================================================================
|
||||
|
||||
/// Softmax benchmark: softmax(x, axis=-1)
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SoftmaxBench;
|
||||
|
||||
impl BenchmarkPattern for SoftmaxBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"softmax"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
TRANSFORMER_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
// Reshape to 2D for softmax along last axis
|
||||
// Assume size.value = rows * cols, use sqrt for balanced shape
|
||||
let dim = (size.value as f64).sqrt() as usize;
|
||||
let rows = size.value / dim;
|
||||
let cols = dim;
|
||||
|
||||
let x = cx.tensor((rows, cols));
|
||||
// Softmax along last axis (axis 1)
|
||||
let _ = x.softmax(1).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LayerNorm Pattern
|
||||
// ============================================================================
|
||||
|
||||
/// Layer normalization benchmark
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LayerNormBench;
|
||||
|
||||
impl BenchmarkPattern for LayerNormBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"layer_norm"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
TRANSFORMER_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
// Typical shape: (batch * seq_len, hidden_dim)
|
||||
let hidden_dim = 128;
|
||||
let batch_seq = size.value / hidden_dim;
|
||||
|
||||
let x = cx.tensor((batch_seq.max(1), hidden_dim));
|
||||
// LayerNorm along last axis with epsilon
|
||||
let _ = x.layer_norm(1, 1e-5).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GeLU Pattern
|
||||
// ============================================================================
|
||||
|
||||
/// GeLU activation benchmark
|
||||
#[derive(Debug, Default)]
|
||||
pub struct GeLUBench;
|
||||
|
||||
impl BenchmarkPattern for GeLUBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"gelu"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
TRANSFORMER_SIZES
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, size: BenchSize) {
|
||||
let x = cx.tensor(size.value);
|
||||
let _ = x.gelu().output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Attention Pattern
|
||||
// ============================================================================
|
||||
|
||||
/// Self-attention benchmark: softmax(Q @ K^T / sqrt(d)) @ V
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AttentionBench {
|
||||
pub seq_len: usize,
|
||||
pub head_dim: usize,
|
||||
}
|
||||
|
||||
impl AttentionBench {
|
||||
pub fn new(seq_len: usize, head_dim: usize) -> Self {
|
||||
Self { seq_len, head_dim }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AttentionBench {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
seq_len: 512,
|
||||
head_dim: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BenchmarkPattern for AttentionBench {
|
||||
fn name(&self) -> &'static str {
|
||||
"attention"
|
||||
}
|
||||
|
||||
fn sizes(&self) -> &[BenchSize] {
|
||||
CUSTOM_SIZE
|
||||
}
|
||||
|
||||
fn build_graph(&self, cx: &mut Graph, _size: BenchSize) {
|
||||
let seq_len = self.seq_len;
|
||||
let head_dim = self.head_dim;
|
||||
|
||||
// Q, K, V tensors: (seq_len, head_dim)
|
||||
let q = cx.tensor((seq_len, head_dim));
|
||||
let k = cx.tensor((seq_len, head_dim));
|
||||
let v = cx.tensor((seq_len, head_dim));
|
||||
|
||||
// Attention: softmax(Q @ K^T / sqrt(d)) @ V
|
||||
// Q @ K^T -> (seq_len, seq_len)
|
||||
let scores = q.matmul(k.permute((1, 0)));
|
||||
|
||||
// Scale by 1/sqrt(head_dim)
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let scaled_scores = scores * scale;
|
||||
|
||||
// Softmax along last axis
|
||||
let attn_weights = scaled_scores.softmax(1);
|
||||
|
||||
// @ V -> (seq_len, head_dim)
|
||||
let _ = attn_weights.matmul(v).output();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Pattern Registry
|
||||
// ============================================================================
|
||||
|
||||
/// Get all high-priority pattern benchmarks
|
||||
pub fn all_pattern_benchmarks() -> Vec<Box<dyn BenchmarkPattern>> {
|
||||
let mut patterns: Vec<Box<dyn BenchmarkPattern>> = vec![];
|
||||
|
||||
// MatMul patterns with different sizes
|
||||
for size in MATMUL_SIZES {
|
||||
patterns.push(Box::new(MatMulBench::new(*size)));
|
||||
}
|
||||
|
||||
// Softmax
|
||||
patterns.push(Box::new(SoftmaxBench));
|
||||
|
||||
// LayerNorm
|
||||
patterns.push(Box::new(LayerNormBench));
|
||||
|
||||
// GeLU
|
||||
patterns.push(Box::new(GeLUBench));
|
||||
|
||||
// Attention patterns with different sizes
|
||||
for (seq_len, head_dim) in ATTENTION_SIZES {
|
||||
patterns.push(Box::new(AttentionBench::new(*seq_len, *head_dim)));
|
||||
}
|
||||
|
||||
patterns
|
||||
}
|
||||
|
||||
/// Calculate bytes transferred for pattern benchmarks
|
||||
pub fn bytes_for_pattern_bench(
|
||||
pattern_name: &str,
|
||||
size: usize,
|
||||
extra: Option<(usize, usize, usize)>,
|
||||
) -> usize {
|
||||
let elem_size = std::mem::size_of::<f32>();
|
||||
|
||||
match pattern_name {
|
||||
"matmul" => {
|
||||
if let Some((m, k, n)) = extra {
|
||||
// Read A (m*k) + Read B (k*n) + Write C (m*n)
|
||||
(m * k + k * n + m * n) * elem_size
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
"softmax" => {
|
||||
// Read input + Write output (same size)
|
||||
2 * size * elem_size
|
||||
}
|
||||
"layer_norm" => {
|
||||
// Read input + Write output
|
||||
2 * size * elem_size
|
||||
}
|
||||
"gelu" => {
|
||||
// Read input + Write output
|
||||
2 * size * elem_size
|
||||
}
|
||||
"attention" => {
|
||||
if let Some((seq_len, head_dim, _)) = extra {
|
||||
// Q, K, V reads: 3 * seq_len * head_dim
|
||||
// scores: seq_len * seq_len
|
||||
// output: seq_len * head_dim
|
||||
(3 * seq_len * head_dim + seq_len * seq_len + seq_len * head_dim) * elem_size
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
2
crates/luminal_cuda_lite/.cargo/config.toml
Normal file
2
crates/luminal_cuda_lite/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[env]
|
||||
RUST_TEST_THREADS = "1"
|
||||
33
crates/luminal_cuda_lite/Cargo.toml
Normal file
33
crates/luminal_cuda_lite/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "luminal_cuda_lite"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
description = "Cuda compiler for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
safetensors = "0.7.0"
|
||||
tracing = "0.1.43"
|
||||
half = { version = "2.7.1", features = ["num-traits"] }
|
||||
pretty-duration = "0.1.1"
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
lru = "0.16.2"
|
||||
libc = "0.2"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
num-traits = "0.2"
|
||||
29
crates/luminal_cuda_lite/README.md
Normal file
29
crates/luminal_cuda_lite/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
## luminal_cuda_lite
|
||||
|
||||
This crate contains the CUDA backend for Luminal.
|
||||
|
||||
The backend can be broken down into several main types of ops. Starting from the highest level and going lower:
|
||||
|
||||
#### Host Ops
|
||||
|
||||
Host ops are opaque operations executed from the host (can execute on device, simply launched in an opaque manner). cuBLAS is a good example of this type of op. Luminal can't assume much about these operations since they are so opaque. These ops implement the `HostOp` trait.
|
||||
|
||||
#### Kernel Ops
|
||||
|
||||
Kernel ops are operations encoded as a kernel and launch parameters. Luminal can put these into CUDA graphs. Cutlass kernels are good examples of these. These ops implement the `KernelOp` trait.
|
||||
|
||||
#### Block Ops
|
||||
|
||||
Block ops are operations encoded on the threadblock level, which implement an operation that runs for a duration within a single threadblock. These are required to use a fixed number of threads per threadblock (or gate unused threads out), and are given a fixed-size shared memory scratchpad. Luminal can fuse these operations into megakernels. These ops impelement the `BlockOp` trait.
|
||||
|
||||
#### Warp Ops
|
||||
|
||||
Warp ops are not yet merged. Stay tuned!
|
||||
|
||||
#### Thread Ops
|
||||
|
||||
Thread ops are not yet merged. Stay tuned!
|
||||
|
||||
### Architecture
|
||||
|
||||
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
258
crates/luminal_cuda_lite/src/host/cublas/mod.rs
Normal file
258
crates/luminal_cuda_lite/src/host/cublas/mod.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
|
||||
|
||||
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
|
||||
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
// Strip quotes if present (egglog strings are stored with quotes)
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasSgemmV2 {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
/// Lazily initialized cuBLAS handle - created on first execute
|
||||
cublas: OnceLock<Arc<CudaBlas>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasSgemmV2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
cublas: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cublas: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasSgemmV2 {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as i32;
|
||||
let n = self.n.exec(dyn_map).unwrap() as i32;
|
||||
let k = self.k.exec(dyn_map).unwrap() as i32;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i32;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * 4,
|
||||
b_buf.len(),
|
||||
k * n * 4,
|
||||
c_buf.len(),
|
||||
m * n * 4
|
||||
);
|
||||
let _sgemm_span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLAS_SGEMM_V2",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
?a_layout,
|
||||
?b_layout,
|
||||
)
|
||||
.entered();
|
||||
|
||||
// Use shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
|
||||
|
||||
// Set the stream for this operation (cuBLAS handle can work with any stream)
|
||||
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
|
||||
unsafe {
|
||||
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
|
||||
}
|
||||
|
||||
let status = unsafe {
|
||||
cublasSgemm_v2(
|
||||
*cublas.handle(),
|
||||
a_layout,
|
||||
b_layout,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha as *const f32,
|
||||
a_ptr as *const f32,
|
||||
lda,
|
||||
b_ptr as *const f32,
|
||||
ldb,
|
||||
&beta as *const f32,
|
||||
c_ptr as *mut f32,
|
||||
ldc,
|
||||
)
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLAS SGEMM TN failed with status: {:?}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
|
||||
self.output_size() * 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; In broadcast [batch, m, n, k] space:
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
; A strides in [batch, m, n, k]
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; B strides in [batch, m, n, k]
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
476
crates/luminal_cuda_lite/src/host/cublaslt/mod.rs
Normal file
476
crates/luminal_cuda_lite/src/host/cublaslt/mod.rs
Normal file
@@ -0,0 +1,476 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, EXPRESSION, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
CudaBlasLT, MatmulShared,
|
||||
sys::{
|
||||
cublasComputeType_t, cublasLtMatmul, cublasLtMatmulAlgoGetHeuristic,
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescCreate, cublasLtMatmulDescDestroy,
|
||||
cublasLtMatmulDescSetAttribute, cublasLtMatmulHeuristicResult_t,
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t,
|
||||
cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceDestroy,
|
||||
cublasLtMatmulPreferenceSetAttribute, cublasLtMatrixLayout_t,
|
||||
cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasLt {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
batch_count: Expression,
|
||||
stride_a: Expression,
|
||||
stride_b: Expression,
|
||||
stride_c: Expression,
|
||||
dtype: DType,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasLt {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N,
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T,
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
batch_count: 1.into(),
|
||||
stride_a: 0.into(),
|
||||
stride_b: 0.into(),
|
||||
stride_c: 0.into(),
|
||||
dtype: DType::F32,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLt {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublaslt",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
|
||||
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
|
||||
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
|
||||
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
|
||||
// (not the Mul eclass), so they survive the cascade.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
// Extract batch parameters
|
||||
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
|
||||
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
|
||||
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
|
||||
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
|
||||
|
||||
// Extract dtype from egglog
|
||||
let dtype = extract_dtype(egraph, kind_children[12]);
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
batch_count,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
dtype,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert DType to CUDA types for cuBLAS LT
|
||||
/// Returns (matrix_dtype, compute_type, scale_dtype)
|
||||
fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cudaDataType) {
|
||||
match dtype {
|
||||
// F64: matrix=f64, compute=f64, scale=f64
|
||||
DType::F64 => (
|
||||
cudaDataType::CUDA_R_64F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_64F,
|
||||
cudaDataType::CUDA_R_64F,
|
||||
),
|
||||
// F32: matrix=f32, compute=f32, scale=f32
|
||||
DType::F32 => (
|
||||
cudaDataType::CUDA_R_32F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// F16: matrix=f16, compute=f32 (FP32 accumulation for accuracy), scale=f32
|
||||
DType::F16 => (
|
||||
cudaDataType::CUDA_R_16F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// BF16: matrix=bf16, compute=f32 with tensor cores, scale=f32
|
||||
DType::Bf16 => (
|
||||
cudaDataType::CUDA_R_16BF,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// TF32: stored as f32, use fast TF32 tensor core path
|
||||
DType::TF32 => (
|
||||
cudaDataType::CUDA_R_32F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// FP8 E4M3: matrix=fp8_e4m3, compute=f32, scale=f32
|
||||
DType::F8E4M3 => (
|
||||
cudaDataType::CUDA_R_8F_E4M3,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// FP8 E5M2: matrix=fp8_e5m2, compute=f32, scale=f32
|
||||
DType::F8E5M2 => (
|
||||
cudaDataType::CUDA_R_8F_E5M2,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
DType::Int => panic!("cuBLAS LT does not support integer matmul"),
|
||||
DType::Bool => panic!("cuBLAS LT does not support bool matmul"),
|
||||
other => todo!("cuBLAS LT matmul not yet implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::cudarc::cublaslt::sys::{
|
||||
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
|
||||
};
|
||||
|
||||
// GEMM parameters — resolve z→1 for element stride before exec
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
|
||||
// Get CUDA types based on dtype
|
||||
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
|
||||
let element_size = (self.dtype.bits() / 8) as u64;
|
||||
assert!(
|
||||
element_size > 0,
|
||||
"cuBLAS LT does not support sub-byte dtype {}",
|
||||
self.dtype
|
||||
);
|
||||
|
||||
// Alpha/beta scale values (all dtypes use F32 scale type)
|
||||
let alpha_f32: f32 = 1.0;
|
||||
let beta_f32: f32 = 0.0;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Clamp leading dimensions to minimum valid values.
|
||||
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
|
||||
// dimension may be 0 in the egglog representation, but cuBLAS requires
|
||||
// lda >= rows_of_A and ldb >= rows_of_B.
|
||||
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
m
|
||||
} else {
|
||||
k
|
||||
};
|
||||
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
k
|
||||
} else {
|
||||
n
|
||||
};
|
||||
let lda = std::cmp::max(lda, a_ld_min as i64);
|
||||
let ldb = std::cmp::max(ldb, b_ld_min as i64);
|
||||
let ldc = std::cmp::max(ldc, m as i64);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT",
|
||||
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
// Allocate workspace (32 MiB)
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
|
||||
unsafe {
|
||||
// Create matmul descriptor (compute_type, scale_type for alpha/beta)
|
||||
cublasLtMatmulDescCreate(&mut matmul_desc, compute_type, scale_dtype).result()?;
|
||||
|
||||
// Set transpose attributes
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&a_layout as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&b_layout as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Create matrix layout descriptors
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
|
||||
cublasLtMatrixLayoutCreate(&mut a_desc, cuda_dtype, a_rows, a_cols, lda).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
|
||||
|
||||
// Set batched GEMM attributes if batch_count > 1
|
||||
if batch_count > 1 {
|
||||
for (desc, stride) in [(a_desc, stride_a), (b_desc, stride_b), (c_desc, stride_c)] {
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i32>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i64>(),
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create preference and set workspace size
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Get heuristic (best algorithm)
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
c_desc, // D layout same as C
|
||||
preference,
|
||||
1, // Request 1 result
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
|
||||
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
|
||||
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
alpha_ptr,
|
||||
a_ptr as *const std::ffi::c_void,
|
||||
a_desc,
|
||||
b_ptr as *const std::ffi::c_void,
|
||||
b_desc,
|
||||
beta_ptr,
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
c_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Cleanup
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
}
|
||||
63
crates/luminal_cuda_lite/src/host/mod.rs
Normal file
63
crates/luminal_cuda_lite/src/host/mod.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
moe::GLUMoE,
|
||||
);
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
/// Execute the operation with access to buffers via a map.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `stream` - The CUDA stream to execute on
|
||||
/// * `self_node` - The NodeIndex of this op in the llir_graph (used as output buffer)
|
||||
/// * `inputs` - NodeIndices of input nodes (in edge order from the graph)
|
||||
/// * `buffers` - Map from NodeIndex to device buffer for all allocated nodes
|
||||
/// * `dyn_map` - Dynamic dimension values
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Returns the output buffer size in elements.
|
||||
/// Return 0 if this op doesn't have a single output buffer (e.g., CudaGraphOp).
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns additional nodes (beyond graph edges) that this op needs buffers for.
|
||||
///
|
||||
/// For most ops, this returns empty (buffers determined by graph edges).
|
||||
/// For CudaGraphOp, this returns all internal kernel nodes.
|
||||
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
/// For CudaGraphOp, this returns sizes for all internal kernel output buffers.
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
FxHashMap::default()
|
||||
}
|
||||
|
||||
/// Returns the name of this host op for stats reporting, or None if not reportable.
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
128
crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg
Normal file
128
crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg
Normal file
@@ -0,0 +1,128 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
)
|
||||
662
crates/luminal_cuda_lite/src/host/moe/mod.rs
Normal file
662
crates/luminal_cuda_lite/src/host/moe/mod.rs
Normal file
@@ -0,0 +1,662 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device,
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
CudaBlasLT, MatmulShared,
|
||||
sys::{
|
||||
cublasComputeType_t, cublasLtMatmul, cublasLtMatmulAlgoGetHeuristic,
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, cublasLtMatmulDescCreate,
|
||||
cublasLtMatmulDescDestroy, cublasLtMatmulDescSetAttribute,
|
||||
cublasLtMatmulHeuristicResult_t, cublasLtMatmulPreference_t,
|
||||
cublasLtMatmulPreferenceAttributes_t, cublasLtMatmulPreferenceCreate,
|
||||
cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceSetAttribute,
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy,
|
||||
cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
/// 1: topk_indices [seq, k] Int
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
dn_io: Expression,
|
||||
/// K dimension of gate_up matmul (= hidden)
|
||||
gu_matmul_k: Expression,
|
||||
/// K dimension of down matmul (= intermediate)
|
||||
dn_matmul_k: Expression,
|
||||
/// K experts to sum over (= top_k)
|
||||
output_k: Expression,
|
||||
/// Total elements in a single gate_up expert weight matrix
|
||||
gu_within_range: Expression,
|
||||
/// Total elements in a single down expert weight matrix
|
||||
dn_within_range: Expression,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
dn_matmul_k: Expression::default(),
|
||||
output_k: Expression::default(),
|
||||
gu_within_range: Expression::default(),
|
||||
dn_within_range: Expression::default(),
|
||||
cublaslt: OnceLock::new(),
|
||||
module: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
.field("dn_matmul_k", &self.dn_matmul_k)
|
||||
.field("output_k", &self.output_k)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
dn_matmul_k: self.dn_matmul_k,
|
||||
output_k: self.output_k,
|
||||
gu_within_range: self.gu_within_range,
|
||||
dn_within_range: self.dn_within_range,
|
||||
cublaslt: OnceLock::new(),
|
||||
module: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
) -> &(Arc<CudaModule>, CudaFunction, CudaFunction) {
|
||||
self.module.get_or_init(|| {
|
||||
let src = r#"
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long long out_ptr, int n) {
|
||||
const float* in_ = (const float*)in_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for GLUMoE {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"GLUMoE",
|
||||
&[
|
||||
("gu_io", EXPRESSION),
|
||||
("dn_io", EXPRESSION),
|
||||
("gu_matmul_k", EXPRESSION),
|
||||
("dn_matmul_k", EXPRESSION),
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
|
||||
let extracted = GLUMoE {
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
dn_matmul_k,
|
||||
output_k,
|
||||
gu_within_range,
|
||||
dn_within_range,
|
||||
cublaslt: OnceLock::new(),
|
||||
module: OnceLock::new(),
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for GLUMoE {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
|
||||
// Get input/output buffers
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
let gate_up_ptr = buf_ptr(gate_up_buf, stream);
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
let gate_up_out_buf = unsafe { stream.alloc::<u8>(gate_up_dim * 2)? }; // BF16 per-token
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
let blocks = (n_cast as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(f32_to_bf16_fn)
|
||||
.arg(&x_ptr)
|
||||
.arg(&xbf16_ptr)
|
||||
.arg(&n_cast)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
let expert_idx = expert_idx as usize;
|
||||
|
||||
// a. Gate+Up matmul (BF16 in, BF16 out)
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
hidden as u64,
|
||||
expert_gu_ptr,
|
||||
cublasOperation_t::CUBLAS_OP_T,
|
||||
hidden as i64,
|
||||
x_t_ptr,
|
||||
cublasOperation_t::CUBLAS_OP_N,
|
||||
hidden as i64,
|
||||
gu_out_ptr,
|
||||
gate_up_dim as i64,
|
||||
cudaDataType::CUDA_R_16BF,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
1.0f32,
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
|
||||
// c. Down matmul (BF16 in → F32 out) with fused accumulate
|
||||
let expert_down_ptr = down_ptr + expert_idx as u64 * down_stride;
|
||||
let out_t_ptr = output_ptr + (t * hidden * 4) as u64; // F32
|
||||
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
intermediate as u64,
|
||||
expert_down_ptr,
|
||||
cublasOperation_t::CUBLAS_OP_T,
|
||||
intermediate as i64,
|
||||
hid_ptr,
|
||||
cublasOperation_t::CUBLAS_OP_N,
|
||||
intermediate as i64,
|
||||
out_t_ptr,
|
||||
hidden as i64,
|
||||
weight,
|
||||
beta,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
// Output is [seq, hidden] F32 → seq * hidden elements
|
||||
// But seq is dynamic. We derive from first input size / hidden.
|
||||
// Actually, output_bytes is what matters for allocation:
|
||||
Expression::from('s') * self.gu_matmul_k
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
Expression::from('s') * self.gu_matmul_k * 4 // F32
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("GLUMoE")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cublas_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
workspace_ptr: u64,
|
||||
m: u64,
|
||||
n: u64,
|
||||
k: u64,
|
||||
a_ptr: u64,
|
||||
a_op: cublasOperation_t,
|
||||
lda: i64,
|
||||
b_ptr: u64,
|
||||
b_op: cublasOperation_t,
|
||||
ldb: i64,
|
||||
c_ptr: u64,
|
||||
ldc: i64,
|
||||
dtype: cudaDataType,
|
||||
compute: cublasComputeType_t,
|
||||
alpha: f32,
|
||||
beta: f32,
|
||||
) -> anyhow::Result<()> {
|
||||
let scale_type = cudaDataType::CUDA_R_32F;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
unsafe {
|
||||
cublasLtMatmulDescCreate(&mut matmul_desc, compute, scale_type).result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&a_op as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&b_op as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
let (a_rows, a_cols) = if a_op == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_op == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
|
||||
cublasLtMatrixLayoutCreate(&mut a_desc, dtype, a_rows, a_cols, lda).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, dtype, m, n, ldc).result()?;
|
||||
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
c_desc,
|
||||
preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
&alpha as *const _ as *const std::ffi::c_void,
|
||||
a_ptr as *const std::ffi::c_void,
|
||||
a_desc,
|
||||
b_ptr as *const std::ffi::c_void,
|
||||
b_desc,
|
||||
&beta as *const _ as *const std::ffi::c_void,
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
c_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cublas_matmul_mixed(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
workspace_ptr: u64,
|
||||
m: u64,
|
||||
n: u64,
|
||||
k: u64,
|
||||
a_ptr: u64,
|
||||
a_op: cublasOperation_t,
|
||||
lda: i64,
|
||||
b_ptr: u64,
|
||||
b_op: cublasOperation_t,
|
||||
ldb: i64,
|
||||
c_ptr: u64,
|
||||
ldc: i64,
|
||||
alpha: f32,
|
||||
beta: f32,
|
||||
) -> anyhow::Result<()> {
|
||||
let ab_dtype = cudaDataType::CUDA_R_16BF;
|
||||
let cd_dtype = cudaDataType::CUDA_R_32F;
|
||||
let compute = cublasComputeType_t::CUBLAS_COMPUTE_32F;
|
||||
let scale_type = cudaDataType::CUDA_R_32F;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut d_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
unsafe {
|
||||
cublasLtMatmulDescCreate(&mut matmul_desc, compute, scale_type).result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&a_op as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&b_op as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
let (a_rows, a_cols) = if a_op == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_op == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
|
||||
cublasLtMatrixLayoutCreate(&mut a_desc, ab_dtype, a_rows, a_cols, lda).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, ab_dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, cd_dtype, m, n, ldc).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut d_desc, cd_dtype, m, n, ldc).result()?;
|
||||
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
d_desc,
|
||||
preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(d_desc);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
return Err(anyhow::anyhow!(
|
||||
"No suitable cuBLASLT algorithm found for mixed matmul"
|
||||
));
|
||||
}
|
||||
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
&alpha as *const _ as *const std::ffi::c_void,
|
||||
a_ptr as *const std::ffi::c_void,
|
||||
a_desc,
|
||||
b_ptr as *const std::ffi::c_void,
|
||||
b_desc,
|
||||
&beta as *const _ as *const std::ffi::c_void,
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
d_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(d_desc);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
656
crates/luminal_cuda_lite/src/kernel/cuda_graph.rs
Normal file
656
crates/luminal_cuda_lite/src/kernel/cuda_graph.rs
Normal file
@@ -0,0 +1,656 @@
|
||||
#![allow(clippy::missing_safety_doc, clippy::not_unsafe_ptr_arg_deref)]
|
||||
//! CUDA Graph API wrappers for explicit graph construction and surgical updates.
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaContext, CudaFunction, CudaStream, DriverError,
|
||||
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
|
||||
};
|
||||
|
||||
/// A CUDA graph that can be modified and instantiated.
|
||||
pub struct CudaGraphHandle {
|
||||
pub(crate) cu_graph: CUgraph,
|
||||
pub(crate) ctx: Arc<CudaContext>,
|
||||
}
|
||||
|
||||
impl CudaGraphHandle {
|
||||
/// Creates a new empty CUDA graph.
|
||||
pub fn new(ctx: Arc<CudaContext>) -> Result<Self, DriverError> {
|
||||
ctx.bind_to_thread()?;
|
||||
let mut graph = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphCreate(graph.as_mut_ptr(), 0).result()?;
|
||||
Ok(Self {
|
||||
cu_graph: graph.assume_init(),
|
||||
ctx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a kernel node to the graph. kernel_params must remain valid for graph lifetime.
|
||||
pub unsafe fn add_kernel_node(
|
||||
&mut self,
|
||||
dependencies: &[CUgraphNode],
|
||||
func: CUfunction,
|
||||
grid_dim: (u32, u32, u32),
|
||||
block_dim: (u32, u32, u32),
|
||||
shared_mem_bytes: u32,
|
||||
kernel_params: *mut *mut c_void,
|
||||
) -> Result<CUgraphNode, DriverError> {
|
||||
let params = sys::CUDA_KERNEL_NODE_PARAMS {
|
||||
func,
|
||||
gridDimX: grid_dim.0,
|
||||
gridDimY: grid_dim.1,
|
||||
gridDimZ: grid_dim.2,
|
||||
blockDimX: block_dim.0,
|
||||
blockDimY: block_dim.1,
|
||||
blockDimZ: block_dim.2,
|
||||
sharedMemBytes: shared_mem_bytes,
|
||||
kernelParams: kernel_params,
|
||||
extra: std::ptr::null_mut(),
|
||||
kern: std::ptr::null_mut(), // Not using CUkernel-based launch
|
||||
ctx: std::ptr::null_mut(), // Use default context
|
||||
};
|
||||
|
||||
let mut node = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphAddKernelNode_v2(
|
||||
node.as_mut_ptr(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
dependencies.len(),
|
||||
¶ms,
|
||||
)
|
||||
.result()?;
|
||||
Ok(node.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds an event record node to the graph for timing.
|
||||
pub fn add_event_record_node(
|
||||
&mut self,
|
||||
dependencies: &[CUgraphNode],
|
||||
event: CUevent,
|
||||
) -> Result<CUgraphNode, DriverError> {
|
||||
let mut node = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphAddEventRecordNode(
|
||||
node.as_mut_ptr(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
dependencies.len(),
|
||||
event,
|
||||
)
|
||||
.result()?;
|
||||
Ok(node.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Instantiates the graph, creating an executable graph.
|
||||
pub fn instantiate(&self) -> Result<CudaGraphExecHandle, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut graph_exec = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphInstantiateWithFlags(graph_exec.as_mut_ptr(), self.cu_graph, 0).result()?;
|
||||
Ok(CudaGraphExecHandle {
|
||||
cu_graph_exec: graph_exec.assume_init(),
|
||||
ctx: self.ctx.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphHandle {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.ctx.bind_to_thread();
|
||||
if !self.cu_graph.is_null() {
|
||||
unsafe {
|
||||
let _ = sys::cuGraphDestroy(self.cu_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An instantiated CUDA graph that can be launched and updated.
|
||||
pub struct CudaGraphExecHandle {
|
||||
pub(crate) cu_graph_exec: CUgraphExec,
|
||||
pub(crate) ctx: Arc<CudaContext>,
|
||||
}
|
||||
|
||||
impl CudaGraphExecHandle {
|
||||
/// Launches the graph on the given stream.
|
||||
pub fn launch(&self, stream: &CudaStream) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe { sys::cuGraphLaunch(self.cu_graph_exec, stream.cu_stream()).result() }
|
||||
}
|
||||
|
||||
/// Surgically updates a kernel node's parameters without rebuilding the graph.
|
||||
pub unsafe fn update_kernel_node(
|
||||
&mut self,
|
||||
node: CUgraphNode,
|
||||
func: CUfunction,
|
||||
grid_dim: (u32, u32, u32),
|
||||
block_dim: (u32, u32, u32),
|
||||
shared_mem_bytes: u32,
|
||||
kernel_params: *mut *mut c_void,
|
||||
) -> Result<(), DriverError> {
|
||||
let params = sys::CUDA_KERNEL_NODE_PARAMS {
|
||||
func,
|
||||
gridDimX: grid_dim.0,
|
||||
gridDimY: grid_dim.1,
|
||||
gridDimZ: grid_dim.2,
|
||||
blockDimX: block_dim.0,
|
||||
blockDimY: block_dim.1,
|
||||
blockDimZ: block_dim.2,
|
||||
sharedMemBytes: shared_mem_bytes,
|
||||
kernelParams: kernel_params,
|
||||
extra: std::ptr::null_mut(),
|
||||
kern: std::ptr::null_mut(),
|
||||
ctx: std::ptr::null_mut(),
|
||||
};
|
||||
|
||||
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, ¶ms) }
|
||||
.result()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphExecHandle {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.ctx.bind_to_thread();
|
||||
if !self.cu_graph_exec.is_null() {
|
||||
unsafe {
|
||||
let _ = sys::cuGraphExecDestroy(self.cu_graph_exec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait to get the raw CUfunction handle from CudaFunction.
|
||||
pub trait CudaFunctionExt {
|
||||
unsafe fn raw_function(&self) -> CUfunction;
|
||||
}
|
||||
|
||||
impl CudaFunctionExt for CudaFunction {
|
||||
unsafe fn raw_function(&self) -> CUfunction {
|
||||
// CudaFunction fields are reordered by Rust - cu_function is at offset 8
|
||||
debug_assert_eq!(
|
||||
std::mem::size_of::<CudaFunction>(),
|
||||
std::mem::size_of::<CUfunction>() + std::mem::size_of::<usize>()
|
||||
);
|
||||
unsafe {
|
||||
let ptr = (self as *const CudaFunction as *const u8).add(8) as *const CUfunction;
|
||||
std::ptr::read(ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stored kernel parameters that persist for the lifetime of a CUDA graph.
|
||||
#[derive(Debug)]
|
||||
pub struct KernelParams {
|
||||
values: Box<[u64]>,
|
||||
ptrs: Box<[*mut c_void]>,
|
||||
/// Index of the dyn_dims pointer in values array (if present)
|
||||
dyn_dims_idx: Option<usize>,
|
||||
}
|
||||
|
||||
impl KernelParams {
|
||||
pub fn new(output_ptr: u64, input_ptrs: &[u64]) -> Self {
|
||||
let mut values: Vec<u64> = Vec::with_capacity(1 + input_ptrs.len());
|
||||
values.push(output_ptr);
|
||||
values.extend_from_slice(input_ptrs);
|
||||
let values = values.into_boxed_slice();
|
||||
let ptrs: Vec<*mut c_void> = values
|
||||
.iter()
|
||||
.map(|v| v as *const u64 as *mut c_void)
|
||||
.collect();
|
||||
Self {
|
||||
values,
|
||||
ptrs: ptrs.into_boxed_slice(),
|
||||
dyn_dims_idx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create kernel params with a dyn_dims pointer as the last parameter.
|
||||
pub fn with_dyn_dims(output_ptr: u64, input_ptrs: &[u64], dyn_dims_ptr: u64) -> Self {
|
||||
let mut values: Vec<u64> = Vec::with_capacity(2 + input_ptrs.len());
|
||||
values.push(output_ptr);
|
||||
values.extend_from_slice(input_ptrs);
|
||||
let dyn_dims_idx = values.len();
|
||||
values.push(dyn_dims_ptr);
|
||||
let values = values.into_boxed_slice();
|
||||
let ptrs: Vec<*mut c_void> = values
|
||||
.iter()
|
||||
.map(|v| v as *const u64 as *mut c_void)
|
||||
.collect();
|
||||
Self {
|
||||
values,
|
||||
ptrs: ptrs.into_boxed_slice(),
|
||||
dyn_dims_idx: Some(dyn_dims_idx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_cuda_params(&mut self) -> *mut *mut c_void {
|
||||
self.ptrs.as_mut_ptr()
|
||||
}
|
||||
|
||||
pub fn update_output(&mut self, ptr: u64) {
|
||||
self.values[0] = ptr;
|
||||
}
|
||||
|
||||
pub fn update_input(&mut self, index: usize, ptr: u64) {
|
||||
self.values[1 + index] = ptr;
|
||||
}
|
||||
|
||||
/// Update the dyn_dims pointer if this kernel uses one.
|
||||
pub fn update_dyn_dims(&mut self, ptr: u64) {
|
||||
if let Some(idx) = self.dyn_dims_idx {
|
||||
self.values[idx] = ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stored kernel parameters for megakernels that persist for the lifetime of a CUDA graph.
|
||||
/// Params: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
|
||||
#[derive(Debug)]
|
||||
pub struct MegakernelParams {
|
||||
/// Parameter values: [tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims]
|
||||
values: Box<[u64]>,
|
||||
/// Pointer array for CUDA kernel launch
|
||||
ptrs: Box<[*mut c_void]>,
|
||||
}
|
||||
|
||||
impl MegakernelParams {
|
||||
/// Create megakernel params with all internal buffer pointers and dyn_dims.
|
||||
/// Order: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
tasks_ptr: u64,
|
||||
head_ptr: u64,
|
||||
ready_ptr: u64,
|
||||
queue_lock_ptr: u64,
|
||||
timings_ptr: u64,
|
||||
start_times_ptr: u64,
|
||||
buffers_ptr: u64,
|
||||
dyn_dims_ptr: u64,
|
||||
) -> Self {
|
||||
let values: Box<[u64]> = vec![
|
||||
tasks_ptr,
|
||||
head_ptr,
|
||||
ready_ptr,
|
||||
queue_lock_ptr,
|
||||
timings_ptr,
|
||||
start_times_ptr,
|
||||
buffers_ptr,
|
||||
dyn_dims_ptr,
|
||||
]
|
||||
.into_boxed_slice();
|
||||
let ptrs: Box<[*mut c_void]> = values
|
||||
.iter()
|
||||
.map(|v| v as *const u64 as *mut c_void)
|
||||
.collect();
|
||||
Self { values, ptrs }
|
||||
}
|
||||
|
||||
pub fn as_cuda_params(&mut self) -> *mut *mut c_void {
|
||||
// Rebuild pointers (in case struct was moved)
|
||||
for (i, v) in self.values.iter().enumerate() {
|
||||
self.ptrs[i] = v as *const u64 as *mut c_void;
|
||||
}
|
||||
self.ptrs.as_mut_ptr()
|
||||
}
|
||||
|
||||
/// Update the buffers pointer (index 6).
|
||||
pub fn update_buffers(&mut self, ptr: u64) {
|
||||
self.values[6] = ptr;
|
||||
}
|
||||
|
||||
/// Update the dyn_dims pointer (index 7).
|
||||
pub fn update_dyn_dims(&mut self, ptr: u64) {
|
||||
self.values[7] = ptr;
|
||||
}
|
||||
|
||||
/// Get the current buffers pointer value.
|
||||
pub fn buffers_ptr(&self) -> u64 {
|
||||
self.values[6]
|
||||
}
|
||||
}
|
||||
|
||||
/// Timing data for a single kernel in a CUDA graph.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CudaGraphKernelTiming {
|
||||
pub kernel_name: &'static str,
|
||||
pub start_ns: u64,
|
||||
pub end_ns: u64,
|
||||
}
|
||||
|
||||
/// Timing data for a CUDA graph execution.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CudaGraphTiming {
|
||||
pub kernel_timings: Vec<CudaGraphKernelTiming>,
|
||||
/// Time from launch call until first kernel started on GPU
|
||||
pub launch_latency_ns: u64,
|
||||
/// Elapsed time (in nanoseconds) from span entry to just before graph launch.
|
||||
/// This captures the setup overhead (constants, buffers, graph building) that
|
||||
/// occurs before the GPU actually starts executing.
|
||||
pub setup_duration_ns: u64,
|
||||
}
|
||||
|
||||
pub fn create_cuda_event(ctx: &Arc<CudaContext>) -> Result<CUevent, DriverError> {
|
||||
ctx.bind_to_thread()?;
|
||||
let mut event = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuEventCreate(
|
||||
event.as_mut_ptr(),
|
||||
sys::CUevent_flags::CU_EVENT_DEFAULT as u32,
|
||||
)
|
||||
.result()?;
|
||||
Ok(event.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn destroy_cuda_event(ctx: &Arc<CudaContext>, event: CUevent) {
|
||||
if !event.is_null() {
|
||||
let _ = ctx.bind_to_thread();
|
||||
unsafe {
|
||||
let _ = sys::cuEventDestroy_v2(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn event_elapsed_ms(
|
||||
ctx: &Arc<CudaContext>,
|
||||
start: CUevent,
|
||||
end: CUevent,
|
||||
) -> Result<f32, DriverError> {
|
||||
ctx.bind_to_thread()?;
|
||||
let mut ms: f32 = 0.0;
|
||||
unsafe {
|
||||
sys::cuEventElapsedTime_v2(&mut ms, start, end).result()?;
|
||||
}
|
||||
Ok(ms)
|
||||
}
|
||||
|
||||
pub fn record_event_on_stream(
|
||||
ctx: &Arc<CudaContext>,
|
||||
event: CUevent,
|
||||
stream: &CudaStream,
|
||||
) -> Result<(), DriverError> {
|
||||
ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuEventRecord(event, stream.cu_stream()).result()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::{Device, Tensor};
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_empty_graph() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
assert!(CudaGraphHandle::new(ctx).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_params() {
|
||||
let mut params = KernelParams::new(0x1000, &[0x2000, 0x3000]);
|
||||
assert!(!params.as_cuda_params().is_null());
|
||||
params.update_output(0x4000);
|
||||
params.update_input(0, 0x5000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cuda_function_size() {
|
||||
assert_eq!(
|
||||
std::mem::size_of::<CudaFunction>(),
|
||||
std::mem::size_of::<CUfunction>() + std::mem::size_of::<usize>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_raw_function_extraction() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
let func = module.load_function("test_kernel").unwrap();
|
||||
let cu_func = unsafe { func.raw_function() };
|
||||
assert!(!cu_func.is_null());
|
||||
let mut max_threads: i32 = 0;
|
||||
let result = unsafe {
|
||||
sys::cuFuncGetAttribute(
|
||||
&mut max_threads,
|
||||
sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
|
||||
cu_func,
|
||||
)
|
||||
};
|
||||
assert!(result == sys::cudaError_enum::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_with_kernel() {
|
||||
use cudarc::driver::{CudaSlice, DevicePtr};
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
let func = module.load_function("test_kernel").unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
let output: CudaSlice<f32> = unsafe { stream.alloc(1) }.unwrap();
|
||||
let mut input: CudaSlice<f32> = unsafe { stream.alloc(1) }.unwrap();
|
||||
stream.memcpy_htod(&[5.0f32], &mut input).unwrap();
|
||||
let cu_func = unsafe { func.raw_function() };
|
||||
let mut graph = CudaGraphHandle::new(ctx.clone()).unwrap();
|
||||
let mut params =
|
||||
KernelParams::new(output.device_ptr(&stream).0, &[input.device_ptr(&stream).0]);
|
||||
let _node = unsafe {
|
||||
graph.add_kernel_node(
|
||||
&[],
|
||||
cu_func,
|
||||
(1, 1, 1),
|
||||
(1, 1, 1),
|
||||
0,
|
||||
params.as_cuda_params(),
|
||||
)
|
||||
}
|
||||
.unwrap();
|
||||
let exec = graph.instantiate().unwrap();
|
||||
exec.launch(&stream).unwrap();
|
||||
stream.synchronize().unwrap();
|
||||
let mut result = [0.0f32];
|
||||
stream.memcpy_dtoh(&output, &mut result).unwrap();
|
||||
assert_eq!(result[0], 6.0f32);
|
||||
}
|
||||
|
||||
// CUDA Graph Tests
|
||||
|
||||
#[test]
|
||||
fn test_cuda_graph_basic_execution() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let size = 1024;
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = ((a + b) * a + b).output();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&result1, &rt.get_f32(c), tol, tol);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a + b)
|
||||
.collect();
|
||||
assert_close(&result1, &expected, tol, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cuda_graph_multiple_executions() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let size = 2048;
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = (a + b + a + b).output();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
results.push(rt.get_f32(c));
|
||||
}
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
for result in &results {
|
||||
assert_close(result, &results[0], tol, tol);
|
||||
}
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| a + b + a + b)
|
||||
.collect();
|
||||
assert_close(&results[0], &expected, tol, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cuda_graph_dyn_dims_surgical_update() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let size = 512;
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor('s');
|
||||
let b = cx.tensor('s');
|
||||
let c = (a + b).output();
|
||||
let d = (c * a).output();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&rt.get_f32(d), &expected, tol, tol);
|
||||
let size = 1024;
|
||||
let data_a2 = random_f32_vec(size, 44, -0.5, 0.5);
|
||||
let data_b2 = random_f32_vec(size, 45, -0.5, 0.5);
|
||||
rt.set_data(a, data_a2.clone());
|
||||
rt.set_data(b, data_b2.clone());
|
||||
cx.set_dim('s', size);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected2: Vec<f32> = data_a2
|
||||
.iter()
|
||||
.zip(&data_b2)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
assert_close(&rt.get_f32(d), &expected2, tol, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_kernel_in_graph() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let size = 1024;
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(size);
|
||||
let b = cx.tensor(size);
|
||||
let c = (a + b).output();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
assert!(rt.last_kernel_stats.iter().any(|s| s.name == "CudaGraph"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cuda_graph_chain_performance() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let size = 4096;
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let mut result = a + b;
|
||||
for _ in 0..5 {
|
||||
result += a;
|
||||
result *= b;
|
||||
}
|
||||
let output = result.output();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
let mut expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
for _ in 0..5 {
|
||||
expected = expected.iter().zip(&data_a).map(|(r, a)| r + a).collect();
|
||||
expected = expected.iter().zip(&data_b).map(|(r, b)| r * b).collect();
|
||||
}
|
||||
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
|
||||
}
|
||||
}
|
||||
3227
crates/luminal_cuda_lite/src/kernel/hlir.rs
Normal file
3227
crates/luminal_cuda_lite/src/kernel/hlir.rs
Normal file
File diff suppressed because it is too large
Load Diff
289
crates/luminal_cuda_lite/src/kernel/mod.rs
Normal file
289
crates/luminal_cuda_lite/src/kernel/mod.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::*;
|
||||
use luminal_tracing::schema::{
|
||||
self as schema, TrackEvent, debug_annotation::NameField, trace_packet, track_event,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod cuda_graph;
|
||||
pub mod hlir;
|
||||
pub mod other_ops;
|
||||
|
||||
pub use cuda_graph::*;
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
use luminal_tracing::schema::trace_packet;
|
||||
let mut interned: std::collections::HashMap<(u32, u64), String> =
|
||||
std::collections::HashMap::new();
|
||||
for packet in &trace.packet {
|
||||
let seq_id = match &packet.optional_trusted_packet_sequence_id {
|
||||
Some(trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(seq)) => {
|
||||
*seq
|
||||
}
|
||||
_ => 0,
|
||||
};
|
||||
// interned_data is a field on TracePacket, not a Data variant
|
||||
if let Some(data) = &packet.interned_data {
|
||||
for entry in &data.debug_annotation_names {
|
||||
if let Some(name) = &entry.name {
|
||||
interned.insert((seq_id, entry.iid()), name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
interned
|
||||
}
|
||||
|
||||
/// Check if a debug annotation has key "id" and the given UUID value.
|
||||
fn annotation_matches_id(
|
||||
a: &schema::DebugAnnotation,
|
||||
id: &Uuid,
|
||||
interned: &std::collections::HashMap<(u32, u64), String>,
|
||||
seq_id: u32,
|
||||
) -> bool {
|
||||
let key_matches = match &a.name_field {
|
||||
Some(NameField::Name(k)) => k == "id",
|
||||
Some(NameField::NameIid(iid)) => interned
|
||||
.get(&(seq_id, *iid))
|
||||
.map(|s| s == "id")
|
||||
.unwrap_or(false),
|
||||
None => false,
|
||||
};
|
||||
if !key_matches {
|
||||
return false;
|
||||
}
|
||||
match &a.value {
|
||||
Some(luminal_tracing::schema::debug_annotation::Value::StringValue(v)) => {
|
||||
*v == format!("{id}")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record CUDA graph kernel timings as nested slices in perfetto trace
|
||||
pub fn record_cuda_graph_timings(
|
||||
trace: &schema::Trace,
|
||||
cuda_graph_timings: &[(CudaGraphTiming, Uuid)],
|
||||
) -> Vec<schema::TracePacket> {
|
||||
use luminal_tracing::schema::{trace_packet, track_descriptor};
|
||||
|
||||
// Build interned string lookup table
|
||||
let interned = build_interned_strings(trace);
|
||||
|
||||
let mut packets = Vec::new();
|
||||
for (graph_timing, id) in cuda_graph_timings {
|
||||
let parent_info = trace.packet.iter().find_map(|p| {
|
||||
let seq_id = match &p.optional_trusted_packet_sequence_id {
|
||||
Some(trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
|
||||
seq,
|
||||
)) => *seq,
|
||||
_ => 0,
|
||||
};
|
||||
match &p.data {
|
||||
Some(trace_packet::Data::TrackEvent(TrackEvent {
|
||||
r#type: ty,
|
||||
track_uuid,
|
||||
debug_annotations,
|
||||
..
|
||||
})) if *ty == Some(track_event::Type::SliceBegin as i32)
|
||||
&& debug_annotations
|
||||
.iter()
|
||||
.any(|a| annotation_matches_id(a, id, &interned, seq_id)) =>
|
||||
{
|
||||
Some((p.timestamp?, p.timestamp_clock_id?, (*track_uuid)?, seq_id))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
let Some((span_start_time, clock_id, track_uuid, sequence_id)) = parent_info else {
|
||||
continue;
|
||||
};
|
||||
// Use span_start_time + setup_duration + launch_latency as the base for kernel timings.
|
||||
// - setup_duration_ns: time spent on host between span entry and launch call
|
||||
// - launch_latency_ns: GPU-side time from launch to first kernel execution
|
||||
// This ensures kernel spans are accurately positioned within the cuda_graph span.
|
||||
let base_time =
|
||||
span_start_time + graph_timing.setup_duration_ns + graph_timing.launch_latency_ns;
|
||||
for kernel_timing in &graph_timing.kernel_timings {
|
||||
packets.push(schema::TracePacket {
|
||||
timestamp: Some(base_time + kernel_timing.start_ns),
|
||||
timestamp_clock_id: Some(clock_id),
|
||||
optional_trusted_packet_sequence_id: Some(
|
||||
trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
|
||||
sequence_id,
|
||||
),
|
||||
),
|
||||
data: Some(trace_packet::Data::TrackEvent(schema::TrackEvent {
|
||||
track_uuid: Some(track_uuid),
|
||||
r#type: Some(track_event::Type::SliceBegin as i32),
|
||||
name_field: Some(track_event::NameField::Name(
|
||||
kernel_timing.kernel_name.to_owned(),
|
||||
)),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
packets.push(schema::TracePacket {
|
||||
timestamp: Some(base_time + kernel_timing.end_ns),
|
||||
timestamp_clock_id: Some(clock_id),
|
||||
optional_trusted_packet_sequence_id: Some(
|
||||
trace_packet::OptionalTrustedPacketSequenceId::TrustedPacketSequenceId(
|
||||
sequence_id,
|
||||
),
|
||||
),
|
||||
data: Some(trace_packet::Data::TrackEvent(schema::TrackEvent {
|
||||
track_uuid: Some(track_uuid),
|
||||
r#type: Some(track_event::Type::SliceEnd as i32),
|
||||
name_field: Some(track_event::NameField::Name(
|
||||
kernel_timing.kernel_name.to_owned(),
|
||||
)),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
packets
|
||||
}
|
||||
|
||||
pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
/// Returns the output buffer size in elements.
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
|
||||
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
|
||||
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.output_size().dyn_vars().into_iter().collect()
|
||||
}
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns the DType of this kernel's output buffer.
|
||||
/// Used by has_nan_outputs to interpret buffer bytes correctly.
|
||||
/// Default: F32 (most kernels output float).
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this kernel will load from global memory.
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this kernel will store to global memory.
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of floating point operations this kernel performs.
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the name of this kernel for profiling display.
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Unknown"
|
||||
}
|
||||
|
||||
/// Allocate internal buffers this kernel needs. Called once during graph building.
|
||||
/// Default: no internal buffers.
|
||||
fn allocate_internal_buffers(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<CudaSlice<u8>> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns the set of dynamic dimensions that affect internal buffer sizes.
|
||||
/// When any of these dimensions change, internal buffers should be reallocated.
|
||||
/// Default: empty set (no dimensions affect internal buffers).
|
||||
fn internal_buffer_dyn_dims(&self) -> FxHashSet<char> {
|
||||
FxHashSet::default()
|
||||
}
|
||||
|
||||
/// Build kernel parameters. Returns the u64 values to pass to the kernel.
|
||||
/// Default: [output_ptr, input_ptrs..., dyn_dims_ptr (if non-zero)]
|
||||
fn build_params(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
_internal_bufs: &[CudaSlice<u8>],
|
||||
dyn_dims_ptr: u64,
|
||||
) -> Vec<u64> {
|
||||
let mut params = vec![output_ptr];
|
||||
params.extend_from_slice(input_ptrs);
|
||||
if dyn_dims_ptr != 0 {
|
||||
params.push(dyn_dims_ptr);
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
/// Called before each kernel execution. Update internal state if needed.
|
||||
/// `all_buffer_ptrs` contains pointers for all buffers this kernel might use.
|
||||
/// `constants` are device constants returned by compile() that may need updating.
|
||||
fn pre_execute(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_internal_bufs: &mut [CudaSlice<u8>],
|
||||
_constants: &mut FxHashMap<char, CudaSlice<u8>>,
|
||||
_all_buffer_ptrs: &FxHashMap<NodeIndex, u64>,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) {
|
||||
}
|
||||
|
||||
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
|
||||
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
/// If this kernel's output is derived from one of its inputs (copy-then-modify
|
||||
/// or in-place write), return that input index. Used by `resolve_data_node` to
|
||||
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
|
||||
/// roundtrip pattern.
|
||||
///
|
||||
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
|
||||
/// (like Scatter which copies dest→output then scatters into it).
|
||||
fn output_data_input(&self) -> Option<usize> {
|
||||
self.output_aliases_input()
|
||||
}
|
||||
|
||||
/// Returns indices of internal buffers containing timing data, if any.
|
||||
/// Returns (timings_idx, start_times_idx, sm_count).
|
||||
fn timing_buffer_indices(&self) -> Option<(usize, usize, usize)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
luminal::impl_into_ops!(KernelOp);
|
||||
|
||||
// Kernel to host op compilation
|
||||
mod to_host;
|
||||
pub use to_host::{CudaGraphOp, kernel_to_host};
|
||||
1768
crates/luminal_cuda_lite/src/kernel/other_ops.rs
Normal file
1768
crates/luminal_cuda_lite/src/kernel/other_ops.rs
Normal file
File diff suppressed because it is too large
Load Diff
839
crates/luminal_cuda_lite/src/kernel/to_host.rs
Normal file
839
crates/luminal_cuda_lite/src/kernel/to_host.rs
Normal file
@@ -0,0 +1,839 @@
|
||||
//! Compiles KernelOp subgraphs into HostOp (CudaGraphOp).
|
||||
//!
|
||||
//! CudaGraphOp wraps a subgraph of KernelOps into a single executable unit
|
||||
//! that can be executed like any other HostOp.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, sys::CUgraphNode,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
},
|
||||
};
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::HostOp,
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
};
|
||||
|
||||
/// A compiled kernel within a CudaGraphOp.
|
||||
#[derive(Debug)]
|
||||
struct CompiledKernel {
|
||||
/// The node index in the original llir_graph
|
||||
node: NodeIndex,
|
||||
/// The compiled CUDA function
|
||||
function: CudaFunction,
|
||||
/// Launch grid dimensions (blocks)
|
||||
grid: (Expression, Expression, Expression),
|
||||
/// Launch block dimensions (threads)
|
||||
block: (Expression, Expression, Expression),
|
||||
/// Shared memory size
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
/// Graph node handle (set after graph is built)
|
||||
graph_node: Option<CUgraphNode>,
|
||||
/// Kernel name for profiling
|
||||
kernel_name: &'static str,
|
||||
}
|
||||
|
||||
impl CompiledKernel {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
node: NodeIndex,
|
||||
function: CudaFunction,
|
||||
grid: (Expression, Expression, Expression),
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
Self {
|
||||
node,
|
||||
function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
kernel_name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unified kernel params that can hold any number of u64 values.
|
||||
struct UnifiedKernelParams {
|
||||
values: Vec<u64>,
|
||||
ptrs: Vec<*mut std::ffi::c_void>,
|
||||
}
|
||||
|
||||
impl UnifiedKernelParams {
|
||||
fn new(values: Vec<u64>) -> Self {
|
||||
let ptrs = values
|
||||
.iter()
|
||||
.map(|v| v as *const u64 as *mut std::ffi::c_void)
|
||||
.collect();
|
||||
Self { values, ptrs }
|
||||
}
|
||||
|
||||
fn as_cuda_params(&mut self) -> *mut *mut std::ffi::c_void {
|
||||
// Rebuild pointers (in case struct was moved)
|
||||
for (i, v) in self.values.iter().enumerate() {
|
||||
self.ptrs[i] = v as *const u64 as *mut std::ffi::c_void;
|
||||
}
|
||||
self.ptrs.as_mut_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state for CudaGraphOp that needs interior mutability.
|
||||
struct CudaGraphOpState {
|
||||
/// Compiled kernels in topological order
|
||||
kernels: Vec<CompiledKernel>,
|
||||
/// Shared device buffer for dynamic dimensions
|
||||
dyn_dims_buffer: Option<CudaSlice<i32>>,
|
||||
/// CUDA graph handle
|
||||
cuda_graph: Option<CudaGraphHandle>,
|
||||
/// CUDA graph exec handle
|
||||
cuda_graph_exec: Option<CudaGraphExecHandle>,
|
||||
/// Mapping from kernel node to graph node
|
||||
node_to_graph_node: FxHashMap<NodeIndex, CUgraphNode>,
|
||||
/// Kernel params for each kernel
|
||||
kernel_params: Vec<UnifiedKernelParams>,
|
||||
/// Last dynamic dimension values (for change detection)
|
||||
last_dyn_values: FxHashMap<char, usize>,
|
||||
/// Last buffer pointers (for change detection)
|
||||
last_buffer_ptrs: FxHashMap<NodeIndex, u64>,
|
||||
/// Timing events for profiling
|
||||
timing_events: Vec<cudarc::driver::sys::CUevent>,
|
||||
}
|
||||
|
||||
impl CudaGraphOpState {
|
||||
fn new(kernels: Vec<CompiledKernel>) -> Self {
|
||||
Self {
|
||||
kernels,
|
||||
dyn_dims_buffer: None,
|
||||
cuda_graph: None,
|
||||
cuda_graph_exec: None,
|
||||
node_to_graph_node: FxHashMap::default(),
|
||||
kernel_params: Vec::new(),
|
||||
last_dyn_values: FxHashMap::default(),
|
||||
last_buffer_ptrs: FxHashMap::default(),
|
||||
timing_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A CUDA graph operation that implements HostOp.
|
||||
///
|
||||
/// This wraps a subgraph of KernelOps into a single executable CUDA graph.
|
||||
/// It manages graph building, execution, and dynamic updates.
|
||||
pub struct CudaGraphOp {
|
||||
/// All nodes that this graph needs buffers for (kernels + their inputs)
|
||||
buffer_nodes: Vec<NodeIndex>,
|
||||
/// Buffer size requirements for extra nodes (node -> size in elements)
|
||||
buffer_sizes: FxHashMap<NodeIndex, Expression>,
|
||||
/// Dynamic dimensions used by this graph (sorted alphabetically)
|
||||
dyn_dims_order: Vec<char>,
|
||||
/// The CUDA stream (needed for operations)
|
||||
stream: Arc<CudaStream>,
|
||||
/// Mutable state wrapped in RefCell for interior mutability
|
||||
state: RefCell<CudaGraphOpState>,
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn new(
|
||||
buffer_nodes: Vec<NodeIndex>,
|
||||
buffer_sizes: FxHashMap<NodeIndex, Expression>,
|
||||
dyn_dims_order: Vec<char>,
|
||||
stream: Arc<CudaStream>,
|
||||
state: CudaGraphOpState,
|
||||
) -> Self {
|
||||
Self {
|
||||
buffer_nodes,
|
||||
buffer_sizes,
|
||||
dyn_dims_order,
|
||||
stream,
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let state = self.state.borrow();
|
||||
f.debug_struct("CudaGraphOp")
|
||||
.field("n_kernels", &state.kernels.len())
|
||||
.field("n_buffer_nodes", &self.buffer_nodes.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaGraphOp {
|
||||
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
|
||||
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
_kind_children: &[&'a luminal::prelude::ENodeId],
|
||||
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
|
||||
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
|
||||
panic!("CudaGraphOp should not be extracted from egglog")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CudaGraphOp {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
// CudaGraphOp doesn't have a single output - individual kernels have outputs
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CudaGraphOp doesn't have a single output - individual kernels have outputs
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
|
||||
// Only return nodes that actually have buffers
|
||||
// Filter out nodes in buffer_sizes with size 0 (like MegakernelOps)
|
||||
// Keep nodes not in buffer_sizes (external inputs that have their own buffers)
|
||||
self.buffer_nodes
|
||||
.iter()
|
||||
.filter(|n| {
|
||||
match self.buffer_sizes.get(n) {
|
||||
Some(size) => size.exec(&FxHashMap::default()).unwrap_or(1) != 0,
|
||||
None => true, // Not a kernel output, might be an external input
|
||||
}
|
||||
})
|
||||
.copied()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("CudaGraph")
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
let _span = span!(Level::TRACE, "cuda_graph", kernels = state.kernels.len()).entered();
|
||||
|
||||
// Check if dyn_map changed
|
||||
let dyn_map_changed = dyn_map.len() != state.last_dyn_values.len()
|
||||
|| dyn_map
|
||||
.iter()
|
||||
.any(|(k, v)| state.last_dyn_values.get(k) != Some(v));
|
||||
|
||||
// Check if any kernel's internal buffer dimensions changed
|
||||
let mut needs_internal_realloc = false;
|
||||
for kernel in state.kernels.iter() {
|
||||
let internal_dims = kernel.kernel_op.internal_buffer_dyn_dims();
|
||||
if internal_dims
|
||||
.iter()
|
||||
.any(|d| dyn_map.get(d) != state.last_dyn_values.get(d))
|
||||
{
|
||||
needs_internal_realloc = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Reallocate internal buffers if needed
|
||||
if needs_internal_realloc {
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
state.kernel_params.clear();
|
||||
}
|
||||
|
||||
// Allocate dyn_dims_buffer if needed
|
||||
if !self.dyn_dims_order.is_empty() && state.dyn_dims_buffer.is_none() {
|
||||
state.dyn_dims_buffer = Some(
|
||||
stream
|
||||
.alloc_zeros::<i32>(self.dyn_dims_order.len())
|
||||
.expect("Failed to allocate dyn_dims buffer"),
|
||||
);
|
||||
}
|
||||
|
||||
// Update shared dyn_dims buffer if dyn_map changed
|
||||
if dyn_map_changed && !self.dyn_dims_order.is_empty() {
|
||||
let values: Vec<i32> = self
|
||||
.dyn_dims_order
|
||||
.iter()
|
||||
.map(|d| dyn_map.get(d).copied().unwrap_or(0) as i32)
|
||||
.collect();
|
||||
if let Some(buf) = state.dyn_dims_buffer.as_mut() {
|
||||
stream.memcpy_htod(&values, buf)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Build CUDA graph if needed
|
||||
if state.cuda_graph.is_none() {
|
||||
self.build_graph(&mut state, stream, buffers, dyn_map)?;
|
||||
}
|
||||
|
||||
// Collect current buffer pointers
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply output-aliases-input
|
||||
for kernel in state.kernels.iter() {
|
||||
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
|
||||
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
|
||||
{
|
||||
current_buffer_ptrs.insert(kernel.node, input_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Always call pre_execute for each kernel to reset internal state
|
||||
// (e.g., MegakernelOps need work queue, head, barriers, lock reset every execution)
|
||||
for idx in 0..state.kernels.len() {
|
||||
let kernel = &mut state.kernels[idx];
|
||||
kernel.kernel_op.pre_execute(
|
||||
stream,
|
||||
&mut kernel.internal_bufs,
|
||||
&mut kernel.constants,
|
||||
¤t_buffer_ptrs,
|
||||
dyn_map,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if we need to update the graph
|
||||
let buffer_ptrs_changed = current_buffer_ptrs != state.last_buffer_ptrs;
|
||||
let needs_update = dyn_map_changed || buffer_ptrs_changed;
|
||||
|
||||
if needs_update {
|
||||
// Update kernel params
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Build params for each kernel first
|
||||
let num_kernels = state.kernels.len();
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &state.kernels[idx];
|
||||
let output_ptr = current_buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
|
||||
// Now update CUDA graph nodes
|
||||
state
|
||||
.cuda_graph_exec
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.ctx
|
||||
.bind_to_thread()?;
|
||||
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &state.kernels[idx];
|
||||
let graph_node = state.node_to_graph_node[&kernel.node];
|
||||
|
||||
let grid_dim = (
|
||||
kernel.grid.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let block_dim = (
|
||||
kernel.block.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
// Get params pointer first to avoid borrowing state twice
|
||||
let params_ptr = state.kernel_params[idx].as_cuda_params();
|
||||
let exec = state.cuda_graph_exec.as_mut().unwrap();
|
||||
unsafe {
|
||||
exec.update_kernel_node(
|
||||
graph_node, cu_func, grid_dim, block_dim, shared_mem, params_ptr,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
state.last_dyn_values = dyn_map.clone();
|
||||
state.last_buffer_ptrs = current_buffer_ptrs;
|
||||
}
|
||||
|
||||
// Launch the graph
|
||||
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the CUDA graph from compiled kernels.
|
||||
fn build_graph(
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
let mut graph = CudaGraphHandle::new(ctx.clone())?;
|
||||
|
||||
let num_kernels = state.kernels.len();
|
||||
state.kernel_params.clear();
|
||||
state.kernel_params.reserve(num_kernels);
|
||||
|
||||
let tracing_enabled = enabled!(Level::TRACE);
|
||||
if tracing_enabled {
|
||||
let needed_events = num_kernels + 1;
|
||||
while state.timing_events.len() < needed_events {
|
||||
state.timing_events.push(create_cuda_event(&ctx)?);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect buffer pointers
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
|
||||
graph.ctx.bind_to_thread()?;
|
||||
|
||||
let mut prev_graph_node: Option<CUgraphNode> = None;
|
||||
|
||||
for idx in 0..num_kernels {
|
||||
// Allocate internal buffers if not already done
|
||||
{
|
||||
let kernel = &mut state.kernels[idx];
|
||||
if kernel.internal_bufs.is_empty() {
|
||||
kernel.internal_bufs =
|
||||
kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
// Call pre_execute to initialize internal state (e.g., populate buffer array for MegakernelOps)
|
||||
{
|
||||
let kernel = &mut state.kernels[idx];
|
||||
kernel.kernel_op.pre_execute(
|
||||
stream,
|
||||
&mut kernel.internal_bufs,
|
||||
&mut kernel.constants,
|
||||
&buffer_ptrs,
|
||||
dyn_map,
|
||||
);
|
||||
}
|
||||
|
||||
let kernel = &state.kernels[idx];
|
||||
let grid_dim = (
|
||||
kernel.grid.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let block_dim = (
|
||||
kernel.block.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
Some(state.timing_events[idx])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let deps: &[CUgraphNode] = match (&prev_graph_node, timing_event) {
|
||||
(Some(prev), Some(event)) => {
|
||||
let event_node = graph.add_event_record_node(&[*prev], event)?;
|
||||
prev_graph_node = Some(event_node);
|
||||
std::slice::from_ref(prev_graph_node.as_ref().unwrap())
|
||||
}
|
||||
(None, Some(event)) => {
|
||||
let event_node = graph.add_event_record_node(&[], event)?;
|
||||
prev_graph_node = Some(event_node);
|
||||
std::slice::from_ref(prev_graph_node.as_ref().unwrap())
|
||||
}
|
||||
(Some(prev), None) => std::slice::from_ref(prev),
|
||||
(None, None) => &[],
|
||||
};
|
||||
|
||||
let graph_node = unsafe {
|
||||
graph.add_kernel_node(
|
||||
deps,
|
||||
cu_func,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
shared_mem,
|
||||
params.as_cuda_params(),
|
||||
)?
|
||||
};
|
||||
|
||||
state.node_to_graph_node.insert(kernel_node, graph_node);
|
||||
state.kernels[idx].graph_node = Some(graph_node);
|
||||
state.kernel_params.push(params);
|
||||
prev_graph_node = Some(graph_node);
|
||||
}
|
||||
|
||||
if tracing_enabled && let Some(prev) = prev_graph_node {
|
||||
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
|
||||
}
|
||||
|
||||
let exec = graph.instantiate()?;
|
||||
|
||||
state.cuda_graph = Some(graph);
|
||||
state.cuda_graph_exec = Some(exec);
|
||||
state.last_dyn_values = dyn_map.clone();
|
||||
state.last_buffer_ptrs = buffer_ptrs;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphOp {
|
||||
fn drop(&mut self) {
|
||||
let mut state = self.state.borrow_mut();
|
||||
|
||||
// Destroy timing events first
|
||||
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
|
||||
if let Some(ctx) = ctx {
|
||||
for event in state.timing_events.drain(..) {
|
||||
destroy_cuda_event(&ctx, event);
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
|
||||
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
|
||||
// so it must be destroyed first to avoid dangling pointer issues.
|
||||
drop(state.cuda_graph_exec.take());
|
||||
drop(state.cuda_graph.take());
|
||||
|
||||
// Now safe to free dynamically allocated GPU buffers
|
||||
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
|
||||
|
||||
// Constants point to __constant__ memory in the CUDA module,
|
||||
// not dynamically allocated — must not be freed.
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
let constants = std::mem::take(&mut kernel.constants);
|
||||
for (_k, v) in constants {
|
||||
std::mem::forget(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile KernelOp subgraphs in the LLIR graph into CudaGraphOps.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Finds all KernelOp nodes in the graph
|
||||
/// 2. Partitions them into convex subgraphs
|
||||
/// 3. For each subgraph, creates a CudaGraphOp (which implements HostOp)
|
||||
/// 4. Adds the CudaGraphOp node to the llir_graph with appropriate edges
|
||||
///
|
||||
/// Note: KernelOp nodes remain in the graph for buffer allocation and edge tracking.
|
||||
/// Their execution is handled by the CudaGraphOp via the CUDA graph API.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn kernel_to_host(
|
||||
llir_graph: &mut LLIRGraph,
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) {
|
||||
let _span = span!(Level::TRACE, "kernel_to_host").entered();
|
||||
|
||||
let kernel_ops_in_graph = llir_graph
|
||||
.node_indices()
|
||||
.filter(|n| llir_graph[*n].to_dialect::<dyn KernelOp>().is_some())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
if kernel_ops_in_graph.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
// Track all CudaGraphOp nodes and their subgraphs for edge creation
|
||||
let mut cuda_graph_subgraphs: Vec<(NodeIndex, FxHashSet<NodeIndex>)> = Vec::new();
|
||||
|
||||
for subgraph in kernel_subgraphs {
|
||||
// Compile kernels in topological order
|
||||
let topo_order: Vec<_> = toposort(&*llir_graph, None)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.filter(|n| subgraph.contains(n))
|
||||
.collect();
|
||||
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
|
||||
}
|
||||
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Compile all kernels with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
let final_global = get_global_dyn_dims();
|
||||
// Clear global ordering now that all kernels are compiled
|
||||
clear_global_dyn_dims();
|
||||
|
||||
// Use the final global ordering if it was extended during compilation
|
||||
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
|
||||
final_order
|
||||
} else {
|
||||
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
|
||||
dims.sort();
|
||||
dims
|
||||
};
|
||||
|
||||
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
|
||||
|
||||
// Create CudaGraphOp with RefCell for interior mutability
|
||||
let state = CudaGraphOpState::new(kernels);
|
||||
|
||||
let cuda_graph_op = CudaGraphOp::new(
|
||||
buffer_nodes,
|
||||
all_buffer_sizes,
|
||||
dyn_dims_order,
|
||||
cuda_stream.clone(),
|
||||
state,
|
||||
);
|
||||
|
||||
// Add CudaGraphOp to llir_graph as a HostOp
|
||||
let cuda_graph_node =
|
||||
llir_graph.add_node(LLIROp::new(Box::new(cuda_graph_op) as Box<dyn HostOp>));
|
||||
|
||||
// Track which kernel nodes belong to this CudaGraphOp
|
||||
for kernel_node in &subgraph {
|
||||
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
llir_graph.add_edge(*input, cuda_graph_node, ());
|
||||
}
|
||||
|
||||
// Note: We intentionally keep the kernel nodes in the graph.
|
||||
// They are needed for:
|
||||
// 1. Buffer allocation (their output_size determines buffer sizes)
|
||||
// 2. Edge tracking (other ops like cuBLAS reference specific kernel outputs)
|
||||
// The CudaGraphOp handles their execution via the CUDA graph API.
|
||||
}
|
||||
|
||||
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
|
||||
// This ensures proper execution ordering when a kernel in one CudaGraphOp
|
||||
// produces output consumed by a kernel in another CudaGraphOp.
|
||||
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
|
||||
|
||||
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
|
||||
// Find external consumers that are kernels belonging to other CudaGraphOps
|
||||
for producer_node in subgraph {
|
||||
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
|
||||
let consumer = edge.target();
|
||||
if subgraph.contains(&consumer) {
|
||||
continue; // Same subgraph
|
||||
}
|
||||
// Check if consumer is a kernel in another CudaGraphOp
|
||||
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer)
|
||||
&& consumer_cuda_graph != *cuda_graph_node
|
||||
{
|
||||
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
|
||||
}
|
||||
// Also add edges to HostOps (like cuBLAS ops) that consume our outputs
|
||||
if llir_graph[consumer]
|
||||
.to_dialect::<dyn super::super::host::HostOp>()
|
||||
.is_some()
|
||||
{
|
||||
edges_to_add.push((*cuda_graph_node, consumer));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
|
||||
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
|
||||
let topo = toposort(&*llir_graph, None).unwrap();
|
||||
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, n) in topo.iter().enumerate() {
|
||||
topo_pos.insert(*n, i);
|
||||
}
|
||||
for (src, dst) in edges_to_add {
|
||||
// Only add forward edges (src before dst in topo order) to avoid creating cycles
|
||||
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
|
||||
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
|
||||
if src_pos >= dst_pos {
|
||||
continue; // Skip back-edges
|
||||
}
|
||||
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
}
|
||||
}
|
||||
309
crates/luminal_cuda_lite/src/lib.rs
Normal file
309
crates/luminal_cuda_lite/src/lib.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::{
|
||||
driver::{CudaContext, DriverError, sys as driver_sys},
|
||||
nvrtc::{
|
||||
Ptx,
|
||||
result::{self as nvrtc_result, NvrtcError},
|
||||
sys as nvrtc_sys,
|
||||
},
|
||||
};
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::F64 => "double",
|
||||
DType::F32 => "float",
|
||||
DType::F16 => "half",
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
DType::U8 => "unsigned char",
|
||||
DType::Bool => "unsigned char",
|
||||
DType::F8E4M3 => "__nv_fp8_e4m3",
|
||||
DType::F8E5M2 => "__nv_fp8_e5m2",
|
||||
DType::F8UE8M0 => "__nv_fp8_e8m0",
|
||||
DType::F6E2M3 => "__nv_fp6_e2m3",
|
||||
DType::F6E3M2 => "__nv_fp6_e3m2",
|
||||
DType::F4E2M1 => "__nv_fp4_e2m1",
|
||||
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
|
||||
}
|
||||
}
|
||||
|
||||
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum CudaModuleImageCompileFailure {
|
||||
ComputeCapability(DriverError),
|
||||
Nvrtc {
|
||||
stage: &'static str,
|
||||
error: NvrtcError,
|
||||
},
|
||||
NoModuleImageProduced,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CudaModuleImageCompileError {
|
||||
pub target_arch: Option<String>,
|
||||
pub driver_version: Option<i32>,
|
||||
pub runtime_version: Option<i32>,
|
||||
pub nvrtc_options: Vec<String>,
|
||||
pub nvrtc_log: Option<String>,
|
||||
pub failure: CudaModuleImageCompileFailure,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CudaModuleImageCompileError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to compile CUDA module image")?;
|
||||
if let Some(target_arch) = &self.target_arch {
|
||||
write!(f, " for {target_arch}")?;
|
||||
}
|
||||
match &self.failure {
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error) => {
|
||||
write!(f, ": failed to query compute capability: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
|
||||
write!(f, ": NVRTC {stage} failed: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced => {
|
||||
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
|
||||
}
|
||||
}
|
||||
if let Some(version) = self.driver_version {
|
||||
write!(f, " | driver {}", format_cuda_version(version))?;
|
||||
}
|
||||
if let Some(version) = self.runtime_version {
|
||||
write!(f, " | runtime {}", format_cuda_version(version))?;
|
||||
}
|
||||
if !self.nvrtc_options.is_empty() {
|
||||
write!(f, " | options {:?}", self.nvrtc_options)?;
|
||||
}
|
||||
if let Some(log) = &self.nvrtc_log {
|
||||
write!(f, " | log: {log}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CudaModuleImageCompileError {}
|
||||
|
||||
fn format_cuda_version(version: i32) -> String {
|
||||
format!("{}.{}", version / 1000, (version % 1000) / 10)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_include_paths() -> Vec<String> {
|
||||
let mut include_paths = Vec::new();
|
||||
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
|
||||
if let Ok(root) = std::env::var(env_var) {
|
||||
let path = format!("{root}/include");
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
for path in CUDA_NVRTC_INCLUDE_PATHS {
|
||||
let path = path.to_string();
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
include_paths
|
||||
}
|
||||
|
||||
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
let mut driver_version = 0;
|
||||
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
|
||||
.result()
|
||||
.ok()
|
||||
.map(|_| driver_version);
|
||||
|
||||
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
|
||||
// resolves newer libcudart symbols that may not exist in the installed runtime.
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
.map(|path| format!("--include-path={path}"))
|
||||
.collect::<Vec<_>>();
|
||||
options.push(format!("--gpu-architecture={target_arch}"));
|
||||
options
|
||||
}
|
||||
|
||||
fn build_module_image_compile_error(
|
||||
target_arch: Option<String>,
|
||||
driver_version: Option<i32>,
|
||||
runtime_version: Option<i32>,
|
||||
nvrtc_options: &[String],
|
||||
nvrtc_log: Option<String>,
|
||||
failure: CudaModuleImageCompileFailure,
|
||||
) -> CudaModuleImageCompileError {
|
||||
CudaModuleImageCompileError {
|
||||
target_arch,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
nvrtc_options: nvrtc_options.to_vec(),
|
||||
nvrtc_log,
|
||||
failure,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
|
||||
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
|
||||
.to_string_lossy()
|
||||
.trim_end_matches('\0')
|
||||
.trim()
|
||||
.to_string();
|
||||
if log.is_empty() { None } else { Some(log) }
|
||||
}
|
||||
|
||||
#[allow(clippy::slow_vector_initialization)]
|
||||
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
let mut cubin_size = 0usize;
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
|
||||
if cubin_size == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
ctx: &Arc<CudaContext>,
|
||||
src: S,
|
||||
) -> Result<Ptx, CudaModuleImageCompileError> {
|
||||
let (driver_version, runtime_version) = cuda_driver_diagnostics();
|
||||
let (major, minor) = ctx.compute_capability().map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
None,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&[],
|
||||
None,
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error),
|
||||
)
|
||||
})?;
|
||||
let target_arch = format!("sm_{major}{minor}");
|
||||
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
|
||||
|
||||
let source = CString::new(src.as_ref().as_bytes())
|
||||
.expect("CUDA source code cannot contain null terminators");
|
||||
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
Some(target_arch.clone()),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
None,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "create_program",
|
||||
error,
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "compile_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let cubin = match get_cubin(program) {
|
||||
Ok(cubin) => cubin,
|
||||
Err(error) => {
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "get_cubin",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "destroy_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
if cubin.is_empty() {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ptx::from_binary(cubin))
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 273,
|
||||
"NVIDIA H100 PCIe" => 2_000,
|
||||
"NVIDIA H100 SXM" => 3_350,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in TFLOPs
|
||||
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 125, // forced to use tf32 flops
|
||||
"NVIDIA H100 PCIe" => 756,
|
||||
"NVIDIA H100 SXM" => 989,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
1805
crates/luminal_cuda_lite/src/runtime.rs
Normal file
1805
crates/luminal_cuda_lite/src/runtime.rs
Normal file
File diff suppressed because it is too large
Load Diff
344
crates/luminal_cuda_lite/src/tests/bucket_tests.rs
Normal file
344
crates/luminal_cuda_lite/src/tests/bucket_tests.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::*;
|
||||
use luminal::prelude::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
|
||||
/// Returns (cx, input_node, output_node).
|
||||
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', 4));
|
||||
let b = (a + a).output();
|
||||
(cx, a.id, b.id)
|
||||
}
|
||||
|
||||
/// Helper: build a matmul graph with dynamic dim 's'.
|
||||
/// Computes (s, K) @ (K, N) -> (s, N)
|
||||
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let c = a.matmul(b).output();
|
||||
(cx, a.id, b.id, c.id)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_dispatch_simple() {
|
||||
// Tests that bucketed compilation produces correct results for different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Set dummy input for search
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..4], &expected, 1e-5, 1e-5);
|
||||
|
||||
// Test bucket 2: s=3
|
||||
cx.set_dim('s', 3);
|
||||
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..12], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_matmul_dynamic() {
|
||||
// Tests matmul with bucketed dynamic dim
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 8;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s1 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=1 (1xK @ KxN -> 1xN)
|
||||
let mut expected_s1 = vec![0.0f32; n];
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s1[j] += a_data[i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
|
||||
|
||||
// Execute at s=4
|
||||
cx.set_dim('s', 4);
|
||||
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_4.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s4 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=4 (4xK @ KxN -> 4xN)
|
||||
let mut expected_s4 = vec![0.0f32; 4 * n];
|
||||
for row in 0..4 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_results_match_unbucketed() {
|
||||
// Tests that bucketed results match non-bucketed results for the same graph
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let seed = 42u64;
|
||||
|
||||
// Non-bucketed run
|
||||
let (mut cx1, a1, b1) = build_dynamic_add_graph();
|
||||
cx1.set_dim('s', 3);
|
||||
cx1.build_search_space::<CudaRuntime>();
|
||||
let mut rt1 = CudaRuntime::initialize(stream.clone());
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
|
||||
// Bucketed run with bucket that covers s=3
|
||||
let (mut cx2, a2, b2) = build_dynamic_add_graph();
|
||||
cx2.set_dim('s', 3);
|
||||
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
|
||||
cx2.build_search_space::<CudaRuntime>();
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
|
||||
// Results should match — same graph, same search seed, same dyn_map
|
||||
assert_eq!(result_unbucketed.len(), result_bucketed.len());
|
||||
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "No bucket matches")]
|
||||
fn test_bucket_out_of_range_panics() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
// Can't trigger panic without GPU, skip gracefully
|
||||
panic!("No bucket matches dyn_map");
|
||||
};
|
||||
|
||||
let (mut cx, a, _b) = build_dynamic_add_graph();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
rt.set_data(a, vec![1.0f32; 40]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_no_buckets_backward_compat() {
|
||||
// No buckets set → should behave identically to old path
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
cx.set_dim('s', 2);
|
||||
|
||||
// No set_dim_buckets call
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..8], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_representative_override() {
|
||||
// Tests that custom representative works
|
||||
let bucket = DimBucket::new(2, 32).representative(16);
|
||||
assert_eq!(bucket.representative_value(), 16);
|
||||
|
||||
let bucket_default = DimBucket::new(2, 32);
|
||||
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
|
||||
|
||||
let exact = DimBucket::new(1, 1);
|
||||
assert_eq!(exact.representative_value(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_switch_preserves_weights() {
|
||||
// Tests that switching between buckets still sees the correct weight data
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 4;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1a = rt.get_f32(c);
|
||||
|
||||
// Switch to bucket 2 (s=3)
|
||||
cx.set_dim('s', 3);
|
||||
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_3.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_3 = rt.get_f32(c);
|
||||
|
||||
// Switch back to bucket 1 (s=1) — weights should still work
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1b = rt.get_f32(c);
|
||||
|
||||
// First and last s=1 results should match exactly
|
||||
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
|
||||
|
||||
// Verify s=3 result correctness
|
||||
let mut expected_3 = vec![0.0f32; 3 * n];
|
||||
for row in 0..3 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_multiple_executions_same_bucket() {
|
||||
// Tests multiple executions within the same bucket with different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
cx.set_dim('s', s);
|
||||
let n = s * 4;
|
||||
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..n], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Overlapping buckets")]
|
||||
fn test_bucket_overlapping_ranges_panics() {
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dim_bucket_contains() {
|
||||
let b = DimBucket::new(2, 10);
|
||||
assert!(!b.contains(1));
|
||||
assert!(b.contains(2));
|
||||
assert!(b.contains(5));
|
||||
assert!(b.contains(10));
|
||||
assert!(!b.contains(11));
|
||||
|
||||
// Exact bucket
|
||||
let exact = DimBucket::new(3, 3);
|
||||
assert!(!exact.contains(2));
|
||||
assert!(exact.contains(3));
|
||||
assert!(!exact.contains(4));
|
||||
}
|
||||
416
crates/luminal_cuda_lite/src/tests/consumed_buffer_tests.rs
Normal file
416
crates/luminal_cuda_lite/src/tests/consumed_buffer_tests.rs
Normal file
@@ -0,0 +1,416 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use rand::SeedableRng;
|
||||
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice, validate_choice_set};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Helper: build search space and extract all possible kernel names across many random choices.
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_names = Vec::new();
|
||||
// Try many random extractions to cover both alternatives
|
||||
for _ in 0..20 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
let name = k.kernel_name().to_string();
|
||||
if !all_names.contains(&name) {
|
||||
all_names.push(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_names
|
||||
}
|
||||
|
||||
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
|
||||
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
|
||||
/// the ConsumedBuffer (not in any other ICons).
|
||||
#[test]
|
||||
fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into dest at indexes
|
||||
let _result = src.scatter(indexes, dest).output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should be available (dest is not shared)
|
||||
assert!(
|
||||
names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy to be available but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
|
||||
/// cleanup rule should fire, deleting the ConsumedBuffer. This makes KernelScatterNoCopy
|
||||
/// invalid, so it should NOT appear in any extraction.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_when_dest_shared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into dest at indexes
|
||||
let scatter_result = src.scatter(indexes, dest);
|
||||
|
||||
// Also use dest directly in another op (add with itself) — this makes dest shared
|
||||
let _dest_also_used = (dest + dest).output();
|
||||
let _result = scatter_result.output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should NOT be available (dest is shared with the add op)
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest is shared, got: {:?}",
|
||||
names
|
||||
);
|
||||
// Regular KernelScatter should be present
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// Actually execute the scatter and verify correctness.
|
||||
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
|
||||
#[test]
|
||||
fn test_scatter_execution_correctness() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: [0.0, 1.0, 2.0, 3.0, 4.0]
|
||||
let dest = cx.tensor(5).persist();
|
||||
// src: [10.0, 20.0, 30.0]
|
||||
let src = cx.tensor(3).persist();
|
||||
// indexes: [1, 3, 4]
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
|
||||
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
|
||||
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
|
||||
|
||||
// Try many random extractions to cover both Scatter and ScatterNoCopy
|
||||
let mut rng = rand::rng();
|
||||
let mut tested_scatter = false;
|
||||
let mut tested_nocopy = false;
|
||||
|
||||
for _ in 0..50 {
|
||||
let choices = random_initial_choice(egraph, &mut rng);
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
// Check which scatter variant was selected
|
||||
let mut has_nocopy = false;
|
||||
let mut has_scatter = false;
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
match k.kernel_name() {
|
||||
"ScatterNoCopy" => has_nocopy = true,
|
||||
"Scatter" => has_scatter = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(dest, vec![0.0f32, 1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(src, vec![10.0f32, 20.0, 30.0]);
|
||||
rt.set_data(indexes, vec![1i32, 3, 4]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let actual = rt.get_f32(result);
|
||||
|
||||
let variant = if has_nocopy {
|
||||
tested_nocopy = true;
|
||||
"ScatterNoCopy"
|
||||
} else if has_scatter {
|
||||
tested_scatter = true;
|
||||
"Scatter"
|
||||
} else {
|
||||
"Unknown"
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
|
||||
actual, expected
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
|
||||
tested_scatter, tested_nocopy
|
||||
);
|
||||
assert!(
|
||||
tested_nocopy,
|
||||
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test the KV-cache round-trip pattern: scatter → remove_buffer → set_buffer → scatter again.
|
||||
/// This mimics how the llama model uses scatter for KV cache updates.
|
||||
#[test]
|
||||
fn test_scatter_kv_cache_roundtrip() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// KV cache: [5] elements (simulating a small cache)
|
||||
let cache_in = cx.named_tensor("cache", 5).persist();
|
||||
// New value to scatter: [1] element
|
||||
let src = cx.tensor(1).persist();
|
||||
// Index: [1] element (position to write)
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into cache at index position
|
||||
let cache_out = src.scatter(indexes, cache_in);
|
||||
// Also read the scatter output (simulates attention reading from cache)
|
||||
let read_out = (cache_out + 0.0).output();
|
||||
// Return cache for round-trip
|
||||
let cache_output = cache_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
// Must set input data BEFORE search (profiler needs valid buffers)
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Print which scatter variant was selected
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read1 = rt.get_f32(read_out);
|
||||
println!("After step 1 (scatter 10.0 at pos 0): {:?}", read1);
|
||||
assert_eq!(
|
||||
read1,
|
||||
vec![10.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"Step 1 read_out mismatch"
|
||||
);
|
||||
|
||||
// Round-trip: remove cache output buffer, set as new cache input
|
||||
let cache_buf = rt.remove_buffer(cache_output);
|
||||
rt.set_buffer(cache_in, cache_buf);
|
||||
|
||||
// Step 2: Scatter 20.0 at position 1
|
||||
rt.set_data(src, vec![20.0f32]);
|
||||
rt.set_data(indexes, vec![1i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read2 = rt.get_f32(read_out);
|
||||
println!("After step 2 (scatter 20.0 at pos 1): {:?}", read2);
|
||||
assert_eq!(
|
||||
read2,
|
||||
vec![10.0, 20.0, 0.0, 0.0, 0.0],
|
||||
"Step 2 read_out mismatch"
|
||||
);
|
||||
|
||||
// Round-trip again
|
||||
let cache_buf = rt.remove_buffer(cache_output);
|
||||
rt.set_buffer(cache_in, cache_buf);
|
||||
|
||||
// Step 3: Scatter 30.0 at position 2
|
||||
rt.set_data(src, vec![30.0f32]);
|
||||
rt.set_data(indexes, vec![2i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read3 = rt.get_f32(read_out);
|
||||
println!("After step 3 (scatter 30.0 at pos 2): {:?}", read3);
|
||||
assert_eq!(
|
||||
read3,
|
||||
vec![10.0, 20.0, 30.0, 0.0, 0.0],
|
||||
"Step 3 read_out mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
|
||||
/// Also verifies graph_break interaction.
|
||||
#[test]
|
||||
fn test_scatter_dual_cache_with_graph_break() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// Two caches (like K and V)
|
||||
let k_cache = cx.named_tensor("k_cache", 5).persist();
|
||||
let v_cache = cx.named_tensor("v_cache", 5).persist();
|
||||
|
||||
// Input values
|
||||
let k_new = cx.tensor(1).persist();
|
||||
let v_new = cx.tensor(1).persist();
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
// Scatter into both caches
|
||||
let k_out = k_new.scatter(indexes, k_cache);
|
||||
let v_out = v_new.scatter(indexes, v_cache);
|
||||
|
||||
// Read both (simulates attention using the scattered caches)
|
||||
let k_read = k_out + 0.0;
|
||||
let v_read = v_out + 0.0;
|
||||
|
||||
// Compute something from the scattered values (simulates attention output)
|
||||
let attn = k_read * v_read;
|
||||
|
||||
// Output everything
|
||||
let attn_out = attn.output();
|
||||
let k_cache_out = k_out.output();
|
||||
let v_cache_out = v_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(v_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(k_new, vec![2.0f32]);
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Dual test selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(v_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(k_new, vec![2.0f32]);
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn1 = rt.get_f32(attn_out);
|
||||
println!("Attn step 1: {:?}", attn1);
|
||||
// k=[2,0,0,0,0], v=[3,0,0,0,0], attn = k*v = [6,0,0,0,0]
|
||||
assert_eq!(attn1, vec![6.0, 0.0, 0.0, 0.0, 0.0], "Step 1 attn mismatch");
|
||||
|
||||
// Round-trip
|
||||
let k_buf = rt.remove_buffer(k_cache_out);
|
||||
let v_buf = rt.remove_buffer(v_cache_out);
|
||||
rt.set_buffer(k_cache, k_buf);
|
||||
rt.set_buffer(v_cache, v_buf);
|
||||
|
||||
// Step 2: scatter k=4.0, v=5.0 at position 1
|
||||
rt.set_data(k_new, vec![4.0f32]);
|
||||
rt.set_data(v_new, vec![5.0f32]);
|
||||
rt.set_data(indexes, vec![1i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn2 = rt.get_f32(attn_out);
|
||||
println!("Attn step 2: {:?}", attn2);
|
||||
// k=[2,4,0,0,0], v=[3,5,0,0,0], attn = k*v = [6,20,0,0,0]
|
||||
assert_eq!(
|
||||
attn2,
|
||||
vec![6.0, 20.0, 0.0, 0.0, 0.0],
|
||||
"Step 2 attn mismatch"
|
||||
);
|
||||
|
||||
// Round-trip
|
||||
let k_buf = rt.remove_buffer(k_cache_out);
|
||||
let v_buf = rt.remove_buffer(v_cache_out);
|
||||
rt.set_buffer(k_cache, k_buf);
|
||||
rt.set_buffer(v_cache, v_buf);
|
||||
|
||||
// Step 3: scatter k=6.0, v=7.0 at position 2
|
||||
rt.set_data(k_new, vec![6.0f32]);
|
||||
rt.set_data(v_new, vec![7.0f32]);
|
||||
rt.set_data(indexes, vec![2i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn3 = rt.get_f32(attn_out);
|
||||
println!("Attn step 3: {:?}", attn3);
|
||||
// k=[2,4,6,0,0], v=[3,5,7,0,0], attn = k*v = [6,20,42,0,0]
|
||||
assert_eq!(
|
||||
attn3,
|
||||
vec![6.0, 20.0, 42.0, 0.0, 0.0],
|
||||
"Step 3 attn mismatch"
|
||||
);
|
||||
}
|
||||
14
crates/luminal_cuda_lite/src/tests/mod.rs
Normal file
14
crates/luminal_cuda_lite/src/tests/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
pub mod utilities;
|
||||
|
||||
#[cfg(test)]
|
||||
mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
685
crates/luminal_cuda_lite/src/tests/model_fuzz.rs
Normal file
685
crates/luminal_cuda_lite/src/tests/model_fuzz.rs
Normal file
@@ -0,0 +1,685 @@
|
||||
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
|
||||
//!
|
||||
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
|
||||
//! reference to catch incorrect HLIR kernel fallback rewrites.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
use super::utilities::{assert_close, fuzz_genomes, get_cuda_stream, random_f32_vec};
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Number of genomes to fuzz per test (higher than default GENOME_FUZZ_COUNT=20).
|
||||
const FUZZ_COUNT: usize = 100;
|
||||
|
||||
// ============================================================================
|
||||
// RMSNorm helper (used by all three models)
|
||||
// ============================================================================
|
||||
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
fn rms_norm_ref(
|
||||
x: &candle_core::Tensor,
|
||||
weight: &candle_core::Tensor,
|
||||
eps: f64,
|
||||
) -> candle_core::Tensor {
|
||||
let dims = x.dims();
|
||||
let last_dim = dims[dims.len() - 1];
|
||||
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
|
||||
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
|
||||
let normed = x.broadcast_mul(&rsqrt).unwrap();
|
||||
normed
|
||||
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SwiGLU MLP helper (used by all three models)
|
||||
// ============================================================================
|
||||
|
||||
fn swiglu_mlp(
|
||||
x: GraphTensor,
|
||||
w_gate: GraphTensor,
|
||||
w_up: GraphTensor,
|
||||
w_down: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let gate = x.matmul(w_gate.t()).swish();
|
||||
let up = x.matmul(w_up.t());
|
||||
(gate * up).matmul(w_down.t())
|
||||
}
|
||||
|
||||
fn swiglu_mlp_ref(
|
||||
x: &candle_core::Tensor,
|
||||
w_gate: &candle_core::Tensor,
|
||||
w_up: &candle_core::Tensor,
|
||||
w_down: &candle_core::Tensor,
|
||||
) -> candle_core::Tensor {
|
||||
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
|
||||
let up = x.matmul(&w_up.t().unwrap()).unwrap();
|
||||
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generic test functions
|
||||
// ============================================================================
|
||||
|
||||
/// Test a SwiGLU MLP block at given dimensions with genome fuzzing.
|
||||
fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((seq, hidden));
|
||||
let w_gate = cx.tensor((intermediate, hidden));
|
||||
let w_up = cx.tensor((intermediate, hidden));
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
|
||||
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
|
||||
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
|
||||
let ref_gate =
|
||||
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_up =
|
||||
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_down =
|
||||
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
|
||||
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
1e-2,
|
||||
1e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Test RMSNorm + matmul projection at given dimensions with genome fuzzing.
|
||||
fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((seq, hidden));
|
||||
let norm_w = cx.tensor(hidden);
|
||||
let proj_w = cx.tensor((proj_dim, hidden));
|
||||
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
let norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
|
||||
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), hidden, &device).unwrap();
|
||||
let ref_proj =
|
||||
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
|
||||
let normed = rms_norm_ref(&ref_input, &ref_norm, eps as f64);
|
||||
let expected = normed.matmul(&ref_proj.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
1e-2,
|
||||
1e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Test a full transformer layer (norm -> proj -> norm -> MLP) without attention.
|
||||
fn fuzz_layer_no_attn(
|
||||
seq: usize,
|
||||
hidden: usize,
|
||||
intermediate: usize,
|
||||
proj_dim: usize,
|
||||
eps: f32,
|
||||
seed: u64,
|
||||
) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((seq, hidden));
|
||||
let attn_norm_w = cx.tensor(hidden);
|
||||
let proj_w = cx.tensor((proj_dim, hidden));
|
||||
let o_proj_w = cx.tensor((hidden, proj_dim));
|
||||
let mlp_norm_w = cx.tensor(hidden);
|
||||
let w_gate = cx.tensor((intermediate, hidden));
|
||||
let w_up = cx.tensor((intermediate, hidden));
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
|
||||
let normed = rms_norm(input, attn_norm_w, eps);
|
||||
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
|
||||
let x = input + proj_out;
|
||||
let mlp_normed = rms_norm(x, mlp_norm_w, eps);
|
||||
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
|
||||
let out = (x + mlp_out).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
let attn_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 1, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let proj_data = random_f32_vec(proj_dim * hidden, seed + 2, -0.3, 0.3);
|
||||
let o_proj_data = random_f32_vec(hidden * proj_dim, seed + 3, -0.3, 0.3);
|
||||
let mlp_norm_data: Vec<f32> = random_f32_vec(hidden, seed + 4, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let gate_data = random_f32_vec(intermediate * hidden, seed + 5, -0.3, 0.3);
|
||||
let up_data = random_f32_vec(intermediate * hidden, seed + 6, -0.3, 0.3);
|
||||
let down_data = random_f32_vec(hidden * intermediate, seed + 7, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(attn_norm_w, attn_norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt.set_data(o_proj_w, o_proj_data.clone());
|
||||
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
// Candle reference
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
|
||||
let ref_attn_norm =
|
||||
candle_core::Tensor::from_vec(attn_norm_data.clone(), hidden, &device).unwrap();
|
||||
let ref_proj =
|
||||
candle_core::Tensor::from_vec(proj_data.clone(), (proj_dim, hidden), &device).unwrap();
|
||||
let ref_o_proj =
|
||||
candle_core::Tensor::from_vec(o_proj_data.clone(), (hidden, proj_dim), &device).unwrap();
|
||||
let ref_mlp_norm =
|
||||
candle_core::Tensor::from_vec(mlp_norm_data.clone(), hidden, &device).unwrap();
|
||||
let ref_gate =
|
||||
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_up =
|
||||
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_down =
|
||||
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
|
||||
|
||||
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, eps as f64);
|
||||
let proj_out = normed
|
||||
.matmul(&ref_proj.t().unwrap())
|
||||
.unwrap()
|
||||
.matmul(&ref_o_proj.t().unwrap())
|
||||
.unwrap();
|
||||
let x_ref = (&ref_input + proj_out).unwrap();
|
||||
let mlp_normed = rms_norm_ref(&x_ref, &ref_mlp_norm, eps as f64);
|
||||
let mlp_out = swiglu_mlp_ref(&mlp_normed, &ref_gate, &ref_up, &ref_down);
|
||||
let expected_t = (x_ref + mlp_out).unwrap();
|
||||
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 2e-2, 2e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(attn_norm_w, attn_norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt.set_data(o_proj_w, o_proj_data.clone());
|
||||
rt.set_data(mlp_norm_w, mlp_norm_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
2e-2,
|
||||
2e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Test a SwiGLU MLP with HLIR-only to specifically verify
|
||||
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
|
||||
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((seq, hidden));
|
||||
let w_gate = cx.tensor((intermediate, hidden));
|
||||
let w_up = cx.tensor((intermediate, hidden));
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
let gate_data = random_f32_vec(intermediate * hidden, seed + 1, -0.3, 0.3);
|
||||
let up_data = random_f32_vec(intermediate * hidden, seed + 2, -0.3, 0.3);
|
||||
let down_data = random_f32_vec(hidden * intermediate, seed + 3, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.clone(), (seq, hidden), &device).unwrap();
|
||||
let ref_gate =
|
||||
candle_core::Tensor::from_vec(gate_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_up =
|
||||
candle_core::Tensor::from_vec(up_data.clone(), (intermediate, hidden), &device).unwrap();
|
||||
let ref_down =
|
||||
candle_core::Tensor::from_vec(down_data.clone(), (hidden, intermediate), &device).unwrap();
|
||||
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
1e-2,
|
||||
1e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Llama-specific tests
|
||||
// Llama 3 8B: HIDDEN=4096, INTERMEDIATE=14336, HEAD_DIM=128
|
||||
// Using scaled-down dims that preserve architectural ratios
|
||||
// ============================================================================
|
||||
|
||||
mod llama {
|
||||
use super::*;
|
||||
|
||||
const SEQ: usize = 4;
|
||||
const HIDDEN: usize = 256;
|
||||
const INTERMEDIATE: usize = 896; // ~3.5x hidden, matching 14336/4096
|
||||
const PROJ_DIM: usize = 256; // Q_DIM == HIDDEN for llama
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
#[test]
|
||||
fn fuzz_llama_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_llama_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_llama_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_llama_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_llama_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
|
||||
}
|
||||
|
||||
/// Force HLIR-only (no block ops) to specifically test the fallback path.
|
||||
#[test]
|
||||
fn fuzz_llama_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Gemma-specific tests
|
||||
// Gemma 3 4B: HIDDEN=2560, INTERMEDIATE=10240, HEAD_DIM=256, Q_DIM=2048
|
||||
// Key difference: Q_DIM != HIDDEN, and 4 extra RMSNorm layers per block
|
||||
// ============================================================================
|
||||
|
||||
mod gemma {
|
||||
use super::*;
|
||||
|
||||
const SEQ: usize = 4;
|
||||
const HIDDEN: usize = 320; // divisible by 8 (N_HEADS)
|
||||
const INTERMEDIATE: usize = 1280; // 4x hidden, matching 10240/2560
|
||||
const Q_DIM: usize = 256; // scaled from 2048 (N_HEADS * HEAD_DIM)
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
fn fuzz_gemma_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_gemma_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_gemma_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
|
||||
}
|
||||
|
||||
/// Gemma has extra post-attention and post-feedforward norms.
|
||||
#[test]
|
||||
fn fuzz_gemma_layer_full_norms() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let attn_norm_w = cx.tensor(HIDDEN);
|
||||
let post_attn_norm_w = cx.tensor(HIDDEN);
|
||||
let pre_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let post_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let proj_w = cx.tensor((Q_DIM, HIDDEN));
|
||||
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
|
||||
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
|
||||
let normed = rms_norm(input, attn_norm_w, EPS);
|
||||
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
|
||||
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
|
||||
let x = input + attn_normed;
|
||||
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
|
||||
let mlp_out = swiglu_mlp(ff_normed, w_gate, w_up, w_down);
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 800u64;
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
|
||||
let attn_norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let post_attn_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 2, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let pre_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 3, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let post_ff_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 4, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let proj_data = random_f32_vec(Q_DIM * HIDDEN, seed + 5, -0.3, 0.3);
|
||||
let o_proj_data = random_f32_vec(HIDDEN * Q_DIM, seed + 6, -0.3, 0.3);
|
||||
let gate_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 7, -0.3, 0.3);
|
||||
let up_data = random_f32_vec(INTERMEDIATE * HIDDEN, seed + 8, -0.3, 0.3);
|
||||
let down_data = random_f32_vec(HIDDEN * INTERMEDIATE, seed + 9, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(attn_norm_w, attn_norm_data.clone());
|
||||
rt.set_data(post_attn_norm_w, post_attn_data.clone());
|
||||
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
|
||||
rt.set_data(post_ff_norm_w, post_ff_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt.set_data(o_proj_w, o_proj_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
// Candle reference
|
||||
let device = candle_core::Device::Cpu;
|
||||
let t = |data: &[f32], shape: &[usize]| {
|
||||
candle_core::Tensor::from_vec(data.to_vec(), shape, &device).unwrap()
|
||||
};
|
||||
let ref_input = t(&input_data, &[SEQ, HIDDEN]);
|
||||
let ref_attn_norm = t(&attn_norm_data, &[HIDDEN]);
|
||||
let ref_post_attn = t(&post_attn_data, &[HIDDEN]);
|
||||
let ref_pre_ff = t(&pre_ff_data, &[HIDDEN]);
|
||||
let ref_post_ff = t(&post_ff_data, &[HIDDEN]);
|
||||
let ref_proj = t(&proj_data, &[Q_DIM, HIDDEN]);
|
||||
let ref_o_proj = t(&o_proj_data, &[HIDDEN, Q_DIM]);
|
||||
let ref_gate = t(&gate_data, &[INTERMEDIATE, HIDDEN]);
|
||||
let ref_up = t(&up_data, &[INTERMEDIATE, HIDDEN]);
|
||||
let ref_down = t(&down_data, &[HIDDEN, INTERMEDIATE]);
|
||||
|
||||
let normed = rms_norm_ref(&ref_input, &ref_attn_norm, EPS as f64);
|
||||
let proj_out = normed
|
||||
.matmul(&ref_proj.t().unwrap())
|
||||
.unwrap()
|
||||
.matmul(&ref_o_proj.t().unwrap())
|
||||
.unwrap();
|
||||
let attn_normed = rms_norm_ref(&proj_out, &ref_post_attn, EPS as f64);
|
||||
let x_ref = (&ref_input + attn_normed).unwrap();
|
||||
let ff_normed = rms_norm_ref(&x_ref, &ref_pre_ff, EPS as f64);
|
||||
let mlp_out = swiglu_mlp_ref(&ff_normed, &ref_gate, &ref_up, &ref_down);
|
||||
let mlp_normed = rms_norm_ref(&mlp_out, &ref_post_ff, EPS as f64);
|
||||
let expected_t = (x_ref + mlp_normed).unwrap();
|
||||
let expected: Vec<f32> = expected_t.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 2e-2, 2e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(attn_norm_w, attn_norm_data.clone());
|
||||
rt.set_data(post_attn_norm_w, post_attn_data.clone());
|
||||
rt.set_data(pre_ff_norm_w, pre_ff_data.clone());
|
||||
rt.set_data(post_ff_norm_w, post_ff_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt.set_data(o_proj_w, o_proj_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
2e-2,
|
||||
2e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_gemma_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test fallback path with Gemma dimensions.
|
||||
#[test]
|
||||
fn fuzz_gemma_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Qwen-specific tests
|
||||
// Qwen3-4B: HIDDEN=2560, INTERMEDIATE=9728, HEAD_DIM=128, Q_DIM=4096
|
||||
// Key difference: Q_DIM > HIDDEN, tied embeddings (lm_head = embedding.t())
|
||||
// ============================================================================
|
||||
|
||||
mod qwen {
|
||||
use super::*;
|
||||
|
||||
const SEQ: usize = 4;
|
||||
const HIDDEN: usize = 256;
|
||||
const INTERMEDIATE: usize = 768; // ~3x hidden, matching 9728/2560
|
||||
const Q_DIM: usize = 512; // scaled from 4096 (Q_DIM > HIDDEN)
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
fn fuzz_qwen_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_qwen_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_qwen_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
|
||||
}
|
||||
|
||||
/// Qwen uses tied embeddings: lm_head = embedding^T
|
||||
#[test]
|
||||
fn fuzz_qwen_lm_head() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const VOCAB: usize = 512;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let norm_w = cx.tensor(HIDDEN);
|
||||
let embedding = cx.tensor((VOCAB, HIDDEN));
|
||||
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 1300u64;
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
|
||||
let norm_data: Vec<f32> = random_f32_vec(HIDDEN, seed + 1, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
let emb_data = random_f32_vec(VOCAB * HIDDEN, seed + 2, -0.3, 0.3);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.clone(), (SEQ, HIDDEN), &device).unwrap();
|
||||
let ref_norm = candle_core::Tensor::from_vec(norm_data.clone(), HIDDEN, &device).unwrap();
|
||||
let ref_emb =
|
||||
candle_core::Tensor::from_vec(emb_data.clone(), (VOCAB, HIDDEN), &device).unwrap();
|
||||
let normed = rms_norm_ref(&ref_input, &ref_norm, EPS as f64);
|
||||
let expected = normed.matmul(&ref_emb.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
},
|
||||
out.id,
|
||||
&expected,
|
||||
1e-2,
|
||||
1e-2,
|
||||
FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_qwen_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fuzz_qwen_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test fallback path with Qwen dimensions.
|
||||
#[test]
|
||||
fn fuzz_qwen_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
|
||||
}
|
||||
}
|
||||
606
crates/luminal_cuda_lite/src/tests/op_functional_tests.rs
Normal file
606
crates/luminal_cuda_lite/src/tests/op_functional_tests.rs
Normal file
@@ -0,0 +1,606 @@
|
||||
use candle_core::{Device, Tensor};
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use luminal::egglog_utils::{
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use super::utilities::{
|
||||
GENOME_FUZZ_COUNT, TOLERANCE_SAFETY_FACTOR, assert_close, dtype_epsilon, fuzz_genomes,
|
||||
gen_slice_range, get_cuda_stream, gpu_supports_dtype, random_f32_vec, random_i32_vec,
|
||||
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
|
||||
};
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[test]
|
||||
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul(
|
||||
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
|
||||
(1usize..128, 1usize..128, 1usize..128, any::<bool>(), any::<bool>(),
|
||||
any::<(bool, bool)>(), any::<(bool, bool)>(), any::<(bool, bool)>(),
|
||||
prop::sample::select(&[luminal::dtype::DType::F32, luminal::dtype::DType::F16, luminal::dtype::DType::Bf16]))
|
||||
.prop_perturb(|(m, n, k, a_cm, b_cm, m_sl, k_sl, n_sl, dt), mut rng| {
|
||||
(m, n, k, a_cm, b_cm,
|
||||
gen_slice_range(m, m_sl.0, m_sl.1, &mut rng),
|
||||
gen_slice_range(k, k_sl.0, k_sl.1, &mut rng),
|
||||
gen_slice_range(n, n_sl.0, n_sl.1, &mut rng),
|
||||
dt)
|
||||
}),
|
||||
seed in any::<u64>()
|
||||
) {
|
||||
prop_assume!(gpu_supports_dtype(dtype), "GPU does not support {:?}", dtype);
|
||||
|
||||
let (m_start, m_end) = m_slice;
|
||||
let (k_start, k_end) = k_slice;
|
||||
let (n_start, n_end) = n_slice;
|
||||
let effective_m = m_end - m_start;
|
||||
let effective_k = k_end - k_start;
|
||||
let effective_n = n_end - n_start;
|
||||
|
||||
// Column-major achieved by storing transposed then calling .t()
|
||||
let (a_shape, b_shape): ((usize, usize), (usize, usize)) = match (a_col_major, b_col_major) {
|
||||
(false, false) => ((m, k), (k, n)), // Rm x Rm
|
||||
(false, true) => ((m, k), (n, k)), // Rm x Cm
|
||||
(true, false) => ((k, m), (k, n)), // Cm x Rm
|
||||
(true, true) => ((k, m), (n, k)), // Cm x Cm
|
||||
};
|
||||
|
||||
let candle_dtype = to_candle_dtype(dtype);
|
||||
|
||||
let luminal_op = move |a: GraphTensor, b: GraphTensor| {
|
||||
let a = a.cast(dtype);
|
||||
let b = b.cast(dtype);
|
||||
let a = if a_col_major { a.t() } else { a };
|
||||
let b = if b_col_major { b.t() } else { b };
|
||||
// After transpose: A is (m, k), B is (k, n)
|
||||
let a = a.slice((m_start..m_end, k_start..k_end));
|
||||
let b = b.slice((k_start..k_end, n_start..n_end));
|
||||
a.matmul(b).cast(luminal::dtype::DType::F32)
|
||||
};
|
||||
let candle_op = move |a: Tensor, b: Tensor| {
|
||||
let a = a.to_dtype(candle_dtype).unwrap();
|
||||
let b = b.to_dtype(candle_dtype).unwrap();
|
||||
let a = if a_col_major { a.t().unwrap() } else { a };
|
||||
let b = if b_col_major { b.t().unwrap() } else { b };
|
||||
// After transpose: A is (m, k), B is (k, n)
|
||||
let a = a.narrow(0, m_start, effective_m).unwrap()
|
||||
.narrow(1, k_start, effective_k).unwrap()
|
||||
.contiguous().unwrap();
|
||||
let b = b.narrow(0, k_start, effective_k).unwrap()
|
||||
.narrow(1, n_start, effective_n).unwrap()
|
||||
.contiguous().unwrap();
|
||||
a.matmul(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap()
|
||||
};
|
||||
|
||||
// Matmul tolerance: rtol scales with sqrt(k) for accumulated rounding error
|
||||
let eps = dtype_epsilon(dtype);
|
||||
let sqrt_k = (effective_k as f32).sqrt();
|
||||
let rtol = eps * sqrt_k;
|
||||
let atol = 5.0 * eps;
|
||||
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
#[test]
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
#[test]
|
||||
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
test_mod(x, x, |a, b| a % b, seed);
|
||||
test_mod((y, x), (y, x), |a, b| a % b, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
let total = rows * cols;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((rows, cols));
|
||||
let sorted_dim0 = input.stable_argsort(0, true).output(); // descend
|
||||
let sorted_dim1 = input.stable_argsort(1, false).output(); // ascend
|
||||
|
||||
// random and unique data using seed
|
||||
let data: Vec<f32> = random_f32_vec(total, seed, 0.0, 1.0);
|
||||
|
||||
let sorted_cols: Vec<Vec<i32>> = (0..cols)
|
||||
.map(|col| {
|
||||
let mut indices: Vec<i32> = (0..rows as i32).collect();
|
||||
indices.sort_by(|&a, &b| {
|
||||
let va = data[(a as usize) * cols + col];
|
||||
let vb = data[(b as usize) * cols + col];
|
||||
vb.partial_cmp(&va).unwrap()
|
||||
});
|
||||
indices
|
||||
})
|
||||
.collect();
|
||||
|
||||
let expected_dim0: Vec<i32> = (0..rows)
|
||||
.flat_map(|row| {
|
||||
(0..cols)
|
||||
.map(|col| sorted_cols[col][row])
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let expected_dim1: Vec<i32> = (0..rows)
|
||||
.flat_map(|row| {
|
||||
let mut indices: Vec<i32> = (0..cols as i32).collect();
|
||||
indices.sort_by(|&a, &b| {
|
||||
let va = data[row * cols + (a as usize)];
|
||||
let vb = data[row * cols + (b as usize)];
|
||||
va.partial_cmp(&vb).unwrap()
|
||||
});
|
||||
indices
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, 10);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out_dim0 = rt.get_i32(sorted_dim0.id);
|
||||
let out_dim1 = rt.get_i32(sorted_dim1.id);
|
||||
|
||||
assert_eq!(out_dim0.len(), expected_dim0.len(), "dim0 length mismatch");
|
||||
assert_eq!(out_dim1.len(), expected_dim1.len(), "dim1 length mismatch");
|
||||
|
||||
// Debug: check for out-of-range values (indices should be 0..rows for dim0, 0..cols for dim1)
|
||||
let max_valid_dim0 = rows as i32 - 1;
|
||||
let max_valid_dim1 = cols as i32 - 1;
|
||||
let bad_dim0: Vec<_> = out_dim0
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, &v)| v < 0 || v > max_valid_dim0)
|
||||
.take(10)
|
||||
.collect();
|
||||
let bad_dim1: Vec<_> = out_dim1
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, &v)| v < 0 || v > max_valid_dim1)
|
||||
.take(10)
|
||||
.collect();
|
||||
|
||||
if !bad_dim0.is_empty() {
|
||||
panic!(
|
||||
"dim0 has out-of-range values (valid: 0-{max_valid_dim0}): {:?}\nFirst 20 values: {:?}",
|
||||
bad_dim0,
|
||||
&out_dim0[..20.min(out_dim0.len())]
|
||||
);
|
||||
}
|
||||
if !bad_dim1.is_empty() {
|
||||
panic!(
|
||||
"dim1 has out-of-range values (valid: 0-{max_valid_dim1}): {:?}",
|
||||
bad_dim1
|
||||
);
|
||||
}
|
||||
|
||||
for i in 0..out_dim0.len() {
|
||||
assert_eq!(
|
||||
out_dim0[i], expected_dim0[i],
|
||||
"dim0 mismatch at {i}: got {}, expected {}",
|
||||
out_dim0[i], expected_dim0[i]
|
||||
);
|
||||
}
|
||||
|
||||
for i in 0..out_dim1.len() {
|
||||
assert_eq!(
|
||||
out_dim1[i], expected_dim1[i],
|
||||
"dim1 mismatch at {i}: got {}, expected {}",
|
||||
out_dim1[i], expected_dim1[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Argsort proptest disabled due to pre-existing bug where argsort output shape
|
||||
// through e-graph compilation returns only `rows` elements instead of `rows * cols`.
|
||||
// proptest! {
|
||||
// #![proptest_config(ProptestConfig::with_cases(10))]
|
||||
// #[test]
|
||||
// fn test_argsort(seed in any::<u64>()) {
|
||||
// run_argsort_test(5, 500, seed);
|
||||
// }
|
||||
// }
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
|
||||
#[test]
|
||||
#[allow(clippy::approx_constant, clippy::excessive_precision)]
|
||||
pub fn test_cast_f16_edge_cases() {
|
||||
use luminal::dtype::DType;
|
||||
|
||||
// Fixed edge-case values that exercise F16 behavior
|
||||
let edge_cases: Vec<f32> = vec![
|
||||
0.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.5,
|
||||
0.333333333, // Will truncate: F16 can't represent 1/3 exactly
|
||||
0.1, // Will truncate: 0.1 isn't exact in binary
|
||||
1.0009765625, // Exactly representable in F16 (1 + 1/1024)
|
||||
1.00048828125, // Rounds to 1.0 in F16 (1 + 1/2048, below F16 precision)
|
||||
1.0007324219, // Between two F16 values, will round
|
||||
-3.140625, // Exactly representable
|
||||
3.14159265, // Pi - will truncate
|
||||
65504.0, // Max normal F16
|
||||
-65504.0, // Min normal F16
|
||||
0.000060976, // Near F16 min positive normal
|
||||
1e-7, // Subnormal in F16
|
||||
100.0,
|
||||
-100.0,
|
||||
12.345678, // Arbitrary value requiring truncation
|
||||
];
|
||||
|
||||
// Generator that ignores seed and returns edge cases
|
||||
let gen_edge_cases = |_n: usize, _seed: u64| edge_cases.clone();
|
||||
|
||||
test_unary_cuda(
|
||||
edge_cases.len(),
|
||||
|a| a.cast(DType::F16).cast(DType::F32),
|
||||
|a| {
|
||||
a.to_dtype(candle_core::DType::F16)
|
||||
.unwrap()
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
gen_edge_cases,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
|
||||
#[test]
|
||||
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
|
||||
use luminal::dtype::DType;
|
||||
|
||||
// Use range beyond F16 limits so some values overflow to infinity
|
||||
let f16_max = half::f16::MAX.to_f32();
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -2.0 * f16_max, 2.0 * f16_max);
|
||||
|
||||
test_unary_cuda(
|
||||
size,
|
||||
|a| a.cast(DType::F16).cast(DType::F32),
|
||||
|a| {
|
||||
a.to_dtype(candle_core::DType::F16)
|
||||
.unwrap()
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuzz test that generates many random genomes and verifies they all produce correct results.
|
||||
/// This tests the genetic algorithm search by validating each genome individually.
|
||||
/// Uses proptest seed for reproducibility - if this test fails, proptest will print the seed
|
||||
/// which can be used to reproduce the failure.
|
||||
fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping test");
|
||||
return;
|
||||
};
|
||||
|
||||
println!("Running fuzz_test_cuda_genomes with seed: {}", seed);
|
||||
|
||||
// Build a graph with operations that have rewrite alternatives
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor((4, 8));
|
||||
let b = cx.tensor((8, 4));
|
||||
let c = cx.tensor((4, 4));
|
||||
|
||||
// Matmul + add + relu creates opportunities for rewrites
|
||||
let d = a.matmul(b);
|
||||
let e = (d + c).relu();
|
||||
let out = e.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().unwrap();
|
||||
let ops = cx.egglog_ops().unwrap();
|
||||
|
||||
// Count mutable eclasses
|
||||
let mutable_eclasses: usize = egraph
|
||||
.eclasses
|
||||
.iter()
|
||||
.filter(|(_, (label, enodes))| {
|
||||
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
|
||||
})
|
||||
.count();
|
||||
println!(
|
||||
"CUDA search space: {} total eclasses, {} mutable",
|
||||
egraph.eclasses.len(),
|
||||
mutable_eclasses
|
||||
);
|
||||
|
||||
// Use seeded RNG for full reproducibility
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
|
||||
// Generate test data with seeded RNG (reproducible)
|
||||
let a_data: Vec<f32> = (0..32).map(|_| rng.random::<f32>()).collect();
|
||||
let b_data: Vec<f32> = (0..32).map(|_| rng.random::<f32>()).collect();
|
||||
let c_data: Vec<f32> = (0..16).map(|_| rng.random::<f32>()).collect();
|
||||
|
||||
// Compute reference result using candle
|
||||
let device = Device::Cpu;
|
||||
let ref_a = Tensor::from_vec(a_data.clone(), (4, 8), &device).unwrap();
|
||||
let ref_b = Tensor::from_vec(b_data.clone(), (8, 4), &device).unwrap();
|
||||
let ref_c = Tensor::from_vec(c_data.clone(), (4, 4), &device).unwrap();
|
||||
let ref_d = ref_a.matmul(&ref_b).unwrap();
|
||||
let ref_e = (&ref_d + &ref_c).unwrap().relu().unwrap();
|
||||
let expected: Vec<f32> = ref_e.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
|
||||
// Test initial genome
|
||||
let initial = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&initial));
|
||||
|
||||
if let Err(e) = validate_choice_set(egraph, &initial, ops) {
|
||||
panic!("Initial genome invalid: {}", e);
|
||||
}
|
||||
|
||||
// Extract and execute initial genome
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
initial.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut rt: CudaRuntime = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt.set_data(c, c_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&result, &expected, tol, tol);
|
||||
println!("Initial genome: correct");
|
||||
|
||||
// If no mutable eclasses, only one valid graph exists
|
||||
if mutable_eclasses == 0 {
|
||||
println!("No mutable eclasses, only one valid graph - test passed");
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate and test many genomes
|
||||
let mut base = initial;
|
||||
let mut tested = 0;
|
||||
let target = 50;
|
||||
|
||||
for _generation in 0..100 {
|
||||
let offspring = extract_generation(egraph, &base, 10, 2, &mut prev_selected, &mut rng);
|
||||
|
||||
if offspring.is_empty() {
|
||||
println!("Search space exhausted");
|
||||
break;
|
||||
}
|
||||
|
||||
for genome in offspring {
|
||||
// Validate
|
||||
if let Err(e) = validate_choice_set(egraph, &genome, ops) {
|
||||
panic!("Invalid genome: {}", e);
|
||||
}
|
||||
|
||||
// Extract and execute
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
// Create fresh runtime for this genome
|
||||
let mut rt: CudaRuntime = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt.set_data(c, c_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
// Verify correctness
|
||||
assert_close(&result, &expected, tol, tol);
|
||||
|
||||
tested += 1;
|
||||
base = genome;
|
||||
|
||||
if tested >= target {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if tested >= target {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Fuzz test: verified {} genomes produce correct results",
|
||||
tested
|
||||
);
|
||||
assert!(tested > 0, "No genomes were tested");
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(3))]
|
||||
|
||||
#[test]
|
||||
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
|
||||
fuzz_test_cuda_genomes_impl(seed);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping test");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let token_ids = cx.tensor(seq_len).as_dtype(luminal::dtype::DType::Int);
|
||||
let embed_table = cx.tensor((vocab_size, embed_dim));
|
||||
let output = embed_table
|
||||
.gather(
|
||||
(token_ids * embed_dim).expand_dim(1, embed_dim)
|
||||
+ cx.arange(embed_dim).expand_dim(0, seq_len),
|
||||
)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let token_data: Vec<i32> = random_i32_vec(seq_len, seed, 0, vocab_size as i32 - 1);
|
||||
let embed_data: Vec<f32> = random_f32_vec(vocab_size * embed_dim, seed, -0.5, 0.5);
|
||||
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0f32; seq_len * embed_dim];
|
||||
for i in 0..seq_len {
|
||||
let tid = token_data[i] as usize;
|
||||
for j in 0..embed_dim {
|
||||
expected[i * embed_dim + j] = embed_data[tid * embed_dim + j];
|
||||
}
|
||||
}
|
||||
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&result, &expected, tol, tol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
},
|
||||
output.id,
|
||||
&expected,
|
||||
tol,
|
||||
tol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[test]
|
||||
fn test_embed_proptest(
|
||||
vocab_size in 10usize..200,
|
||||
embed_dim in 8usize..128,
|
||||
seq_len in 1usize..32,
|
||||
seed in any::<u64>(),
|
||||
) {
|
||||
run_embed_test(vocab_size, embed_dim, seq_len, seed);
|
||||
}
|
||||
}
|
||||
94
crates/luminal_cuda_lite/src/tests/performance_tests.rs
Normal file
94
crates/luminal_cuda_lite/src/tests/performance_tests.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use tracing::{Level, enabled};
|
||||
|
||||
use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Test that measures bandwidth utilization for a large element-wise add kernel.
|
||||
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
|
||||
#[test]
|
||||
pub fn kernel_add_bandwidth_test() {
|
||||
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
|
||||
let size = 64 * 1024 * 1024;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let output = (a + b).output();
|
||||
|
||||
// Generate test data
|
||||
let data_a: Vec<f32> = (0..size).map(|i| (i % 1000) as f32 * 0.001).collect();
|
||||
let data_b: Vec<f32> = (0..size)
|
||||
.map(|i| ((i + 500) % 1000) as f32 * 0.001)
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Run and measure
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Print stats
|
||||
println!("\n=== Large KernelAdd Bandwidth Test ===");
|
||||
println!(
|
||||
"Tensor size: {} elements ({} MB per tensor)",
|
||||
size,
|
||||
size * 4 / 1024 / 1024
|
||||
);
|
||||
println!(
|
||||
"Total memory traffic: {} MB (2 reads + 1 write)",
|
||||
size * 4 * 3 / 1024 / 1024
|
||||
);
|
||||
if enabled!(Level::INFO) {
|
||||
rt.print_execution_stats();
|
||||
}
|
||||
|
||||
// Verify correctness (spot check)
|
||||
let result = rt.get_f32(output);
|
||||
for i in [0, size / 2, size - 1] {
|
||||
let expected = data_a[i] + data_b[i];
|
||||
let got = result[i];
|
||||
assert!(
|
||||
(got - expected).abs() < 1e-5,
|
||||
"Mismatch at {}: expected {}, got {}",
|
||||
i,
|
||||
expected,
|
||||
got
|
||||
);
|
||||
}
|
||||
|
||||
// Check bandwidth is reasonable (at least 50% of peak for large kernels)
|
||||
if let Some(peak_bw) = cuda_bandwidth_gbps(&ctx) {
|
||||
for stat in &rt.last_kernel_stats {
|
||||
let total_bytes = stat.bytes_loaded + stat.bytes_stored;
|
||||
if stat.name == "Add" && total_bytes > 0 {
|
||||
let utilization = stat.bandwidth_gbps / peak_bw as f64 * 100.0;
|
||||
println!(
|
||||
"\nAdd kernel achieved {:.1} GB/s ({:.1}% of {:.0} GB/s peak)",
|
||||
stat.bandwidth_gbps, utilization, peak_bw
|
||||
);
|
||||
println!(
|
||||
" Loaded: {} bytes, Stored: {} bytes",
|
||||
stat.bytes_loaded, stat.bytes_stored
|
||||
);
|
||||
// Large adds should achieve decent bandwidth
|
||||
assert!(
|
||||
utilization > 50.0,
|
||||
"Bandwidth utilization too low: {:.1}%",
|
||||
utilization
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
510
crates/luminal_cuda_lite/src/tests/transformer.rs
Normal file
510
crates/luminal_cuda_lite/src/tests/transformer.rs
Normal file
@@ -0,0 +1,510 @@
|
||||
//! Fuzz tests for small transformer models on CUDA.
|
||||
//!
|
||||
//! Builds a mini Llama-like transformer (RMSNorm + causal self-attention + SwiGLU MLP)
|
||||
//! and verifies CUDA execution against a CPU reference implementation using candle.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
// ---- Tiny Llama hyperparameters ----
|
||||
const SEQ: usize = 4;
|
||||
const HIDDEN: usize = 16;
|
||||
const INTERMEDIATE: usize = 32;
|
||||
|
||||
// ---- Graph-based mini transformer (Luminal) ----
|
||||
|
||||
/// RMSNorm: x * rsqrt(mean(x^2) + eps), optionally scaled by weight
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
/// Build self-attention using a simple single-head approach.
|
||||
/// Input: (seq, hidden), outputs: (seq, hidden)
|
||||
fn self_attention(
|
||||
x: GraphTensor,
|
||||
wq: GraphTensor,
|
||||
wk: GraphTensor,
|
||||
wv: GraphTensor,
|
||||
wo: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
// Project to Q, K, V: (seq, hidden) @ (hidden, hidden)^T = (seq, hidden)
|
||||
let q = x.matmul(wq.t());
|
||||
let k = x.matmul(wk.t());
|
||||
let v = x.matmul(wv.t());
|
||||
|
||||
// Simple single-head scaled dot-product attention (no causal mask for simplicity)
|
||||
let scale = 1.0 / (HIDDEN as f32).sqrt();
|
||||
let scores = q.matmul(k.t()) * scale; // (seq, seq)
|
||||
let attn_weights = scores.softmax(1); // softmax over key dim
|
||||
|
||||
// Apply attention to values and output projection
|
||||
attn_weights.matmul(v).matmul(wo.t())
|
||||
}
|
||||
|
||||
/// SwiGLU MLP: down(swish(gate(x)) * up(x))
|
||||
fn swiglu_mlp(
|
||||
x: GraphTensor,
|
||||
w_gate: GraphTensor,
|
||||
w_up: GraphTensor,
|
||||
w_down: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let gate = x.matmul(w_gate.t()).swish();
|
||||
let up = x.matmul(w_up.t());
|
||||
(gate * up).matmul(w_down.t())
|
||||
}
|
||||
|
||||
/// Build a single transformer layer on the graph.
|
||||
struct MiniTransformerLayer {
|
||||
attn_norm_w: GraphTensor,
|
||||
wq: GraphTensor,
|
||||
wk: GraphTensor,
|
||||
wv: GraphTensor,
|
||||
wo: GraphTensor,
|
||||
mlp_norm_w: GraphTensor,
|
||||
w_gate: GraphTensor,
|
||||
w_up: GraphTensor,
|
||||
w_down: GraphTensor,
|
||||
}
|
||||
|
||||
impl MiniTransformerLayer {
|
||||
fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attn_norm_w: cx.tensor(HIDDEN),
|
||||
wq: cx.tensor((HIDDEN, HIDDEN)),
|
||||
wk: cx.tensor((HIDDEN, HIDDEN)),
|
||||
wv: cx.tensor((HIDDEN, HIDDEN)),
|
||||
wo: cx.tensor((HIDDEN, HIDDEN)),
|
||||
mlp_norm_w: cx.tensor(HIDDEN),
|
||||
w_gate: cx.tensor((INTERMEDIATE, HIDDEN)),
|
||||
w_up: cx.tensor((INTERMEDIATE, HIDDEN)),
|
||||
w_down: cx.tensor((HIDDEN, INTERMEDIATE)),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
// Pre-norm attention with residual
|
||||
let normed = rms_norm(x, self.attn_norm_w, 1e-5);
|
||||
let attn_out = self_attention(normed, self.wq, self.wk, self.wv, self.wo);
|
||||
let x = x + attn_out;
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
let normed = rms_norm(x, self.mlp_norm_w, 1e-5);
|
||||
let mlp_out = swiglu_mlp(normed, self.w_gate, self.w_up, self.w_down);
|
||||
x + mlp_out
|
||||
}
|
||||
|
||||
/// Return all weight tensors and their sizes for data loading
|
||||
fn weights(&self) -> Vec<(GraphTensor, usize)> {
|
||||
vec![
|
||||
(self.attn_norm_w, HIDDEN),
|
||||
(self.wq, HIDDEN * HIDDEN),
|
||||
(self.wk, HIDDEN * HIDDEN),
|
||||
(self.wv, HIDDEN * HIDDEN),
|
||||
(self.wo, HIDDEN * HIDDEN),
|
||||
(self.mlp_norm_w, HIDDEN),
|
||||
(self.w_gate, INTERMEDIATE * HIDDEN),
|
||||
(self.w_up, INTERMEDIATE * HIDDEN),
|
||||
(self.w_down, HIDDEN * INTERMEDIATE),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Candle CPU reference ----
|
||||
|
||||
/// CPU reference for RMSNorm using candle
|
||||
fn rms_norm_ref(
|
||||
x: &candle_core::Tensor,
|
||||
weight: &candle_core::Tensor,
|
||||
eps: f64,
|
||||
) -> candle_core::Tensor {
|
||||
let dims = x.dims();
|
||||
let last_dim = dims[dims.len() - 1];
|
||||
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
|
||||
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
|
||||
let normed = x.broadcast_mul(&rsqrt).unwrap();
|
||||
normed
|
||||
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// CPU reference for self-attention (single-head, no causal mask)
|
||||
fn self_attention_ref(
|
||||
x: &candle_core::Tensor,
|
||||
wq: &candle_core::Tensor,
|
||||
wk: &candle_core::Tensor,
|
||||
wv: &candle_core::Tensor,
|
||||
wo: &candle_core::Tensor,
|
||||
) -> candle_core::Tensor {
|
||||
let q = x.matmul(&wq.t().unwrap()).unwrap();
|
||||
let k = x.matmul(&wk.t().unwrap()).unwrap();
|
||||
let v = x.matmul(&wv.t().unwrap()).unwrap();
|
||||
|
||||
let scale = 1.0 / (HIDDEN as f64).sqrt();
|
||||
let scores = q.matmul(&k.t().unwrap()).unwrap();
|
||||
let scores = (scores * scale).unwrap();
|
||||
|
||||
// Softmax over key dimension (dim 1)
|
||||
let max_val = scores.max(1).unwrap().unsqueeze(1).unwrap();
|
||||
let shifted = scores.broadcast_sub(&max_val).unwrap();
|
||||
let exps = shifted.exp().unwrap();
|
||||
let sum_exps = exps.sum(1).unwrap().unsqueeze(1).unwrap();
|
||||
let attn_weights = exps.broadcast_div(&sum_exps).unwrap();
|
||||
|
||||
attn_weights
|
||||
.matmul(&v)
|
||||
.unwrap()
|
||||
.matmul(&wo.t().unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// CPU reference for SwiGLU MLP
|
||||
fn swiglu_mlp_ref(
|
||||
x: &candle_core::Tensor,
|
||||
w_gate: &candle_core::Tensor,
|
||||
w_up: &candle_core::Tensor,
|
||||
w_down: &candle_core::Tensor,
|
||||
) -> candle_core::Tensor {
|
||||
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
|
||||
let up = x.matmul(&w_up.t().unwrap()).unwrap();
|
||||
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
|
||||
}
|
||||
|
||||
/// CPU reference for one transformer layer
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transformer_layer_ref(
|
||||
x: &candle_core::Tensor,
|
||||
attn_norm_w: &candle_core::Tensor,
|
||||
wq: &candle_core::Tensor,
|
||||
wk: &candle_core::Tensor,
|
||||
wv: &candle_core::Tensor,
|
||||
wo: &candle_core::Tensor,
|
||||
mlp_norm_w: &candle_core::Tensor,
|
||||
w_gate: &candle_core::Tensor,
|
||||
w_up: &candle_core::Tensor,
|
||||
w_down: &candle_core::Tensor,
|
||||
) -> candle_core::Tensor {
|
||||
let normed = rms_norm_ref(x, attn_norm_w, 1e-5);
|
||||
let attn_out = self_attention_ref(&normed, wq, wk, wv, wo);
|
||||
let x = (x + attn_out).unwrap();
|
||||
|
||||
let normed = rms_norm_ref(&x, mlp_norm_w, 1e-5);
|
||||
let mlp_out = swiglu_mlp_ref(&normed, w_gate, w_up, w_down);
|
||||
(x + mlp_out).unwrap()
|
||||
}
|
||||
|
||||
// ---- Helper to generate weight data for a layer ----
|
||||
|
||||
fn generate_layer_weights(
|
||||
layer: &MiniTransformerLayer,
|
||||
base_seed: u64,
|
||||
) -> Vec<(GraphTensor, Vec<f32>)> {
|
||||
layer
|
||||
.weights()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (tensor, size))| {
|
||||
let data = random_f32_vec(*size, base_seed + i as u64, -0.5, 0.5);
|
||||
// RMSNorm weights should be initialized to ~1.0
|
||||
let data = if *size == HIDDEN {
|
||||
data.iter().map(|x| x + 1.0).collect::<Vec<_>>()
|
||||
} else {
|
||||
data
|
||||
};
|
||||
(*tensor, data)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build_candle_ref(input_data: &[f32], weight_data: &[(GraphTensor, Vec<f32>)]) -> Vec<f32> {
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input =
|
||||
candle_core::Tensor::from_vec(input_data.to_vec(), (SEQ, HIDDEN), &device).unwrap();
|
||||
|
||||
// weight_data: [attn_norm_w, wq, wk, wv, wo, mlp_norm_w, w_gate, w_up, w_down]
|
||||
let w = |idx: usize, shape: &[usize]| {
|
||||
candle_core::Tensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
|
||||
};
|
||||
let ref_attn_norm_w = w(0, &[HIDDEN]);
|
||||
let ref_wq = w(1, &[HIDDEN, HIDDEN]);
|
||||
let ref_wk = w(2, &[HIDDEN, HIDDEN]);
|
||||
let ref_wv = w(3, &[HIDDEN, HIDDEN]);
|
||||
let ref_wo = w(4, &[HIDDEN, HIDDEN]);
|
||||
let ref_mlp_norm_w = w(5, &[HIDDEN]);
|
||||
let ref_w_gate = w(6, &[INTERMEDIATE, HIDDEN]);
|
||||
let ref_w_up = w(7, &[INTERMEDIATE, HIDDEN]);
|
||||
let ref_w_down = w(8, &[HIDDEN, INTERMEDIATE]);
|
||||
|
||||
let expected = transformer_layer_ref(
|
||||
&ref_input,
|
||||
&ref_attn_norm_w,
|
||||
&ref_wq,
|
||||
&ref_wk,
|
||||
&ref_wv,
|
||||
&ref_wo,
|
||||
&ref_mlp_norm_w,
|
||||
&ref_w_gate,
|
||||
&ref_w_up,
|
||||
&ref_w_down,
|
||||
);
|
||||
expected.flatten_all().unwrap().to_vec1().unwrap()
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
/// Test a single transformer layer on CUDA against candle CPU reference.
|
||||
#[test]
|
||||
fn test_mini_transformer_layer() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
rt.set_data(input, input_data.clone());
|
||||
|
||||
let weight_data = generate_layer_weights(&layer, 100);
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
// Use minimal search iterations to avoid excessive graph rewriting
|
||||
// which can cause float drift through softmax/RMSNorm reordering
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let expected = build_candle_ref(&input_data, &weight_data);
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test a two-layer transformer on CUDA against candle CPU reference.
|
||||
#[test]
|
||||
fn test_mini_transformer_two_layers() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let layer1 = MiniTransformerLayer::init(&mut cx);
|
||||
let layer2 = MiniTransformerLayer::init(&mut cx);
|
||||
let x = layer1.forward(input).graph_break();
|
||||
let out = layer2.forward(x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
rt.set_data(input, input_data.clone());
|
||||
|
||||
let layer1_weights = generate_layer_weights(&layer1, 200);
|
||||
let layer2_weights = generate_layer_weights(&layer2, 300);
|
||||
|
||||
for (tensor, data) in layer1_weights.iter().chain(layer2_weights.iter()) {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
// Run two layers on CPU reference
|
||||
let device = candle_core::Device::Cpu;
|
||||
let mut ref_x = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
|
||||
|
||||
for weights in [&layer1_weights, &layer2_weights] {
|
||||
let w = |idx: usize, shape: &[usize]| {
|
||||
candle_core::Tensor::from_vec(weights[idx].1.clone(), shape, &device).unwrap()
|
||||
};
|
||||
ref_x = transformer_layer_ref(
|
||||
&ref_x,
|
||||
&w(0, &[HIDDEN]),
|
||||
&w(1, &[HIDDEN, HIDDEN]),
|
||||
&w(2, &[HIDDEN, HIDDEN]),
|
||||
&w(3, &[HIDDEN, HIDDEN]),
|
||||
&w(4, &[HIDDEN, HIDDEN]),
|
||||
&w(5, &[HIDDEN]),
|
||||
&w(6, &[INTERMEDIATE, HIDDEN]),
|
||||
&w(7, &[INTERMEDIATE, HIDDEN]),
|
||||
&w(8, &[HIDDEN, INTERMEDIATE]),
|
||||
);
|
||||
}
|
||||
|
||||
let expected: Vec<f32> = ref_x.flatten_all().unwrap().to_vec1().unwrap();
|
||||
// Two layers accumulate more drift
|
||||
assert_close(&result, &expected, 2e-2, 2e-2);
|
||||
}
|
||||
|
||||
/// Test the transformer with multiple random data seeds to catch data-dependent bugs.
|
||||
#[test]
|
||||
fn test_transformer_multi_seed() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
for seed in [42u64, 99, 777] {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
|
||||
rt.set_data(input, input_data.clone());
|
||||
|
||||
let weight_data = generate_layer_weights(&layer, seed + 100);
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let expected = build_candle_ref(&input_data, &weight_data);
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test just the RMSNorm component on CUDA
|
||||
#[test]
|
||||
fn test_rms_norm_cuda() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let weight = cx.tensor(HIDDEN);
|
||||
let out = rms_norm(input, weight, 1e-5).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 1, -0.5, 0.5);
|
||||
let weight_data: Vec<f32> = random_f32_vec(HIDDEN, 2, -0.5, 0.5)
|
||||
.iter()
|
||||
.map(|x| x + 1.0)
|
||||
.collect();
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(weight, weight_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
|
||||
let ref_weight = candle_core::Tensor::from_vec(weight_data, HIDDEN, &device).unwrap();
|
||||
let expected = rms_norm_ref(&ref_input, &ref_weight, 1e-5);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
/// Test just the self-attention on CUDA
|
||||
#[test]
|
||||
fn test_self_attention_cuda() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let wq = cx.tensor((HIDDEN, HIDDEN));
|
||||
let wk = cx.tensor((HIDDEN, HIDDEN));
|
||||
let wv = cx.tensor((HIDDEN, HIDDEN));
|
||||
let wo = cx.tensor((HIDDEN, HIDDEN));
|
||||
let out = self_attention(input, wq, wk, wv, wo).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 10, -0.5, 0.5);
|
||||
let wq_data = random_f32_vec(HIDDEN * HIDDEN, 11, -0.5, 0.5);
|
||||
let wk_data = random_f32_vec(HIDDEN * HIDDEN, 12, -0.5, 0.5);
|
||||
let wv_data = random_f32_vec(HIDDEN * HIDDEN, 13, -0.5, 0.5);
|
||||
let wo_data = random_f32_vec(HIDDEN * HIDDEN, 14, -0.5, 0.5);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(wq, wq_data.clone());
|
||||
rt.set_data(wk, wk_data.clone());
|
||||
rt.set_data(wv, wv_data.clone());
|
||||
rt.set_data(wo, wo_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
|
||||
let ref_wq = candle_core::Tensor::from_vec(wq_data, (HIDDEN, HIDDEN), &device).unwrap();
|
||||
let ref_wk = candle_core::Tensor::from_vec(wk_data, (HIDDEN, HIDDEN), &device).unwrap();
|
||||
let ref_wv = candle_core::Tensor::from_vec(wv_data, (HIDDEN, HIDDEN), &device).unwrap();
|
||||
let ref_wo = candle_core::Tensor::from_vec(wo_data, (HIDDEN, HIDDEN), &device).unwrap();
|
||||
|
||||
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test just the SwiGLU MLP on CUDA
|
||||
#[test]
|
||||
fn test_swiglu_mlp_cuda() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
println!("CUDA not available, skipping");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 20, -0.5, 0.5);
|
||||
let gate_data = random_f32_vec(INTERMEDIATE * HIDDEN, 21, -0.5, 0.5);
|
||||
let up_data = random_f32_vec(INTERMEDIATE * HIDDEN, 22, -0.5, 0.5);
|
||||
let down_data = random_f32_vec(HIDDEN * INTERMEDIATE, 23, -0.5, 0.5);
|
||||
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
let device = candle_core::Device::Cpu;
|
||||
let ref_input = candle_core::Tensor::from_vec(input_data, (SEQ, HIDDEN), &device).unwrap();
|
||||
let ref_gate =
|
||||
candle_core::Tensor::from_vec(gate_data, (INTERMEDIATE, HIDDEN), &device).unwrap();
|
||||
let ref_up = candle_core::Tensor::from_vec(up_data, (INTERMEDIATE, HIDDEN), &device).unwrap();
|
||||
let ref_down =
|
||||
candle_core::Tensor::from_vec(down_data, (HIDDEN, INTERMEDIATE), &device).unwrap();
|
||||
|
||||
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3, 1e-3);
|
||||
}
|
||||
496
crates/luminal_cuda_lite/src/tests/utilities.rs
Normal file
496
crates/luminal_cuda_lite/src/tests/utilities.rs
Normal file
@@ -0,0 +1,496 @@
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use luminal::egglog_utils::{
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use num_traits::{Num, Signed};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::runtime::{CudaRuntime, ToCudaInput};
|
||||
|
||||
/// Safety factor multiplied with epsilon for tolerance calculations
|
||||
pub const TOLERANCE_SAFETY_FACTOR: f32 = 2.0;
|
||||
|
||||
/// Number of genomes to fuzz per op test invocation.
|
||||
pub const GENOME_FUZZ_COUNT: usize = 20;
|
||||
|
||||
/// Trait for test-compatible data types that can be used in generic test functions.
|
||||
/// Bridges luminal's runtime types with candle's tensor types.
|
||||
pub trait TestDType:
|
||||
Clone + Sized + WithDType + PartialEq + Copy + std::fmt::Debug + 'static
|
||||
where
|
||||
Vec<Self>: ToCudaInput,
|
||||
{
|
||||
/// The corresponding luminal DType
|
||||
const DTYPE: luminal::dtype::DType;
|
||||
|
||||
/// Retrieve data from the runtime in this dtype
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self>;
|
||||
/// Extract a Vec from a candle Tensor
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self>;
|
||||
/// Compare two result vectors. Float types use tolerance; exact types use equality.
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32);
|
||||
}
|
||||
|
||||
impl TestDType for f32 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F32;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_f32(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<f32>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, rtol, atol);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for f16 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F16;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_f16(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<f16>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, f16::from_f32(rtol), f16::from_f32(atol));
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for bf16 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Bf16;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_bf16(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<bf16>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, bf16::from_f32(rtol), bf16::from_f32(atol));
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for i32 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Int;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_i32(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<i32>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], _rtol: f32, _atol: f32) {
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn random_i32_vec(n: usize, seed: u64, low: i32, high: i32) -> Vec<i32> {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
(0..n).map(|_| rng.random_range(low..=high)).collect()
|
||||
}
|
||||
|
||||
pub fn random_f32_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<f32> {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
(0..n).map(|_| rng.random_range(low..high)).collect()
|
||||
}
|
||||
|
||||
/// Assert two vectors are close following NumPy/PyTorch conventions.
|
||||
/// Formula: |a - b| <= atol + rtol * |b|
|
||||
/// Generic version that works with any Float type (f32, f16, bf16).
|
||||
pub fn assert_close<T: Num + Signed + PartialOrd + Copy + std::fmt::Display>(
|
||||
a_vec: &[T],
|
||||
b_vec: &[T],
|
||||
rtol: T,
|
||||
atol: T,
|
||||
) {
|
||||
assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match");
|
||||
for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
|
||||
let diff = (*a - *b).abs();
|
||||
let tolerance = atol + rtol * b.abs();
|
||||
|
||||
if diff > tolerance {
|
||||
panic!("{a} is not close to {b}, index {i}, diff: {diff}, tolerance: {tolerance}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
ctx.bind_to_thread().ok()?;
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
ctx.compute_capability().ok()
|
||||
}
|
||||
|
||||
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
|
||||
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
|
||||
let Some((major, _)) = gpu_compute_cap() else {
|
||||
return false;
|
||||
};
|
||||
match dtype {
|
||||
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
|
||||
luminal::dtype::DType::F4E2M1
|
||||
| luminal::dtype::DType::F8E4M3
|
||||
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Machine epsilon for each dtype (approximate)
|
||||
pub fn dtype_epsilon(dtype: luminal::dtype::DType) -> f32 {
|
||||
match dtype {
|
||||
luminal::dtype::DType::F32 => 1.19e-7, // 2^-23
|
||||
luminal::dtype::DType::F16 => 9.77e-4, // 2^-10
|
||||
luminal::dtype::DType::Bf16 => 7.81e-3, // 2^-7
|
||||
luminal::dtype::DType::Int => 0.0,
|
||||
luminal::dtype::DType::Bool => 0.0,
|
||||
other => todo!("dtype_epsilon not implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a luminal DType to the corresponding candle DType.
|
||||
pub fn to_candle_dtype(dtype: luminal::dtype::DType) -> candle_core::DType {
|
||||
match dtype {
|
||||
luminal::dtype::DType::F32 => candle_core::DType::F32,
|
||||
luminal::dtype::DType::F16 => candle_core::DType::F16,
|
||||
luminal::dtype::DType::Bf16 => candle_core::DType::BF16,
|
||||
luminal::dtype::DType::Int => candle_core::DType::I32,
|
||||
luminal::dtype::DType::Bool => candle_core::DType::U8,
|
||||
other => todo!("candle dtype mapping not implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Base unary test function with input generator (CUDA version)
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
pub fn test_unary_cuda<T: TestDType>(
|
||||
shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor) -> Tensor,
|
||||
generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
seed: u64,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let shape: Vec<usize> = shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let n_elements: usize = shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(shape.clone());
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
|
||||
// Reference using candle on CUDA
|
||||
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
|
||||
let ref_a = Tensor::from_slice(&input_data, shape, &device).unwrap();
|
||||
let ref_b = ref_func(ref_a).flatten_all().unwrap();
|
||||
let ref_vec = T::candle_to_vec(&ref_b);
|
||||
|
||||
let eps = dtype_epsilon(<T as TestDType>::DTYPE);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
T::assert_match(&result, &ref_vec, tol, tol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<T>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| rt.set_data(a, input_data.clone()),
|
||||
b.id,
|
||||
&ref_vec,
|
||||
tol,
|
||||
tol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Base binary test function with input generators
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn test_binary_cuda<T: TestDType>(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor, Tensor) -> Tensor,
|
||||
a_generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
b_generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
seed: u64,
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let a_shape: Vec<usize> = a_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let b_shape: Vec<usize> = b_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let a_elements: usize = a_shape.iter().product();
|
||||
let b_elements: usize = b_shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a: GraphTensor = cx.tensor(a_shape.clone());
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = a_generator(a_elements, seed);
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
|
||||
// Reference using candle on CUDA
|
||||
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
|
||||
let ref_a = Tensor::from_slice(&a_data, a_shape, &device).unwrap();
|
||||
let ref_b = Tensor::from_slice(&b_data, b_shape, &device).unwrap();
|
||||
let ref_c = ref_func(ref_a, ref_b).flatten_all().unwrap();
|
||||
let ref_vec = T::candle_to_vec(&ref_c);
|
||||
|
||||
T::assert_match(&result, &ref_vec, rtol, atol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<T>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
},
|
||||
c.id,
|
||||
&ref_vec,
|
||||
rtol,
|
||||
atol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Test mod operation with element-wise reference using Rust's % operator
|
||||
pub fn test_mod(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
seed: u64,
|
||||
) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let a_shape: Vec<usize> = a_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let b_shape: Vec<usize> = b_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let a_elements: usize = a_shape.iter().product();
|
||||
let b_elements: usize = b_shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(a_shape.clone());
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
|
||||
// Generate divisor values away from zero (0.1 to 0.5) to avoid division issues
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
// Reference: Rust's % operator matches CUDA's fmodf (IEEE 754 remainder)
|
||||
let expected: Vec<f32> = a_data
|
||||
.iter()
|
||||
.zip(b_data.iter())
|
||||
.map(|(x, y)| x % y)
|
||||
.collect();
|
||||
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let rtol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let atol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&result, &expected, rtol, atol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
},
|
||||
c.id,
|
||||
&expected,
|
||||
rtol,
|
||||
atol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generate a slice range for an axis of given size.
|
||||
/// If do_start is true, randomly choose a start offset (leaving at least 1 element).
|
||||
/// If do_end is true, randomly choose an end before the axis end.
|
||||
pub fn gen_slice_range(
|
||||
size: usize,
|
||||
do_start: bool,
|
||||
do_end: bool,
|
||||
rng: &mut impl Rng,
|
||||
) -> (usize, usize) {
|
||||
let start = if do_start && size > 1 {
|
||||
rng.random_range(0..size)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let remaining = size - start;
|
||||
let end = if do_end && remaining > 1 {
|
||||
start + rng.random_range(1..remaining)
|
||||
} else {
|
||||
size
|
||||
};
|
||||
(start, end)
|
||||
}
|
||||
|
||||
/// Fuzz test multiple genomes from the e-graph search space.
|
||||
///
|
||||
/// After a graph has been built and compared against a reference, this function
|
||||
/// extracts random genomes via mutation and verifies they all produce results
|
||||
/// matching the expected reference output. This catches bugs where graph rewrites
|
||||
/// produce incorrect computation.
|
||||
///
|
||||
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn fuzz_genomes<T: TestDType>(
|
||||
cx: &Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
setup_inputs: impl Fn(&mut CudaRuntime),
|
||||
output_id: NodeIndex,
|
||||
expected: &[T],
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
num_genomes: usize,
|
||||
seed: u64,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(egraph) = cx.egraph() else {
|
||||
return;
|
||||
};
|
||||
let Some(ops) = cx.egglog_ops() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check if there are alternative genomes to explore
|
||||
let mutable_eclasses: usize = egraph
|
||||
.eclasses
|
||||
.iter()
|
||||
.filter(|(_, (label, enodes))| {
|
||||
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
|
||||
})
|
||||
.count();
|
||||
if mutable_eclasses == 0 {
|
||||
return; // Only one valid graph, nothing to fuzz
|
||||
}
|
||||
|
||||
// Use a different seed offset to avoid correlating with the search seed
|
||||
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(7777));
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
|
||||
let initial = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&initial));
|
||||
|
||||
let mut base = initial;
|
||||
let mut tested = 0;
|
||||
|
||||
for _ in 0..100 {
|
||||
let offspring = extract_generation(egraph, &base, 10, 2, &mut prev_selected, &mut rng);
|
||||
|
||||
if offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
for genome in offspring {
|
||||
if validate_choice_set(egraph, &genome, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
setup_inputs(&mut rt);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = T::get_from_runtime(&rt, output_id);
|
||||
T::assert_match(&result, expected, rtol, atol);
|
||||
|
||||
tested += 1;
|
||||
base = genome;
|
||||
|
||||
if tested >= num_genomes {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
22
crates/luminal_metal/Cargo.toml
Normal file
22
crates/luminal_metal/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "luminal_metal"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
description = "Metal backend for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
metal = "0.31"
|
||||
objc = "0.2"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
half = "2.7.1"
|
||||
tracing = "0.1.43"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
proptest = "1.9.0"
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
227
crates/luminal_metal/src/kernel/matmul.rs
Normal file
227
crates/luminal_metal/src/kernel/matmul.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use super::{MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum MetalMatmulFamily {
|
||||
#[default]
|
||||
Naive,
|
||||
RegularTiled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulDescriptor {
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub batch_shape: Vec<Expression>,
|
||||
pub lhs_strides: Vec<Expression>,
|
||||
pub rhs_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub transpose_lhs: bool,
|
||||
pub transpose_rhs: bool,
|
||||
}
|
||||
|
||||
impl MatmulDescriptor {
|
||||
pub fn from_mul_and_sum(
|
||||
mul_info: &MetalMulInfo,
|
||||
sum_info: &MetalSumReduceInfo,
|
||||
) -> Option<Self> {
|
||||
let zero = Expression::from(0);
|
||||
let z = Expression::from('z');
|
||||
|
||||
let is_simple_2d_matmul = mul_info.shape.len() == 3
|
||||
&& sum_info.shape.len() == 2
|
||||
&& mul_info.a_strides.len() == 3
|
||||
&& mul_info.b_strides.len() == 3
|
||||
&& sum_info.strides.len() == 2
|
||||
&& mul_info.shape[0] == sum_info.shape[0]
|
||||
&& mul_info.shape[1] == sum_info.shape[1]
|
||||
&& mul_info.shape[2] == sum_info.iters
|
||||
&& mul_info.a_strides[1] == zero
|
||||
&& mul_info.a_strides[2] == z
|
||||
&& mul_info.b_strides[0] == zero
|
||||
&& mul_info.b_strides[1] == z
|
||||
&& sum_info.strides[1] == z
|
||||
&& sum_info.iter_stride == z;
|
||||
|
||||
if !is_simple_2d_matmul {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
m: sum_info.shape[0],
|
||||
n: sum_info.shape[1],
|
||||
k: sum_info.iters,
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: mul_info.a_strides.clone(),
|
||||
rhs_strides: mul_info.b_strides.clone(),
|
||||
out_strides: sum_info.strides.clone(),
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulPlan {
|
||||
pub family: MetalMatmulFamily,
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub lda: Expression,
|
||||
pub ldb: Expression,
|
||||
pub ldd: Expression,
|
||||
pub batch_size: u32,
|
||||
pub batch_stride_a: u32,
|
||||
pub batch_stride_b: u32,
|
||||
pub batch_stride_d: u32,
|
||||
pub bm: u16,
|
||||
pub bn: u16,
|
||||
pub bk: u16,
|
||||
pub wm: u16,
|
||||
pub wn: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct MetalMatmulPlanner;
|
||||
|
||||
impl MetalMatmulPlanner {
|
||||
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
|
||||
let family = if desc.batch_shape.is_empty()
|
||||
&& desc.m.as_num().is_some_and(|m| m >= 32)
|
||||
&& desc.n.as_num().is_some_and(|n| n >= 32)
|
||||
&& desc.k.as_num().is_some_and(|k| k >= 32)
|
||||
{
|
||||
MetalMatmulFamily::RegularTiled
|
||||
} else {
|
||||
MetalMatmulFamily::Naive
|
||||
};
|
||||
MatmulPlan {
|
||||
family,
|
||||
m: desc.m,
|
||||
n: desc.n,
|
||||
k: desc.k,
|
||||
lda: desc.lhs_strides[0],
|
||||
ldb: desc.rhs_strides[2],
|
||||
ldd: desc.out_strides[0],
|
||||
batch_size: 1,
|
||||
batch_stride_a: 0,
|
||||
batch_stride_b: 0,
|
||||
batch_stride_d: 0,
|
||||
bm: 16,
|
||||
bn: 16,
|
||||
bk: 8,
|
||||
wm: 2,
|
||||
wn: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn descriptor_recovers_simple_2d_matmul() {
|
||||
let mul = MetalMulInfo {
|
||||
shape: vec![
|
||||
Expression::from(4),
|
||||
Expression::from(8),
|
||||
Expression::from(16),
|
||||
],
|
||||
a_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
b_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
output_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from('z') * 8,
|
||||
Expression::from('z'),
|
||||
],
|
||||
};
|
||||
let sum = MetalSumReduceInfo {
|
||||
shape: vec![Expression::from(4), Expression::from(8)],
|
||||
strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
iters: Expression::from(16),
|
||||
iter_stride: Expression::from('z'),
|
||||
};
|
||||
|
||||
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
|
||||
assert_eq!(desc.m, Expression::from(4));
|
||||
assert_eq!(desc.n, Expression::from(8));
|
||||
assert_eq!(desc.k, Expression::from(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_keeps_small_problems_on_naive_path() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(4),
|
||||
n: Expression::from(8),
|
||||
k: Expression::from(16),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::Naive);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
assert_eq!(plan.lda, Expression::from('z') * 16);
|
||||
assert_eq!(plan.ldb, Expression::from('z') * 8);
|
||||
assert_eq!(plan.ldd, Expression::from('z') * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_promotes_large_problems_to_regular_tiled() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(64),
|
||||
n: Expression::from(64),
|
||||
k: Expression::from(64),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 64,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 64,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
}
|
||||
}
|
||||
81
crates/luminal_metal/src/kernel/mod.rs
Normal file
81
crates/luminal_metal/src/kernel/mod.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
mod matmul;
|
||||
mod ops;
|
||||
pub use matmul::*;
|
||||
pub use ops::*;
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalMulInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
pub a_strides: Vec<Expression>,
|
||||
pub b_strides: Vec<Expression>,
|
||||
pub output_strides: Vec<Expression>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalSumReduceInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
pub strides: Vec<Expression>,
|
||||
pub iters: Expression,
|
||||
pub iter_stride: Expression,
|
||||
}
|
||||
|
||||
pub trait MetalKernelOp: EgglogOp {
|
||||
fn compile(
|
||||
&self,
|
||||
device: &Device,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) -> ComputePipelineState;
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.first().copied().unwrap_or(DType::F32)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
pipeline: &ComputePipelineState,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
);
|
||||
|
||||
// ========================================================================
|
||||
// Performance Metrics for MBU/MFU Calculation
|
||||
// ========================================================================
|
||||
|
||||
fn bytes_loaded(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn bytes_stored(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn flops(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn mul_info(&self) -> Option<MetalMulInfo> {
|
||||
None
|
||||
}
|
||||
|
||||
fn sum_reduce_info(&self) -> Option<MetalSumReduceInfo> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_matmul(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
luminal::impl_into_ops!(MetalKernelOp);
|
||||
2535
crates/luminal_metal/src/kernel/ops.rs
Normal file
2535
crates/luminal_metal/src/kernel/ops.rs
Normal file
File diff suppressed because it is too large
Load Diff
12
crates/luminal_metal/src/lib.rs
Normal file
12
crates/luminal_metal/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use metal::{Buffer, Device, MTLResourceOptions};
|
||||
pub use objc::rc::autoreleasepool;
|
||||
pub use runtime::MetalRuntime;
|
||||
|
||||
// Re-export kernel ops
|
||||
pub use kernel::MetalOps;
|
||||
555
crates/luminal_metal/src/runtime.rs
Normal file
555
crates/luminal_metal/src/runtime.rs
Normal file
@@ -0,0 +1,555 @@
|
||||
use crate::kernel::{
|
||||
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
|
||||
};
|
||||
use half::f16;
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::LLIRGraph,
|
||||
hlir::{Input, NativeData, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
prelude::{
|
||||
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
|
||||
FxHashMap, NodeIndex, ToId,
|
||||
},
|
||||
};
|
||||
use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptions};
|
||||
use objc::runtime::Object;
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct MetalRuntime {
|
||||
device: Device,
|
||||
command_queue: CommandQueue,
|
||||
/// Host-side input tensors provided by the user.
|
||||
input_data: FxHashMap<NodeIndex, NativeData>,
|
||||
/// Buffers for HLIR input tensors (set by user)
|
||||
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Buffers for LLIR intermediate/output tensors
|
||||
pub buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Dynamic dimensions table (a-z), shared across all kernels.
|
||||
dyn_buffer: Buffer,
|
||||
/// The current LLIR graph
|
||||
llir_graph: LLIRGraph,
|
||||
/// Inferred runtime dtype for each LLIR node.
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
|
||||
let mut graph = llir_graph.clone();
|
||||
let planner = MetalMatmulPlanner;
|
||||
let mut rewrites = Vec::new();
|
||||
|
||||
for sum_node in graph.node_indices().collect::<Vec<_>>() {
|
||||
let Some(sum_info) = graph[sum_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.sum_reduce_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let input_edges: Vec<_> = graph
|
||||
.edges_directed(sum_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if input_edges.len() != 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mul_node = input_edges[0];
|
||||
let Some(mul_info) = graph[mul_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.mul_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mul_inputs: Vec<_> = graph
|
||||
.edges_directed(mul_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
|
||||
}
|
||||
|
||||
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
|
||||
graph[sum_node] =
|
||||
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
|
||||
m: plan.m,
|
||||
n: plan.n,
|
||||
k: plan.k,
|
||||
lda: plan.lda,
|
||||
ldb: plan.ldb,
|
||||
ldd: plan.ldd,
|
||||
family: plan.family,
|
||||
bm: plan.bm,
|
||||
bn: plan.bn,
|
||||
bk: plan.bk,
|
||||
wm: plan.wm,
|
||||
wn: plan.wn,
|
||||
batch_size: plan.batch_size,
|
||||
batch_stride_a: plan.batch_stride_a,
|
||||
batch_stride_b: plan.batch_stride_b,
|
||||
batch_stride_d: plan.batch_stride_d,
|
||||
}));
|
||||
|
||||
graph.remove_node(mul_node);
|
||||
graph.add_edge(mul_inputs[0], sum_node, ());
|
||||
graph.add_edge(mul_inputs[1], sum_node, ());
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
#[cfg(test)]
|
||||
pub(crate) fn contains_matmul(&self) -> bool {
|
||||
self.llir_graph.node_indices().any(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.is_some_and(|op| op.is_matmul())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn debug_kernel_ops(&self) -> Vec<String> {
|
||||
self.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.map(|op| format!("{op:?}"))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
|
||||
self.input_data.insert(id.to_id(), data.into());
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let id = id.to_id();
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
|
||||
let data_id = self
|
||||
.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
let buffer = self
|
||||
.buffers
|
||||
.get(&data_id)
|
||||
.or_else(|| {
|
||||
// If data_id is an Input node, get from hlir_buffers
|
||||
if let Some(Input { node, .. }) = self.llir_graph[data_id].to_op::<Input>() {
|
||||
self.hlir_buffers.get(&NodeIndex::new(*node))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.expect("Cannot find tensor in runtime!");
|
||||
let dtype = self
|
||||
.node_dtypes
|
||||
.get(&data_id)
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
self.llir_graph[data_id]
|
||||
.to_op::<Input>()
|
||||
.map(|inp| inp.dtype)
|
||||
})
|
||||
.unwrap_or(DType::F32);
|
||||
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F16 => {
|
||||
let ptr = buffer.contents() as *const f16;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f16>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect()
|
||||
}
|
||||
DType::Int => {
|
||||
let ptr = buffer.contents() as *const i32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<i32>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| *v as f32)
|
||||
.collect()
|
||||
}
|
||||
_ => {
|
||||
let ptr = buffer.contents() as *const f32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f32>();
|
||||
std::slice::from_raw_parts(ptr, len).to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Runtime for MetalRuntime {
|
||||
type Ops = crate::kernel::MetalOps;
|
||||
type CompileArg = ();
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = Duration;
|
||||
|
||||
fn initialize(_: Self::CompileArg) -> Self {
|
||||
let device = Device::system_default().expect("No Metal device found!");
|
||||
let command_queue = device.new_command_queue();
|
||||
let dyn_buffer = device.new_buffer(
|
||||
(DYN_SLOT_COUNT * std::mem::size_of::<i32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
Self {
|
||||
device,
|
||||
command_queue,
|
||||
input_data: FxHashMap::default(),
|
||||
hlir_buffers: FxHashMap::default(),
|
||||
buffers: FxHashMap::default(),
|
||||
dyn_buffer,
|
||||
llir_graph: StableGraph::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
self.buffers.clear();
|
||||
self.hlir_buffers.clear();
|
||||
self.node_dtypes.clear();
|
||||
self.llir_graph = Self::fuse_matmuls(llir_graph);
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
self.node_dtypes.insert(node, input.dtype);
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn profile(
|
||||
&mut self,
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
|
||||
let trials = trials.max(1);
|
||||
let mut duration = Duration::default();
|
||||
for _ in 0..trials {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
duration += start.elapsed();
|
||||
}
|
||||
duration /= trials as u32;
|
||||
|
||||
(duration, format!("{:.2?}", duration))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) -> Self::ExecReturn {
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
// Bind dyn dims right after the output slot:
|
||||
// [inputs..., output, dyn, bytes...]
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeStats for MetalRuntime {
|
||||
fn execute_with_stats(&mut self, dyn_map: &FxHashMap<char, usize>) -> Option<ExecutionStats> {
|
||||
let mut total_bytes_loaded = 0usize;
|
||||
let mut total_bytes_stored = 0usize;
|
||||
let mut total_flops = 0usize;
|
||||
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
total_bytes_loaded += kernel_op.bytes_loaded(dyn_map);
|
||||
total_bytes_stored += kernel_op.bytes_stored(dyn_map);
|
||||
total_flops += kernel_op.flops(dyn_map);
|
||||
}
|
||||
}
|
||||
let (time_us, timing_method) = self.execute_timed(dyn_map);
|
||||
|
||||
Some(ExecutionStats::with_timing_method(
|
||||
time_us,
|
||||
total_bytes_loaded,
|
||||
total_bytes_stored,
|
||||
total_flops,
|
||||
timing_method,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
DType::F16 => {
|
||||
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
DType::Int => {
|
||||
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
unsupported => panic!("Metal input dtype {unsupported:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let buffer = self.device.new_buffer(
|
||||
(size * dtype.bits().div_ceil(8)) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_dyn_buffer(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let ptr = self.dyn_buffer.contents() as *mut i32;
|
||||
unsafe {
|
||||
for idx in 0..DYN_SLOT_COUNT {
|
||||
*ptr.add(idx) = 0;
|
||||
}
|
||||
for (&symbol, &value) in dyn_map {
|
||||
if symbol.is_ascii_lowercase() {
|
||||
let slot = (symbol as u8 - b'a') as usize;
|
||||
if slot < DYN_SLOT_COUNT {
|
||||
*ptr.add(slot) = value as i32;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute and return GPU-side execution time in microseconds.
|
||||
fn execute_timed(&mut self, dyn_map: &FxHashMap<char, usize>) -> (f64, TimingMethod) {
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
// gpuStartTime and gpuEndTime are available on macOS 10.15+
|
||||
let gpu_start: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUStartTime]
|
||||
};
|
||||
let gpu_end: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUEndTime]
|
||||
};
|
||||
|
||||
let gpu_time_seconds = gpu_end - gpu_start;
|
||||
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
|
||||
|
||||
(gpu_time_us, TimingMethod::DeviceTimestamp)
|
||||
}
|
||||
}
|
||||
1014
crates/luminal_metal/src/tests.rs
Normal file
1014
crates/luminal_metal/src/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
17
crates/luminal_nn/Cargo.toml
Normal file
17
crates/luminal_nn/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "luminal_nn"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.12.1"
|
||||
luminal = { path = "../.." }
|
||||
rustc-hash = "1.1.0"
|
||||
rand = "0.9.2"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = { version = "0.13", features = ["f16"] }
|
||||
paste = "1.0.14"
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
125
crates/luminal_nn/src/activation.rs
Normal file
125
crates/luminal_nn/src/activation.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Rectified Linear Unit activation function
|
||||
#[derive(Default)]
|
||||
pub struct ReLU;
|
||||
|
||||
impl ReLU {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
input.relu()
|
||||
}
|
||||
}
|
||||
|
||||
/// Gaussian Error Linear Unit activation function
|
||||
#[derive(Default)]
|
||||
pub struct GeLU;
|
||||
|
||||
impl GeLU {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
input.gelu()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid activation function
|
||||
#[derive(Default)]
|
||||
pub struct Sigmoid;
|
||||
|
||||
impl Sigmoid {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
input.sigmoid()
|
||||
}
|
||||
}
|
||||
|
||||
/// Swish activation function
|
||||
#[derive(Default)]
|
||||
pub struct Swish;
|
||||
|
||||
impl Swish {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
input.swish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Tanh activation function
|
||||
#[derive(Default)]
|
||||
pub struct Tanh;
|
||||
|
||||
impl Tanh {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
input.tanh()
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::ReLU;
|
||||
// use crate::Linear;
|
||||
// use dfdx::prelude::{Module as DfdxModule, *};
|
||||
// use luminal::{
|
||||
// prelude::{Module, *},
|
||||
// tests::assert_close,
|
||||
// };
|
||||
|
||||
// #[test]
|
||||
// fn test_relu_and_linear() {
|
||||
// // Test single and batch, unoptimized and optimized
|
||||
// let mut cx = Graph::new();
|
||||
// let batch = cx.tensor((2, 3)).set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
// let a = cx.tensor(3).set(vec![1.0, 2.0, 3.0]);
|
||||
|
||||
// let model = (
|
||||
// Linear::new(3, 4, false, &mut cx),
|
||||
// ReLU,
|
||||
// Linear::new(4, 2, false, &mut cx),
|
||||
// );
|
||||
// model
|
||||
// .0
|
||||
// .weight
|
||||
// .set(vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
// model.2.weight.set(vec![1., 2., 3., 1., 2., 3., 1., 2.]);
|
||||
// let mut b = model.forward(a).retrieve();
|
||||
// let mut batch_out = model.forward(batch).retrieve();
|
||||
|
||||
// cx.execute();
|
||||
|
||||
// let unoptimized_b = b.data();
|
||||
// let unoptimized_batch_out = batch_out.data();
|
||||
|
||||
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
|
||||
// cx.execute();
|
||||
|
||||
// assert_close(&unoptimized_b, &b.data());
|
||||
// assert_close(&unoptimized_batch_out, &batch_out.data());
|
||||
|
||||
// // Test against dfdx
|
||||
// let dev = Cpu::default();
|
||||
// let mut model = <(
|
||||
// dfdx::nn::modules::builders::UnbiasedLinear<3, 4>,
|
||||
// dfdx::nn::modules::builders::ReLU,
|
||||
// dfdx::nn::modules::builders::UnbiasedLinear<4, 2>,
|
||||
// )>::build_on_device(&dev);
|
||||
// // Set weights
|
||||
// model.0.weight = dev
|
||||
// .tensor_from_vec(
|
||||
// vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.],
|
||||
// (dfdx::shapes::Const::<3>, dfdx::shapes::Const::<4>),
|
||||
// )
|
||||
// .permute();
|
||||
// model.2.weight = dev
|
||||
// .tensor_from_vec(
|
||||
// vec![1., 2., 3., 1., 2., 3., 1., 2.],
|
||||
// (dfdx::shapes::Const::<4>, dfdx::shapes::Const::<2>),
|
||||
// )
|
||||
// .permute();
|
||||
// let a = dev.tensor_from_vec(vec![1.0, 2.0, 3.0], (dfdx::shapes::Const::<3>,));
|
||||
// let d_batch = dev.tensor_from_vec(
|
||||
// vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
|
||||
// (dfdx::shapes::Const::<2>, dfdx::shapes::Const::<3>),
|
||||
// );
|
||||
// let out = model.forward(a);
|
||||
// let d_batch_out = model.forward(d_batch);
|
||||
|
||||
// assert_close(&unoptimized_b, &out.as_vec());
|
||||
// assert_close(&unoptimized_batch_out, &d_batch_out.as_vec());
|
||||
// }
|
||||
// }
|
||||
451
crates/luminal_nn/src/attention.rs
Normal file
451
crates/luminal_nn/src/attention.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
use luminal::prelude::*;
|
||||
use luminal::shape::Expression;
|
||||
|
||||
/// Gather entire rows from a 2D tensor using row indices.
|
||||
///
|
||||
/// - `data`: (R, D) tensor
|
||||
/// - `indices`: (N,) Int tensor of row indices
|
||||
/// - `d`: the number of columns (D), must match data's second dimension
|
||||
///
|
||||
/// Returns: (N, D) tensor where output[i] = data[indices[i]]
|
||||
pub fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
|
||||
assert_eq!(indices.dtype, DType::Int);
|
||||
let n = indices.dims1();
|
||||
|
||||
// base[i] = indices[i] * D → flat starting position for each row
|
||||
let base = (indices * d).expand_dim(1, d); // (N, D) broadcast along cols
|
||||
|
||||
// col[j] = j → column offsets 0..D
|
||||
let col = data.graph().arange(d as i32).expand_dim(0, n); // (N, D) broadcast along rows
|
||||
|
||||
// flat_idx[i,j] = indices[i] * D + j
|
||||
let flat_idx = base + col;
|
||||
|
||||
data.gather(flat_idx)
|
||||
}
|
||||
|
||||
/// Scatter entire rows into a 2D tensor using row indices.
|
||||
///
|
||||
/// - `src`: (N, D) tensor of values to write
|
||||
/// - `indices`: (N,) Int tensor of destination row indices
|
||||
/// - `dest`: (R, D) tensor to write into (copied first, then overwritten at index positions)
|
||||
/// - `d`: the number of columns (D)
|
||||
///
|
||||
/// Returns: (R, D) tensor where output = copy(dest); output[indices[i]] = src[i]
|
||||
pub fn scatter_rows(
|
||||
src: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
dest: GraphTensor,
|
||||
d: usize,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(indices.dtype, DType::Int);
|
||||
let n = indices.dims1();
|
||||
|
||||
// Same index expansion as gather_rows
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = src.graph().arange(d as i32).expand_dim(0, n);
|
||||
let flat_idx = base + col;
|
||||
|
||||
src.scatter(flat_idx, dest)
|
||||
}
|
||||
|
||||
/// Pure HLIR paged attention for one layer with causal masking.
|
||||
///
|
||||
/// Inputs:
|
||||
/// - `q`: (s, hidden) f32 — query vectors
|
||||
/// - `k_new`: (s, kv_dim) f32 — new key vectors
|
||||
/// - `v_new`: (s, kv_dim) f32 — new value vectors
|
||||
/// - `k_cache`: (num_slots, kv_dim) f32 — key cache (preallocated)
|
||||
/// - `v_cache`: (num_slots, kv_dim) f32 — value cache (preallocated)
|
||||
/// - `gather_idx`: (ctx_len,) Int — which cache slots to read
|
||||
/// - `scatter_idx`: (s,) Int — which cache slots to write new KV into
|
||||
/// - `prev_seq`: number of previously cached tokens (for causal mask offset)
|
||||
/// - `n_heads`: number of query heads
|
||||
/// - `n_kv_heads`: number of KV heads (for GQA)
|
||||
/// - `head_dim`: dimension per head
|
||||
///
|
||||
/// Returns: (attn_out, k_cache_new, v_cache_new)
|
||||
/// - `attn_out`: (s, hidden) f32
|
||||
/// - `k_cache_new`: (num_slots, kv_dim) f32
|
||||
/// - `v_cache_new`: (num_slots, kv_dim) f32
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_attention(
|
||||
q: GraphTensor,
|
||||
k_new: GraphTensor,
|
||||
v_new: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
prev_seq: Expression,
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let kv_dim = n_kv_heads * head_dim;
|
||||
let kv_groups = n_heads / n_kv_heads;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let s = q.dims()[0];
|
||||
let ctx = gather_idx.dims()[0];
|
||||
let cx = q.graph();
|
||||
|
||||
// ── Phase 1: Write new KV into cache ──
|
||||
let k_cache = scatter_rows(k_new, scatter_idx, k_cache, kv_dim);
|
||||
let v_cache = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
|
||||
|
||||
// ── Phase 2: Gather context KV from cache ──
|
||||
let k = gather_rows(k_cache, gather_idx, kv_dim); // (ctx, kv_dim)
|
||||
let v = gather_rows(v_cache, gather_idx, kv_dim); // (ctx, kv_dim)
|
||||
|
||||
// ── Phase 3: Reshape for multi-head attention ──
|
||||
// Q: (s, hidden) → (s, n_heads, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
|
||||
// → (n_kv_heads, kv_groups, s, head_dim)
|
||||
let q = q
|
||||
.split_dims(1, head_dim) // (s, n_heads, head_dim)
|
||||
.split_dims(1, kv_groups) // (s, n_kv_heads, kv_groups, head_dim)
|
||||
.permute((1, 2, 0, 3)); // (n_kv_heads, kv_groups, s, head_dim)
|
||||
|
||||
// K: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, head_dim, ctx)
|
||||
let k = k
|
||||
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
|
||||
.permute((1, 2, 0)); // (n_kv_heads, head_dim, ctx)
|
||||
|
||||
// V: (ctx, kv_dim) → (ctx, n_kv_heads, head_dim) → (n_kv_heads, ctx, head_dim)
|
||||
let v = v
|
||||
.split_dims(1, head_dim) // (ctx, n_kv_heads, head_dim)
|
||||
.permute((1, 0, 2)); // (n_kv_heads, ctx, head_dim)
|
||||
|
||||
// ── Phase 4: Attention ──
|
||||
// Broadcast K, V over kv_groups dimension
|
||||
let k = k.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, head_dim, ctx)
|
||||
let v = v.expand_dim(1, kv_groups); // (n_kv_heads, kv_groups, ctx, head_dim)
|
||||
|
||||
// QK^T: (n_kv_heads, kv_groups, s, head_dim) @ (n_kv_heads, kv_groups, head_dim, ctx)
|
||||
// → (n_kv_heads, kv_groups, s, ctx)
|
||||
let scores = q.matmul(k) * scale;
|
||||
|
||||
// Build causal mask: query at position prev_seq+i can attend to context j iff j <= prev_seq+i.
|
||||
// row_vals[i] = prev_seq + i, col_vals[j] = j
|
||||
// mask[i,j] = -1e9 where row_vals[i] < col_vals[j], else 0
|
||||
let z = Expression::from('z');
|
||||
let row_vals = cx.iota(z + prev_seq, s).expand_dim(1, ctx); // (s, ctx)
|
||||
let col_vals = cx.arange(ctx).expand_dim(0, s); // (s, ctx)
|
||||
let mask = row_vals
|
||||
.cast(DType::F32)
|
||||
.lt(col_vals.cast(DType::F32))
|
||||
.cast(DType::F32)
|
||||
* -1e9;
|
||||
|
||||
// Broadcast (s, ctx) → (n_kv_heads, kv_groups, s, ctx)
|
||||
let mask = mask.expand_dim(0, n_kv_heads).expand_dim(1, kv_groups);
|
||||
let scores = scores + mask;
|
||||
|
||||
// Softmax over context dimension (axis 3)
|
||||
let weights = scores.softmax(3);
|
||||
|
||||
// Weighted sum: (n_kv_heads, kv_groups, s, ctx) @ (n_kv_heads, kv_groups, ctx, head_dim)
|
||||
// → (n_kv_heads, kv_groups, s, head_dim)
|
||||
let out = weights.matmul(v);
|
||||
|
||||
// ── Phase 5: Reshape output ──
|
||||
// (n_kv_heads, kv_groups, s, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
|
||||
let mut out = out.permute((2, 0, 1, 3));
|
||||
out.shape = ShapeTracker::new((s, n_heads * head_dim));
|
||||
|
||||
(out, k_cache, v_cache)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gather_rows() {
|
||||
let mut cx = Graph::new();
|
||||
let data = cx.tensor((4, 3)); // 4 rows, 3 cols
|
||||
let indices = cx.tensor(3).as_dtype(DType::Int);
|
||||
let result = gather_rows(data, indices, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
|
||||
rt.set_data(
|
||||
data.id,
|
||||
vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
|
||||
);
|
||||
// Gather rows 0, 2, 3
|
||||
rt.set_data(indices.id, vec![0, 2, 3]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(
|
||||
*rt.get_f32(result.id),
|
||||
vec![1., 2., 3., 7., 8., 9., 10., 11., 12.]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_rows() {
|
||||
let mut cx = Graph::new();
|
||||
let src = cx.tensor((2, 3));
|
||||
let indices = cx.tensor(2).as_dtype(DType::Int);
|
||||
let dest = cx.tensor((4, 3));
|
||||
let result = scatter_rows(src, indices, dest, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
|
||||
rt.set_data(indices.id, vec![1, 3]);
|
||||
rt.set_data(dest.id, vec![0.; 12]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(
|
||||
*rt.get_f32(result.id),
|
||||
vec![0., 0., 0., 10., 20., 30., 0., 0., 0., 40., 50., 60.]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_then_gather_roundtrip() {
|
||||
let mut cx = Graph::new();
|
||||
let kv_new = cx.tensor((2, 4)); // 2 new rows, dim=4
|
||||
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
|
||||
let cache = cx.tensor((6, 4)); // 6 slots
|
||||
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
|
||||
|
||||
// Scatter new rows into cache, then gather them back
|
||||
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
|
||||
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
|
||||
rt.set_data(cache.id, vec![0.; 24]); // Zero cache
|
||||
rt.set_data(gather_idx.id, vec![1, 4]); // Read back from same slots
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(
|
||||
*rt.get_f32(gathered.id),
|
||||
vec![1., 2., 3., 4., 5., 6., 7., 8.]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_shape_and_cache_update() {
|
||||
// Minimal config: n_heads=2, n_kv_heads=2, head_dim=2, kv_groups=1
|
||||
// hidden = 4, kv_dim = 4
|
||||
let n_heads = 2;
|
||||
let n_kv_heads = 2;
|
||||
let head_dim = 2;
|
||||
let hidden = n_heads * head_dim; // 4
|
||||
let kv_dim = n_kv_heads * head_dim; // 4
|
||||
let num_slots = 8;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((1, hidden)); // 1 new token
|
||||
let k_new = cx.tensor((1, kv_dim));
|
||||
let v_new = cx.tensor((1, kv_dim));
|
||||
let k_cache = cx.tensor((num_slots, kv_dim));
|
||||
let v_cache = cx.tensor((num_slots, kv_dim));
|
||||
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context tokens
|
||||
let scatter_idx = cx.tensor(1).as_dtype(DType::Int); // 1 new token
|
||||
|
||||
// prev_seq=2: this is the 3rd token (positions 0,1 cached, position 2 is new)
|
||||
let (attn_out, k_cache_new, v_cache_new) = paged_attention(
|
||||
q,
|
||||
k_new,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
2.into(),
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
let k_cache_new = k_cache_new.output();
|
||||
let v_cache_new = v_cache_new.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
|
||||
rt.set_data(q.id, vec![1., 0., 1., 0.]);
|
||||
// k_new = [0.5, 0.5, 0.5, 0.5]
|
||||
rt.set_data(k_new.id, vec![0.5, 0.5, 0.5, 0.5]);
|
||||
// v_new = [1, 2, 3, 4]
|
||||
rt.set_data(v_new.id, vec![1., 2., 3., 4.]);
|
||||
// Zero caches
|
||||
rt.set_data(k_cache.id, vec![0.; num_slots * kv_dim]);
|
||||
rt.set_data(v_cache.id, vec![0.; num_slots * kv_dim]);
|
||||
// Scatter new KV to slot 2
|
||||
rt.set_data(scatter_idx.id, vec![2]);
|
||||
// Gather context from slots 0, 1, 2 (slots 0,1 are zeros, slot 2 is the new KV)
|
||||
rt.set_data(gather_idx.id, vec![0, 1, 2]);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Verify output shape: (1, hidden=4)
|
||||
let out = rt.get_f32(attn_out.id);
|
||||
assert_eq!(out.len(), hidden);
|
||||
|
||||
// Verify KV cache was updated: k_cache_new should have [0.5, 0.5, 0.5, 0.5] at slot 2
|
||||
let k_out = rt.get_f32(k_cache_new.id);
|
||||
assert_eq!(k_out.len(), num_slots * kv_dim);
|
||||
// Slot 2 is at offset 2*4=8..12
|
||||
assert_eq!(&k_out[8..12], &[0.5, 0.5, 0.5, 0.5]);
|
||||
// Slot 0 should still be zeros
|
||||
assert_eq!(&k_out[0..4], &[0., 0., 0., 0.]);
|
||||
|
||||
let v_out = rt.get_f32(v_cache_new.id);
|
||||
assert_eq!(&v_out[8..12], &[1., 2., 3., 4.]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_known_values() {
|
||||
// Test with values where we can compute expected attention output.
|
||||
// n_heads=1, n_kv_heads=1, head_dim=2, kv_groups=1
|
||||
// hidden=2, kv_dim=2
|
||||
let n_heads = 1;
|
||||
let n_kv_heads = 1;
|
||||
let head_dim = 2;
|
||||
let hidden = 2;
|
||||
let kv_dim = 2;
|
||||
let num_slots = 4;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((1, hidden));
|
||||
let k_new = cx.tensor((1, kv_dim));
|
||||
let v_new = cx.tensor((1, kv_dim));
|
||||
let k_cache = cx.tensor((num_slots, kv_dim));
|
||||
let v_cache = cx.tensor((num_slots, kv_dim));
|
||||
let gather_idx = cx.tensor(2).as_dtype(DType::Int);
|
||||
let scatter_idx = cx.tensor(1).as_dtype(DType::Int);
|
||||
|
||||
// prev_seq=1: 1 cached token + 1 new token, context len=2
|
||||
// Query at absolute position 1 can attend to context positions 0 and 1
|
||||
let (attn_out, _, _) = paged_attention(
|
||||
q,
|
||||
k_new,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
1.into(),
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
|
||||
// K cached at slot 0: [1, 0]
|
||||
// K new (written to slot 1): [0, 1]
|
||||
// V cached at slot 0: [10, 20]
|
||||
// V new (written to slot 1): [30, 40]
|
||||
// Q: [1, 1]
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
k_cache_data[0] = 1.;
|
||||
k_cache_data[1] = 0.; // slot 0 K = [1, 0]
|
||||
let mut v_cache_data = vec![0.; num_slots * kv_dim];
|
||||
v_cache_data[0] = 10.;
|
||||
v_cache_data[1] = 20.; // slot 0 V = [10, 20]
|
||||
|
||||
rt.set_data(q.id, vec![1., 1.]);
|
||||
rt.set_data(k_new.id, vec![0., 1.]); // new K = [0, 1]
|
||||
rt.set_data(v_new.id, vec![30., 40.]); // new V = [30, 40]
|
||||
rt.set_data(k_cache.id, k_cache_data);
|
||||
rt.set_data(v_cache.id, v_cache_data);
|
||||
rt.set_data(scatter_idx.id, vec![1]); // write to slot 1
|
||||
rt.set_data(gather_idx.id, vec![0, 1]); // gather slots 0, 1
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out = rt.get_f32(attn_out.id);
|
||||
assert_eq!(out.len(), hidden);
|
||||
let expected = vec![20.0, 30.0];
|
||||
for (a, b) in out.iter().zip(&expected) {
|
||||
assert!((a - b).abs() < 0.1, "Expected {expected:?}, got {out:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_causal_mask() {
|
||||
// Verify that the causal mask blocks future positions.
|
||||
// n_heads=1, n_kv_heads=1, head_dim=2
|
||||
let n_heads = 1;
|
||||
let n_kv_heads = 1;
|
||||
let head_dim = 2;
|
||||
let hidden = 2;
|
||||
let kv_dim = 2;
|
||||
let num_slots = 4;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((2, hidden)); // 2 new tokens
|
||||
let k_new = cx.tensor((2, kv_dim));
|
||||
let v_new = cx.tensor((2, kv_dim));
|
||||
let k_cache = cx.tensor((num_slots, kv_dim));
|
||||
let v_cache = cx.tensor((num_slots, kv_dim));
|
||||
let gather_idx = cx.tensor(3).as_dtype(DType::Int); // 3 context (1 cached + 2 new)
|
||||
let scatter_idx = cx.tensor(2).as_dtype(DType::Int);
|
||||
|
||||
// prev_seq=1: 1 cached token, 2 new tokens → context len=3
|
||||
// Query 0 at absolute pos 1: can see ctx 0,1 (not 2)
|
||||
// Query 1 at absolute pos 2: can see ctx 0,1,2
|
||||
let (attn_out, _, _) = paged_attention(
|
||||
q,
|
||||
k_new,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
1.into(),
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Cache has 1 token at slot 0
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
k_cache_data[0] = 1.;
|
||||
k_cache_data[1] = 0.; // slot 0: K=[1,0]
|
||||
let mut v_cache_data = vec![0.; num_slots * kv_dim];
|
||||
v_cache_data[0] = 100.;
|
||||
v_cache_data[1] = 0.; // slot 0: V=[100,0]
|
||||
|
||||
// 2 new tokens
|
||||
rt.set_data(q.id, vec![1., 0., 0., 1.]);
|
||||
rt.set_data(k_new.id, vec![0., 1., 1., 1.]); // token0 K=[0,1], token1 K=[1,1]
|
||||
rt.set_data(v_new.id, vec![0., 10., 0., 20.]); // token0 V=[0,10], token1 V=[0,20]
|
||||
rt.set_data(k_cache.id, k_cache_data);
|
||||
rt.set_data(v_cache.id, v_cache_data);
|
||||
rt.set_data(scatter_idx.id, vec![1, 2]); // write to slots 1, 2
|
||||
rt.set_data(gather_idx.id, vec![0, 1, 2]); // gather all 3
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out = rt.get_f32(attn_out.id);
|
||||
assert_eq!(out.len(), 2 * hidden);
|
||||
|
||||
// Token 0 (abs pos 1): attends to ctx 0,1 only (ctx 2 is masked)
|
||||
// Token 1 (abs pos 2): attends to ctx 0,1,2
|
||||
// Verify output has valid (non-NaN, non-inf) values and correct length
|
||||
for val in out.iter() {
|
||||
assert!(val.is_finite(), "Output contains non-finite value: {}", val);
|
||||
}
|
||||
}
|
||||
}
|
||||
408
crates/luminal_nn/src/convolution.rs
Normal file
408
crates/luminal_nn/src/convolution.rs
Normal file
@@ -0,0 +1,408 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Generic N-dimensional convolution layer implemented with the GraphTensor `unfold` helper.
|
||||
///
|
||||
/// The layer expects inputs shaped like `[batch..., channels, spatial...]` where the number of
|
||||
/// spatial dimensions is greater than zero. The kernel configuration controls how many spatial
|
||||
/// axes are convolved (N) and must be shorter than the input rank (K): `K > N` is asserted.
|
||||
pub struct ConvND {
|
||||
pub weight: GraphTensor, // (ch_out, ch_in * kernel_product)
|
||||
pub bias: Option<GraphTensor>,
|
||||
kernel: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
dilation: Vec<usize>,
|
||||
padding: Vec<usize>,
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
}
|
||||
|
||||
impl ConvND {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
kernel: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
dilation: Vec<usize>,
|
||||
padding: Vec<usize>,
|
||||
bias: bool,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
assert!(
|
||||
!kernel.is_empty(),
|
||||
"ConvND requires at least one spatial dimension in the kernel",
|
||||
);
|
||||
let k = kernel.len();
|
||||
assert_eq!(
|
||||
stride.len(),
|
||||
k,
|
||||
"Stride dimensions ({}) must match kernel dimensions ({k})",
|
||||
stride.len()
|
||||
);
|
||||
assert_eq!(
|
||||
dilation.len(),
|
||||
k,
|
||||
"Dilation dimensions ({}) must match kernel dimensions ({k})",
|
||||
dilation.len()
|
||||
);
|
||||
assert_eq!(
|
||||
padding.len(),
|
||||
k,
|
||||
"Padding dimensions ({}) must match kernel dimensions ({k})",
|
||||
padding.len()
|
||||
);
|
||||
|
||||
let kernel_product: usize = kernel.iter().product();
|
||||
|
||||
Self {
|
||||
weight: cx
|
||||
.named_tensor("ConvWeight", (ch_out, ch_in * kernel_product))
|
||||
.persist(),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("ConvBias", ch_out).persist())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
kernel,
|
||||
stride,
|
||||
dilation,
|
||||
padding,
|
||||
ch_in,
|
||||
ch_out,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply convolution to an input shaped `[batch..., channels, spatial...]`.
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
let input_dims = input.dims();
|
||||
let rank = input_dims.len();
|
||||
let spatial = self.kernel.len();
|
||||
|
||||
assert!(
|
||||
rank > spatial,
|
||||
"ConvND expects input rank ({rank}) to be greater than kernel dims ({spatial})",
|
||||
);
|
||||
|
||||
let batch_len = rank - spatial - 1;
|
||||
assert_eq!(
|
||||
input_dims[batch_len],
|
||||
Expression::from(self.ch_in),
|
||||
"Input channel dimension ({}) must match ch_in ({})",
|
||||
input_dims[batch_len],
|
||||
self.ch_in
|
||||
);
|
||||
assert_eq!(
|
||||
self.weight.dims()[0],
|
||||
Expression::from(self.ch_out),
|
||||
"Weight output channels ({}) must match ch_out ({})",
|
||||
self.weight.dims()[0],
|
||||
self.ch_out
|
||||
);
|
||||
|
||||
// Pad only the spatial dimensions.
|
||||
let mut padding = vec![(Expression::from(0), Expression::from(0)); rank];
|
||||
for (i, pad) in self.padding.iter().enumerate() {
|
||||
let axis = batch_len + 1 + i;
|
||||
padding[axis] = (Expression::from(*pad), Expression::from(*pad));
|
||||
}
|
||||
let padded = input.pad(padding, 0.0);
|
||||
|
||||
// Build unfold parameters with ones for non-spatial axes.
|
||||
let mut kernel_shape = vec![1; rank];
|
||||
let mut stride_shape = vec![1; rank];
|
||||
let mut dilation_shape = vec![1; rank];
|
||||
for i in 0..spatial {
|
||||
let axis = batch_len + 1 + i;
|
||||
kernel_shape[axis] = self.kernel[i];
|
||||
stride_shape[axis] = self.stride[i];
|
||||
dilation_shape[axis] = self.dilation[i];
|
||||
}
|
||||
|
||||
let unfolded = padded.unfold(kernel_shape, stride_shape, dilation_shape);
|
||||
|
||||
// Move window dimensions to the front for easier indexing.
|
||||
let mut order: Vec<usize> = (rank..2 * rank).collect();
|
||||
order.extend(0..rank);
|
||||
let unfolded = unfolded.permute(order);
|
||||
let unfolded_dims = unfolded.dims();
|
||||
|
||||
// Capture output spatial dimensions from the unfolded view.
|
||||
let output_dims: Vec<Expression> =
|
||||
unfolded_dims[batch_len + 1..batch_len + 1 + spatial].to_vec();
|
||||
|
||||
// Reorder to [batch..., out..., channels, kernel_spatial..., kernel_batch..., kernel_channel].
|
||||
let mut order2 = Vec::with_capacity(2 * rank);
|
||||
// window batch dims
|
||||
order2.extend(0..batch_len);
|
||||
// window spatial dims (outputs)
|
||||
order2.extend(batch_len + 1..batch_len + 1 + spatial);
|
||||
// window channel dim
|
||||
order2.push(batch_len);
|
||||
// kernel spatial dims
|
||||
order2.extend(rank + batch_len + 1..rank + batch_len + 1 + spatial);
|
||||
// kernel batch dims and kernel channel dim (to be merged away)
|
||||
order2.extend(rank..rank + batch_len + 1);
|
||||
let mut patches = unfolded.permute(order2);
|
||||
|
||||
// Drop kernel axes for batch + channel by merging them into the previous dimension.
|
||||
for _ in 0..=batch_len {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
|
||||
// Flatten channel and kernel spatial dimensions together.
|
||||
for _ in 0..spatial {
|
||||
let channel_axis = batch_len + spatial;
|
||||
patches = patches.merge_dims(channel_axis, channel_axis + 1);
|
||||
}
|
||||
|
||||
// Collapse batch dimensions into one and output dimensions into one for matmul.
|
||||
for _ in 1..batch_len {
|
||||
patches = patches.merge_dims(0, 1);
|
||||
}
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
|
||||
let mut out = patches.matmul(self.weight.permute((1, 0)));
|
||||
|
||||
// Restore batch and spatial dimensions.
|
||||
for dim in self.input_batch_dims(&input_dims, batch_len).iter().rev() {
|
||||
out = out.split_dims(0, *dim);
|
||||
}
|
||||
for dim in output_dims.iter().rev() {
|
||||
out = out.split_dims(batch_len, *dim);
|
||||
}
|
||||
|
||||
// Move channel dimension ahead of the spatial axes: [batch..., ch_out, spatial...]
|
||||
let mut final_order: Vec<usize> = (0..batch_len).collect();
|
||||
final_order.push(batch_len + spatial);
|
||||
final_order.extend(batch_len..batch_len + spatial);
|
||||
out = out.permute(final_order);
|
||||
|
||||
if let Some(_b) = self.bias {
|
||||
todo!()
|
||||
// out += b.expand(out.shape);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn input_batch_dims(&self, input_dims: &[Expression], batch_len: usize) -> Vec<Expression> {
|
||||
input_dims[..batch_len].to_vec()
|
||||
}
|
||||
|
||||
pub fn infer_output_shape(&self, input: &[usize]) -> Vec<usize> {
|
||||
let rank = input.len();
|
||||
let spatial = self.kernel.len();
|
||||
|
||||
assert!(rank > spatial, "expected input rank > spatial dims");
|
||||
let batch_len = rank - spatial - 1;
|
||||
assert_eq!(
|
||||
input[batch_len], self.ch_in,
|
||||
"input channel dimension does not match ch_in",
|
||||
);
|
||||
|
||||
let batch_prefix = &input[..batch_len];
|
||||
let spatial_dims = &input[batch_len + 1..];
|
||||
let out_spatial: Vec<usize> = spatial_dims
|
||||
.iter()
|
||||
.zip(
|
||||
self.kernel
|
||||
.iter()
|
||||
.zip(self.stride.iter())
|
||||
.zip(self.dilation.iter())
|
||||
.zip(self.padding.iter()),
|
||||
)
|
||||
.map(|(dim, (((k, s), d), p))| (dim + 2 * p - d * (k - 1) - 1) / s + 1)
|
||||
.collect();
|
||||
|
||||
let mut shape = batch_prefix.to_vec();
|
||||
shape.push(self.ch_out);
|
||||
shape.extend(out_spatial);
|
||||
shape
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ConvND;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn assert_close(a: &[f32], b: &[f32]) {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"length mismatch: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
);
|
||||
for (idx, (lhs, rhs)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff = (lhs - rhs).abs();
|
||||
if diff > 1e-4 {
|
||||
panic!("values differ at {idx}: {lhs} vs {rhs} (diff {diff})");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn candle_conv1d_output(
|
||||
conv: &ConvND,
|
||||
input: &[f32],
|
||||
width: usize,
|
||||
weight: &[f32],
|
||||
bias: Option<&[f32]>,
|
||||
) -> candle_core::Result<Vec<f32>> {
|
||||
let device = Device::Cpu;
|
||||
let input = Tensor::from_vec(input.to_vec(), (1, conv.ch_in, width), &device)?;
|
||||
let weight = Tensor::from_vec(
|
||||
weight.to_vec(),
|
||||
(conv.ch_out, conv.ch_in, conv.kernel[0]),
|
||||
&device,
|
||||
)?;
|
||||
let bias = match bias {
|
||||
Some(b) => Some(Tensor::from_vec(b.to_vec(), conv.ch_out, &device)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let output = input.conv1d(
|
||||
&weight,
|
||||
conv.padding[0],
|
||||
conv.stride[0],
|
||||
conv.dilation[0],
|
||||
1,
|
||||
)?;
|
||||
let output = match bias {
|
||||
Some(bias) => {
|
||||
let bias = bias.reshape((1, conv.ch_out, 1))?;
|
||||
output.broadcast_add(&bias)?
|
||||
}
|
||||
None => output,
|
||||
};
|
||||
output.flatten_all()?.to_vec1::<f32>()
|
||||
}
|
||||
|
||||
fn candle_conv2d_output(
|
||||
conv: &ConvND,
|
||||
input: &[f32],
|
||||
height: usize,
|
||||
width: usize,
|
||||
weight: &[f32],
|
||||
bias: Option<&[f32]>,
|
||||
) -> candle_core::Result<Vec<f32>> {
|
||||
let device = Device::Cpu;
|
||||
let input = Tensor::from_vec(input.to_vec(), (1, conv.ch_in, height, width), &device)?;
|
||||
let weight = Tensor::from_vec(
|
||||
weight.to_vec(),
|
||||
(conv.ch_out, conv.ch_in, conv.kernel[0], conv.kernel[1]),
|
||||
&device,
|
||||
)?;
|
||||
let bias = match bias {
|
||||
Some(b) => Some(Tensor::from_vec(b.to_vec(), conv.ch_out, &device)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
conv.padding[0], conv.padding[1],
|
||||
"Candle conv2d only supports equal padding"
|
||||
);
|
||||
assert_eq!(
|
||||
conv.stride[0], conv.stride[1],
|
||||
"Candle conv2d only supports equal stride"
|
||||
);
|
||||
assert_eq!(
|
||||
conv.dilation[0], conv.dilation[1],
|
||||
"Candle conv2d only supports equal dilation"
|
||||
);
|
||||
|
||||
let output = input.conv2d(
|
||||
&weight,
|
||||
conv.padding[0],
|
||||
conv.stride[0],
|
||||
conv.dilation[0],
|
||||
1,
|
||||
)?;
|
||||
let output = match bias {
|
||||
Some(bias) => {
|
||||
let bias = bias.reshape((1, conv.ch_out, 1, 1))?;
|
||||
output.broadcast_add(&bias)?
|
||||
}
|
||||
None => output,
|
||||
};
|
||||
output.flatten_all()?.to_vec1::<f32>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_values_match_expected_window_sums() -> candle_core::Result<()> {
|
||||
let mut cx = luminal::graph::Graph::new();
|
||||
let conv = ConvND::new(1, 1, vec![3], vec![1], vec![1], vec![1], true, &mut cx);
|
||||
|
||||
let input = [1., 2., 3., 4., 5.];
|
||||
let weight = [1., 1., 1.];
|
||||
let bias = [0.5];
|
||||
|
||||
let out = candle_conv1d_output(&conv, &input, input.len(), &weight, Some(&bias))?;
|
||||
|
||||
assert_close(&out, &[3.5, 6.5, 9.5, 12.5, 9.5]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_values_accumulate_across_channels() -> candle_core::Result<()> {
|
||||
let mut cx = luminal::graph::Graph::new();
|
||||
let conv = ConvND::new(
|
||||
2,
|
||||
1,
|
||||
vec![2, 2],
|
||||
vec![1, 1],
|
||||
vec![1, 1],
|
||||
vec![0, 0],
|
||||
true,
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
let input = [
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., // channel 0
|
||||
9., 8., 7., 6., 5., 4., 3., 2., 1., // channel 1
|
||||
];
|
||||
let weight = [1., 1., 1., 1., 2., 2., 2., 2.];
|
||||
let bias = [0.25];
|
||||
|
||||
let out = candle_conv2d_output(&conv, &input, 3, 3, &weight, Some(&bias))?;
|
||||
|
||||
assert_close(&out, &[68.25, 64.25, 56.25, 52.25]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_shapes_follow_stride_and_padding() {
|
||||
let mut cx = luminal::graph::Graph::new();
|
||||
let conv = ConvND::new(1, 1, vec![3], vec![2], vec![1], vec![1], false, &mut cx);
|
||||
|
||||
// expected length: floor((padded_len - dilation*(k-1) -1)/stride +1)
|
||||
// padded_len = 7 + 2 = 9
|
||||
// effective kernel = 3
|
||||
// => (9 -3)/2 +1 = 4
|
||||
let inferred = conv.infer_output_shape(&[2, 1, 7]);
|
||||
assert_eq!(inferred, vec![2, 1, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_shapes_follow_stride_and_padding() {
|
||||
let mut cx = luminal::graph::Graph::new();
|
||||
let conv = ConvND::new(
|
||||
3,
|
||||
2,
|
||||
vec![2, 3],
|
||||
vec![1, 2],
|
||||
vec![1, 1],
|
||||
vec![0, 1],
|
||||
true,
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
// height: (5 - dilation*(2-1) -1 + 0 +0)/1 +1 = 4
|
||||
// width: (6 - dilation*(3-1) -1 + 1 +1)/2 +1 = 3
|
||||
let inferred = conv.infer_output_shape(&[1, 3, 5, 6]);
|
||||
assert_eq!(inferred, vec![1, 2, 4, 3]);
|
||||
}
|
||||
}
|
||||
116
crates/luminal_nn/src/embedding.rs
Normal file
116
crates/luminal_nn/src/embedding.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
// use luminal::{prelude::*, tests::random_vec};
|
||||
|
||||
// pub struct Embedding {
|
||||
// permute: bool,
|
||||
// pub weight: GraphTensor, // n embeddings x embedding dim
|
||||
// embedding_dim: usize,
|
||||
// }
|
||||
|
||||
// impl Embedding {
|
||||
// pub fn new(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
|
||||
// Self {
|
||||
// weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
|
||||
// permute: false,
|
||||
// embedding_dim,
|
||||
// }
|
||||
// }
|
||||
|
||||
// pub fn new_permuted(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
|
||||
// Self {
|
||||
// weight: cx.named_tensor("Embedding Weight", (embedding_dim, n_embeddings)),
|
||||
// permute: true,
|
||||
// embedding_dim,
|
||||
// }
|
||||
// }
|
||||
|
||||
// pub fn initialize(self) -> Self {
|
||||
// self.weight.set(random_vec(
|
||||
// self.weight.shape.n_elements().to_usize().unwrap(),
|
||||
// ));
|
||||
// self
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl SerializeModule for Embedding {
|
||||
// fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
// s.tensor("weight", self.weight);
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl Module<GraphTensor> for Embedding {
|
||||
// type Output = GraphTensor;
|
||||
|
||||
// fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
// // Flatten batches
|
||||
// let batch_size = input.shape.n_elements();
|
||||
// let inp = input.reshape(batch_size);
|
||||
// // Gather
|
||||
// let out = if self.permute {
|
||||
// self.weight.permute((1, 0)).gather(inp)
|
||||
// } else {
|
||||
// self.weight.gather(inp)
|
||||
// };
|
||||
// // Unflatten
|
||||
// let mut new_shape = input.dims();
|
||||
// new_shape.push(self.embedding_dim.into());
|
||||
// out.reshape(new_shape)
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl Embedding {
|
||||
// // Reverse from embedding to token distribution
|
||||
// pub fn reverse(&self, input: GraphTensor) -> GraphTensor {
|
||||
// if self.permute {
|
||||
// input.matmul(self.weight)
|
||||
// } else {
|
||||
// input.matmul(self.weight.permute((1, 0)))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use dfdx::{
|
||||
// prelude::Module as DfdxModule,
|
||||
// tensor::{Cpu, TensorFromVec},
|
||||
// };
|
||||
|
||||
// use luminal::prelude::Module;
|
||||
|
||||
// use super::Embedding;
|
||||
// use dfdx::nn::BuildOnDevice;
|
||||
// luminal::test_imports!();
|
||||
|
||||
// #[test]
|
||||
// fn test_embedding() {
|
||||
// let mut cx = Graph::new();
|
||||
// let batch = cx.tensor((2, 3)).set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
|
||||
// let a = cx.tensor(3).set(vec![1.0, 0.0, 1.0]).retrieve();
|
||||
|
||||
// let model = Embedding::new(3, 4, &mut cx).initialize();
|
||||
// model
|
||||
// .weight
|
||||
// .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
|
||||
// let mut b = model.forward(a).retrieve();
|
||||
// let mut batch_out = model.forward(batch).retrieve();
|
||||
|
||||
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
|
||||
|
||||
// cx.execute();
|
||||
|
||||
// let d_dev = Cpu::default();
|
||||
// let mut d_model = <dfdx::nn::modules::builders::Embedding<3, 4>>::build_on_device(&d_dev);
|
||||
// d_model.weight = d_dev.tensor_from_vec(
|
||||
// vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.],
|
||||
// (DConst::<3>, DConst::<4>),
|
||||
// );
|
||||
// let d_a = d_dev.tensor_from_vec(vec![1, 0, 1], (DConst::<3>,));
|
||||
// let d_batch = d_dev.tensor_from_vec(vec![1, 0, 2, 1, 0, 1], (DConst::<2>, DConst::<3>));
|
||||
|
||||
// let d_b = d_model.forward(d_a);
|
||||
// let d_batch_out = d_model.forward(d_batch);
|
||||
|
||||
// assert_close(&b.data(), &d_b.as_vec());
|
||||
// assert_close(&batch_out.data(), &d_batch_out.as_vec());
|
||||
// }
|
||||
// }
|
||||
18
crates/luminal_nn/src/lib.rs
Normal file
18
crates/luminal_nn/src/lib.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
#![allow(unused_imports)]
|
||||
|
||||
mod activation;
|
||||
pub use activation::*;
|
||||
mod convolution;
|
||||
pub use convolution::*;
|
||||
mod embedding;
|
||||
pub use embedding::*;
|
||||
mod linear;
|
||||
pub use linear::*;
|
||||
mod norm;
|
||||
pub use norm::*;
|
||||
mod pooling;
|
||||
pub use pooling::*;
|
||||
mod moe;
|
||||
pub use moe::*;
|
||||
mod attention;
|
||||
pub use attention::*;
|
||||
76
crates/luminal_nn/src/linear.rs
Normal file
76
crates/luminal_nn/src/linear.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A simple unbiased linear layer
|
||||
pub struct Linear {
|
||||
pub weight: GraphTensor,
|
||||
pub bias: Option<GraphTensor>,
|
||||
permute: bool,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (inp, out)).persist(),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", out).persist())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
permute: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_permuted(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (out, inp)).persist(),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", out).persist())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
permute: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
let output = input.matmul(if self.permute {
|
||||
self.weight.permute((1, 0))
|
||||
} else {
|
||||
self.weight
|
||||
});
|
||||
if let Some(_bias) = self.bias {
|
||||
todo!()
|
||||
// output += bias.expand(output.shape);
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::Linear;
|
||||
// use luminal::{prelude::*, tests::assert_close};
|
||||
// #[test]
|
||||
// fn test_linear() {
|
||||
// let mut cx = Graph::new();
|
||||
// let batch = cx.tensor((2, 3)).set([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
// let a = cx.tensor(3).set([1.0, 2.0, 3.0]);
|
||||
|
||||
// let model = Linear::new(3, 4, false, &mut cx).init_rand();
|
||||
// let mut b = model.forward(a).retrieve();
|
||||
// let mut batch_out = model.forward(batch).retrieve();
|
||||
|
||||
// cx.execute();
|
||||
|
||||
// let unoptimized_b = b.data();
|
||||
// let unoptimized_batch_out = batch_out.data();
|
||||
|
||||
// cx.compile(GenericCompiler::default(), (&mut b, &mut batch_out));
|
||||
// cx.execute();
|
||||
|
||||
// assert_close(&unoptimized_b, &b.data());
|
||||
// assert_close(&unoptimized_batch_out, &batch_out.data());
|
||||
// }
|
||||
// }
|
||||
513
crates/luminal_nn/src/moe.rs
Normal file
513
crates/luminal_nn/src/moe.rs
Normal file
@@ -0,0 +1,513 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A layer of E experts and a router
|
||||
pub struct MoE {
|
||||
pub expert_weights: GraphTensor, // [E, in, out]
|
||||
pub router: GraphTensor, // [in, E]
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl MoE {
|
||||
pub fn forward(&self, activations: GraphTensor) -> GraphTensor {
|
||||
let n = activations.dims().len();
|
||||
let e_dim = *self.router.dims().last().unwrap();
|
||||
let (_, in_size, out_size) = self.expert_weights.dims3();
|
||||
let io = in_size * out_size;
|
||||
let k_expr = Expression::from(self.k);
|
||||
|
||||
// 1. Routing probabilities: [batch.., E]
|
||||
let routing_weights = activations.matmul(self.router).softmax(n - 1);
|
||||
|
||||
// 2. Top-k expert indices: [batch.., k] (Int)
|
||||
let top_k_indices = routing_weights.topk_indexes(self.k, n - 1);
|
||||
|
||||
// 3. Gather top-k routing values: [batch.., k]
|
||||
// flat_idx = batch_row * E + expert_idx
|
||||
// iota(z / k * E) gives batch_row * E at each position in [batch.., k]
|
||||
let row_offsets = activations
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx); // [batch.., k]
|
||||
|
||||
// 4. Gather expert weight matrices: [batch.., k, in, out]
|
||||
// flat_idx[.., ki, i, o] = expert_idx[.., ki] * in*out + i * out + o
|
||||
let base = (top_k_indices * io).cast(DType::F32); // [batch.., k]
|
||||
let within = activations
|
||||
.graph()
|
||||
.iota(Expression::from('z'), (in_size, out_size))
|
||||
.cast(DType::F32); // [in, out] values 0..in*out-1
|
||||
|
||||
// Expand base to [batch.., k, in, out]
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base
|
||||
.expand_dim(n_base, in_size)
|
||||
.expand_dim(n_base + 1, out_size);
|
||||
|
||||
// Expand within to [batch.., k, in, out]
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
let gathered = self.expert_weights.gather(expert_flat_idx); // [batch.., k, in, out]
|
||||
|
||||
// 5. Batched matmul: [batch.., k, 1, in] @ [batch.., k, in, out] → [batch.., k, out]
|
||||
let expanded_act = activations
|
||||
.expand_dim(n - 1, self.k) // [batch.., k, in]
|
||||
.unsqueeze(n); // [batch.., k, 1, in]
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::MoE;
|
||||
use luminal::prelude::*;
|
||||
use rand::{rng, Rng};
|
||||
|
||||
fn random_vec(n: usize) -> Vec<f32> {
|
||||
let mut r = rng();
|
||||
(0..n).map(|_| r.random_range(-0.5..0.5)).collect()
|
||||
}
|
||||
|
||||
fn assert_close(a: &[f32], b: &[f32]) {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"length mismatch: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
);
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff = (x - y).abs();
|
||||
if diff > 1e-3 {
|
||||
panic!(
|
||||
"{x} is not close to {y} at index {i}, diff={diff}\n actual: {a:?}\n expected: {b:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference MoE computation for a single input vector.
|
||||
/// input: [in_dim], router: [in_dim, n_experts] (row-major),
|
||||
/// expert_weights: [n_experts, in_dim, out_dim] (row-major)
|
||||
fn moe_reference_1d(
|
||||
input: &[f32],
|
||||
router: &[f32],
|
||||
expert_weights: &[f32],
|
||||
n_experts: usize,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
k: usize,
|
||||
) -> Vec<f32> {
|
||||
// 1. Router logits: input @ router → [n_experts]
|
||||
let mut logits = vec![0.0f32; n_experts];
|
||||
for e in 0..n_experts {
|
||||
for i in 0..in_dim {
|
||||
logits[e] += input[i] * router[i * n_experts + e];
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Softmax
|
||||
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = logits.iter().map(|x| (x - max_l).exp()).collect();
|
||||
let sum_e: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|x| x / sum_e).collect();
|
||||
|
||||
// 3. Top-k indices (descending by probability)
|
||||
let mut indices: Vec<usize> = (0..n_experts).collect();
|
||||
indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
|
||||
let top_k_idx = &indices[..k];
|
||||
let top_k_w: Vec<f32> = top_k_idx.iter().map(|&i| probs[i]).collect();
|
||||
|
||||
// 4. Weighted sum of expert outputs (no renormalization, matching code intent)
|
||||
let mut output = vec![0.0f32; out_dim];
|
||||
for (ki, &eidx) in top_k_idx.iter().enumerate() {
|
||||
for o in 0..out_dim {
|
||||
let mut val = 0.0f32;
|
||||
for i in 0..in_dim {
|
||||
val += input[i] * expert_weights[eidx * in_dim * out_dim + i * out_dim + o];
|
||||
}
|
||||
output[o] += top_k_w[ki] * val;
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// Reference MoE for batched input [batch, in_dim]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn moe_reference_batch(
|
||||
input: &[f32],
|
||||
router: &[f32],
|
||||
expert_weights: &[f32],
|
||||
n_experts: usize,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
k: usize,
|
||||
batch: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut output = Vec::with_capacity(batch * out_dim);
|
||||
for b in 0..batch {
|
||||
let inp = &input[b * in_dim..(b + 1) * in_dim];
|
||||
let out = moe_reference_1d(inp, router, expert_weights, n_experts, in_dim, out_dim, k);
|
||||
output.extend_from_slice(&out);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
// ── Test: 1D input, k=1, strongly-routed to expert 0 ────────────────
|
||||
#[test]
|
||||
fn test_moe_1d_k1() {
|
||||
let n_experts = 2;
|
||||
let in_dim = 3;
|
||||
let out_dim = 2;
|
||||
let k = 1;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let input = cx.tensor(in_dim);
|
||||
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
|
||||
let router_w = cx.tensor((in_dim, n_experts));
|
||||
|
||||
let moe = MoE {
|
||||
expert_weights: expert_w,
|
||||
router: router_w,
|
||||
k,
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
let input_data = vec![1.0, 2.0, 3.0];
|
||||
// Router strongly favors expert 0
|
||||
let router_data = vec![
|
||||
10.0, -10.0, // feature 0
|
||||
10.0, -10.0, // feature 1
|
||||
10.0, -10.0, // feature 2
|
||||
];
|
||||
// Expert 0: simple linear, Expert 1: different
|
||||
let expert_data = vec![
|
||||
// Expert 0: [3x2]
|
||||
1.0, 0.0, 0.0, 1.0, 1.0, 1.0, // Expert 1: [3x2]
|
||||
2.0, 0.0, 0.0, 2.0, 2.0, 2.0,
|
||||
];
|
||||
|
||||
rt.set_data(input.id, input_data.clone());
|
||||
rt.set_data(router_w.id, router_data.clone());
|
||||
rt.set_data(expert_w.id, expert_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let expected = moe_reference_1d(
|
||||
&input_data,
|
||||
&router_data,
|
||||
&expert_data,
|
||||
n_experts,
|
||||
in_dim,
|
||||
out_dim,
|
||||
k,
|
||||
);
|
||||
// With strong routing to expert 0: output ≈ [1,2,3]@[[1,0],[0,1],[1,1]] = [4, 5]
|
||||
assert_close(rt.get_f32(output.id), &expected);
|
||||
}
|
||||
|
||||
// ── Test: 1D input, k=E (all experts selected) ─────────────────────
|
||||
#[test]
|
||||
fn test_moe_1d_k_equals_e() {
|
||||
let n_experts = 3;
|
||||
let in_dim = 2;
|
||||
let out_dim = 2;
|
||||
let k = 3; // select all experts
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let input = cx.tensor(in_dim);
|
||||
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
|
||||
let router_w = cx.tensor((in_dim, n_experts));
|
||||
|
||||
let moe = MoE {
|
||||
expert_weights: expert_w,
|
||||
router: router_w,
|
||||
k,
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
let input_data = vec![1.0, 1.0];
|
||||
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
|
||||
let router_data = vec![0.01, 0.02, 0.03, 0.01, 0.02, 0.03];
|
||||
// Each expert: identity-scaled by index+1
|
||||
let expert_data = vec![
|
||||
// Expert 0: identity
|
||||
1.0, 0.0, 0.0, 1.0, // Expert 1: 2x
|
||||
2.0, 0.0, 0.0, 2.0, // Expert 2: 3x
|
||||
3.0, 0.0, 0.0, 3.0,
|
||||
];
|
||||
|
||||
rt.set_data(input.id, input_data.clone());
|
||||
rt.set_data(router_w.id, router_data.clone());
|
||||
rt.set_data(expert_w.id, expert_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let expected = moe_reference_1d(
|
||||
&input_data,
|
||||
&router_data,
|
||||
&expert_data,
|
||||
n_experts,
|
||||
in_dim,
|
||||
out_dim,
|
||||
k,
|
||||
);
|
||||
// Equal routing: each expert weight = 1/3
|
||||
// output = 1/3 * [1,1] + 1/3 * [2,2] + 1/3 * [3,3] = [2, 2]
|
||||
assert_close(rt.get_f32(output.id), &expected);
|
||||
}
|
||||
|
||||
// ── Test: 2D batched input ──────────────────────────────────────────
|
||||
#[test]
|
||||
fn test_moe_batched() {
|
||||
let n_experts = 2;
|
||||
let in_dim = 3;
|
||||
let out_dim = 2;
|
||||
let k = 1;
|
||||
let batch = 2;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let input = cx.tensor((batch, in_dim));
|
||||
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
|
||||
let router_w = cx.tensor((in_dim, n_experts));
|
||||
|
||||
let moe = MoE {
|
||||
expert_weights: expert_w,
|
||||
router: router_w,
|
||||
k,
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
let input_data = vec![
|
||||
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
|
||||
0.0, 1.0, 0.0, // batch 1: routes to expert via feature 1
|
||||
];
|
||||
// Router: feature 0 → expert 0, feature 1 → expert 1
|
||||
let router_data = vec![
|
||||
10.0, -10.0, // feature 0 → expert 0
|
||||
-10.0, 10.0, // feature 1 → expert 1
|
||||
0.0, 0.0, // feature 2 → neutral
|
||||
];
|
||||
let expert_data = vec![
|
||||
// Expert 0: [3x2]
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // Expert 1: [3x2]
|
||||
7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
];
|
||||
|
||||
rt.set_data(input.id, input_data.clone());
|
||||
rt.set_data(router_w.id, router_data.clone());
|
||||
rt.set_data(expert_w.id, expert_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let expected = moe_reference_batch(
|
||||
&input_data,
|
||||
&router_data,
|
||||
&expert_data,
|
||||
n_experts,
|
||||
in_dim,
|
||||
out_dim,
|
||||
k,
|
||||
batch,
|
||||
);
|
||||
assert_close(rt.get_f32(output.id), &expected);
|
||||
}
|
||||
|
||||
// ── Test: random inputs with k=2 ────────────────────────────────────
|
||||
#[test]
|
||||
fn test_moe_random_k2() {
|
||||
let n_experts = 4;
|
||||
let in_dim = 8;
|
||||
let out_dim = 4;
|
||||
let k = 2;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let input = cx.tensor(in_dim);
|
||||
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
|
||||
let router_w = cx.tensor((in_dim, n_experts));
|
||||
|
||||
let moe = MoE {
|
||||
expert_weights: expert_w,
|
||||
router: router_w,
|
||||
k,
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
let input_data = random_vec(in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
let expert_data = random_vec(n_experts * in_dim * out_dim);
|
||||
|
||||
rt.set_data(input.id, input_data.clone());
|
||||
rt.set_data(router_w.id, router_data.clone());
|
||||
rt.set_data(expert_w.id, expert_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let expected = moe_reference_1d(
|
||||
&input_data,
|
||||
&router_data,
|
||||
&expert_data,
|
||||
n_experts,
|
||||
in_dim,
|
||||
out_dim,
|
||||
k,
|
||||
);
|
||||
assert_close(rt.get_f32(output.id), &expected);
|
||||
}
|
||||
|
||||
// ── Test: batched random inputs ─────────────────────────────────────
|
||||
#[test]
|
||||
fn test_moe_batched_random() {
|
||||
let n_experts = 3;
|
||||
let in_dim = 4;
|
||||
let out_dim = 3;
|
||||
let k = 2;
|
||||
let batch = 4;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let input = cx.tensor((batch, in_dim));
|
||||
let expert_w = cx.tensor((n_experts, in_dim, out_dim));
|
||||
let router_w = cx.tensor((in_dim, n_experts));
|
||||
|
||||
let moe = MoE {
|
||||
expert_weights: expert_w,
|
||||
router: router_w,
|
||||
k,
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
|
||||
let input_data = random_vec(batch * in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
let expert_data = random_vec(n_experts * in_dim * out_dim);
|
||||
|
||||
rt.set_data(input.id, input_data.clone());
|
||||
rt.set_data(router_w.id, router_data.clone());
|
||||
rt.set_data(expert_w.id, expert_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let expected = moe_reference_batch(
|
||||
&input_data,
|
||||
&router_data,
|
||||
&expert_data,
|
||||
n_experts,
|
||||
in_dim,
|
||||
out_dim,
|
||||
k,
|
||||
batch,
|
||||
);
|
||||
assert_close(rt.get_f32(output.id), &expected);
|
||||
}
|
||||
|
||||
/// Dump the egglog HLIR for a QwenMoE-style GLU-MoE pattern.
|
||||
/// This helps identify the exact pattern for the GLUMoE backend HostOp.
|
||||
#[test]
|
||||
fn dump_glu_moe_egglog() {
|
||||
use luminal::dtype::DType;
|
||||
use luminal::egglog_utils::hlir_to_egglog;
|
||||
|
||||
let n_experts = 4;
|
||||
let hidden = 8;
|
||||
let intermediate = 4;
|
||||
let top_k: usize = 2;
|
||||
|
||||
let mut cx = Graph::new();
|
||||
|
||||
// Input tensors
|
||||
let x = cx.tensor(('s', hidden));
|
||||
let router = cx.tensor((n_experts, hidden));
|
||||
let gate_up_weights = cx
|
||||
.tensor((n_experts, intermediate * 2, hidden))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((n_experts, hidden, intermediate))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len(); // 2
|
||||
let e_dim = *router.dims().first().unwrap(); // E
|
||||
let k_expr = luminal::shape::Expression::from(top_k);
|
||||
|
||||
// 1. Router: softmax(x @ router^T) → [s, E]
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
|
||||
// 2. TopK expert selection → [s, k] (Int)
|
||||
let top_k_indices = routing_weights.topk_indexes(top_k, n - 1);
|
||||
|
||||
// 3. Gather top-k routing values → [s, k]
|
||||
let row_offsets = cx.iota(
|
||||
luminal::shape::Expression::from('z') / k_expr * e_dim,
|
||||
top_k_indices.dims(),
|
||||
);
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
|
||||
let gate_up_gathered =
|
||||
gather_experts_test(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, top_k).unsqueeze(n); // [s, k, 1, H]
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n); // [s, k, intermediate*2]
|
||||
|
||||
// 5. SwiGLU: silu(gate) * up → [s, k, intermediate]
|
||||
let gate = gate_up_out.slice((.., .., ..intermediate));
|
||||
let up = gate_up_out.slice((.., .., intermediate..));
|
||||
let hidden_act = gate.silu() * up;
|
||||
|
||||
// 6. Gather down expert weights → [s, k, H, intermediate]
|
||||
let down_gathered = gather_experts_test(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let hidden_exp = hidden_act.unsqueeze(2); // [s, k, 1, intermediate]
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
let (program, root) = hlir_to_egglog(&cx);
|
||||
println!("=== GLU-MoE HLIR Egglog Dump ===");
|
||||
println!("Root: {root}");
|
||||
println!("{program}");
|
||||
}
|
||||
|
||||
/// Helper: gather expert weight matrices using topk indices.
|
||||
fn gather_experts_test(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = (top_k_indices * io).cast(DType::F32);
|
||||
let within = graph_source
|
||||
.graph()
|
||||
.iota(luminal::shape::Expression::from('z'), (d1, d2))
|
||||
.cast(DType::F32);
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
}
|
||||
44
crates/luminal_nn/src/norm.rs
Normal file
44
crates/luminal_nn/src/norm.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A simple layer norm with an optional weight and bias
|
||||
#[derive(Default)]
|
||||
pub struct LayerNorm {
|
||||
pub weight: Option<GraphTensor>,
|
||||
pub bias: Option<GraphTensor>,
|
||||
mean_norm: bool,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
weight: Option<&str>,
|
||||
bias: Option<&str>,
|
||||
mean_norm: bool,
|
||||
epsilon: f32,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: weight.map(|w| cx.named_tensor(w, dim).persist()),
|
||||
bias: bias.map(|b| cx.named_tensor(b, dim).persist()),
|
||||
mean_norm,
|
||||
epsilon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
|
||||
if self.mean_norm {
|
||||
input = input.mean_norm(input.shape.last_axis());
|
||||
}
|
||||
input = input.std_norm(input.shape.last_axis(), self.epsilon);
|
||||
if let Some(w) = self.weight {
|
||||
input *= w.expand_lhs(&input.dims()[..input.dims().len() - 1]);
|
||||
}
|
||||
if let Some(b) = self.bias {
|
||||
input += b.expand_lhs(&input.dims()[..input.dims().len() - 1]);
|
||||
}
|
||||
input
|
||||
}
|
||||
}
|
||||
106
crates/luminal_nn/src/pooling.rs
Normal file
106
crates/luminal_nn/src/pooling.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
// use luminal::prelude::*;
|
||||
|
||||
// pub struct AvgPool2D {
|
||||
// kernel: (usize, usize),
|
||||
// stride: (usize, usize),
|
||||
// }
|
||||
|
||||
// impl AvgPool2D {
|
||||
// pub fn new(kernel: (usize, usize), stride: (usize, usize)) -> Self {
|
||||
// Self { kernel, stride }
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl SerializeModule for AvgPool2D {
|
||||
// fn serialize(&self, _s: &mut luminal::module::Serializer) {
|
||||
// // No parameters to serialize for average pooling
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl AvgPool2D {
|
||||
// pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
|
||||
// // Input: (batch (optional), ch_in, dimx_in, dimy_in)
|
||||
// let mut expanded = false;
|
||||
// if input.shape.len() == 3 {
|
||||
// // Expand batch
|
||||
// input = input.expand_dim(0, 1);
|
||||
// expanded = true;
|
||||
// }
|
||||
// let (batch, ch_in, dimx_in, dimy_in) = input.dims4();
|
||||
// let dimx_out = ((dimx_in - self.kernel.0) / self.stride.0 + 1).simplify();
|
||||
// let dimy_out = ((dimy_in - self.kernel.1) / self.stride.1 + 1).simplify();
|
||||
|
||||
// let output = input
|
||||
// .pool_last_dim(self.kernel.1, self.stride.1, 1) // dilation = 1 for pooling
|
||||
// .permute((0, 1, 3, 4, 2))
|
||||
// .pool_last_dim(self.kernel.0, self.stride.0, 1)
|
||||
// .permute((0, 1, 5, 3, 4, 2))
|
||||
// .reshape((
|
||||
// batch,
|
||||
// ch_in,
|
||||
// self.kernel.0 * self.kernel.1,
|
||||
// dimx_out * dimy_out,
|
||||
// ))
|
||||
// .mean(2) // Average over the kernel dimension
|
||||
// .reshape((batch, ch_in, dimx_out, dimy_out));
|
||||
|
||||
// if expanded {
|
||||
// output.reshape((ch_in, dimx_out, dimy_out))
|
||||
// } else {
|
||||
// output
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// pub struct AdaptiveAvgPool2D {
|
||||
// output_size: (usize, usize),
|
||||
// }
|
||||
|
||||
// impl AdaptiveAvgPool2D {
|
||||
// pub fn new(output_size: (usize, usize)) -> Self {
|
||||
// Self { output_size }
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl SerializeModule for AdaptiveAvgPool2D {
|
||||
// fn serialize(&self, _s: &mut luminal::module::Serializer) {
|
||||
// // No learnable parameters
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl AdaptiveAvgPool2D {
|
||||
// pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
|
||||
// let mut expanded = false;
|
||||
// // Handle missing batch dimension
|
||||
// if input.shape.len() == 3 {
|
||||
// input = input.expand_dim(0, 1);
|
||||
// expanded = true;
|
||||
// }
|
||||
|
||||
// // Extract dimensions
|
||||
// let (batch, ch, h_in, w_in) = input.dims4();
|
||||
// let (h_out, w_out) = self.output_size;
|
||||
|
||||
// let stride_h = (h_in / h_out).simplify();
|
||||
// let stride_w = (w_in / w_out).simplify();
|
||||
// let kernel_h = (h_in - (h_out - 1) * stride_h).simplify();
|
||||
// let kernel_w = (w_in - (w_out - 1) * stride_w).simplify();
|
||||
|
||||
// // Two-stage pooling (Y then X), followed by averaging over the kernel window
|
||||
// let mut output = input
|
||||
// .pool_last_dim(kernel_w, stride_w, 1)
|
||||
// .permute((0, 1, 3, 4, 2))
|
||||
// .pool_last_dim(kernel_h, stride_h, 1)
|
||||
// .permute((0, 1, 5, 3, 4, 2))
|
||||
// .reshape((batch, ch, kernel_h * kernel_w, h_out * w_out))
|
||||
// .mean(2)
|
||||
// .reshape((batch, ch, h_out, w_out));
|
||||
|
||||
// // Remove batch dim if it was originally absent
|
||||
// if expanded {
|
||||
// output = output.reshape((ch, h_out, w_out));
|
||||
// }
|
||||
|
||||
// output
|
||||
// }
|
||||
// }
|
||||
5
crates/luminal_python/.gitignore
vendored
Normal file
5
crates/luminal_python/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
*.onnx
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
.venv
|
||||
120
crates/luminal_python/CLAUDE.md
Normal file
120
crates/luminal_python/CLAUDE.md
Normal file
@@ -0,0 +1,120 @@
|
||||
## Python Environment
|
||||
|
||||
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
|
||||
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
|
||||
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
At the end of any session that involved a hard or non-obvious bug, append an entry to
|
||||
`LessonsLearned.md` in this directory. A "hard bug" means any bug that required significant
|
||||
investigation — intermittent failures, wrong output without a crash, egglog/optimizer issues,
|
||||
or anything that took more than a few minutes to locate.
|
||||
|
||||
Each entry should cover:
|
||||
1. **What the symptom was** (test failure, wrong output, panic, etc.)
|
||||
2. **What the actual root cause was** (the specific code/logic that was wrong)
|
||||
3. **Why it was hard to find** (what made it non-obvious or intermittent)
|
||||
4. **The fix** (what changed and why it works)
|
||||
5. **A general principle** extracted from the bug — something that helps avoid the same
|
||||
class of mistake in future code
|
||||
|
||||
The goal is to build a living record of codebase-specific pitfalls that future sessions can
|
||||
consult before writing new egglog rules, CUDA kernels, or optimizer passes.
|
||||
1. If you want to run tests:
|
||||
- `./run_test.sh` - runs tests with the native backend
|
||||
- `./run_tests_cuda.sh` - runs tests with the CUDA backend
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
### Overview
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
|
||||
### Test Pattern (CORRECT)
|
||||
|
||||
All tests should follow this standard pattern:
|
||||
|
||||
```python
|
||||
def test_operation():
|
||||
"""Brief description of what operation is being tested."""
|
||||
# 1. Instantiate PyTorch model
|
||||
model: torch.nn.Module = OperationTestModel()
|
||||
|
||||
# 2. Compile with luminal backend
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
# 3. Create test input
|
||||
x: torch.Tensor = torch.tensor([...]) # or torch.rand(...)
|
||||
|
||||
# 4. Run both original and compiled versions
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
|
||||
# 5. Verify outputs match
|
||||
assert torch.allclose(output, original)
|
||||
```
|
||||
|
||||
### Test Models
|
||||
|
||||
- Define test model classes in `tests/test_models.py`
|
||||
- Each model should be a simple `torch.nn.Module` that demonstrates one operation or pattern
|
||||
- Use clear, descriptive class names (e.g., `AddTestModel`, `TransposeTestModel`)
|
||||
- Include docstrings explaining what the model tests
|
||||
|
||||
Example:
|
||||
```python
|
||||
class AddTestModel(torch.nn.Module):
|
||||
"""Tests element-wise addition."""
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
```
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
**❌ DO NOT create ONNX files directly in tests:**
|
||||
```python
|
||||
# WRONG - bypasses the PyTorch integration
|
||||
model_path = create_onnx_model(...)
|
||||
graph_result = luminal.process_onnx(model_path, backend='native')
|
||||
```
|
||||
|
||||
**✓ DO create PyTorch models and use torch.compile:**
|
||||
```python
|
||||
# CORRECT - tests actual user workflow
|
||||
model: torch.nn.Module = MyTestModel()
|
||||
model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
```
|
||||
|
||||
### Rationale
|
||||
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
|
||||
- **User-facing API**: Tests use the same API that users will use (torch.compile)
|
||||
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
|
||||
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
|
||||
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
|
||||
|
||||
### Special Cases
|
||||
|
||||
**Testing constants:**
|
||||
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0])
|
||||
return x + constant
|
||||
```
|
||||
|
||||
**Testing type casts:**
|
||||
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
```
|
||||
|
||||
**Testing complex operations:**
|
||||
Chain operations naturally in PyTorch - ONNX export handles the conversion:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
transposed = x.transpose(0, 1)
|
||||
scaled = transposed * 2.0
|
||||
return scaled + 1.0
|
||||
```
|
||||
758
crates/luminal_python/LessonsLearned.md
Normal file
758
crates/luminal_python/LessonsLearned.md
Normal file
@@ -0,0 +1,758 @@
|
||||
# Lessons Learned
|
||||
|
||||
This file documents hard bugs encountered in this codebase, their root causes, and principles
|
||||
to prevent similar issues in the future.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-24 — Intermittent CUDA Backend Failures: Embed False Match + Batched Matmul Dimension Drop
|
||||
|
||||
### Background: Why the Failures Were Intermittent
|
||||
|
||||
Both bugs only appeared on roughly 50% of test runs. The source of non-determinism is
|
||||
`FxHashMap` (a fixed-seed hash map). The egglog optimizer's `SerializedEGraph::new` builds
|
||||
`Vec<NodeId>` orderings for each e-class by iterating a `FxHashMap`, producing non-deterministic
|
||||
node orderings. `random_initial_choice()` in `src/egglog_utils/mod.rs` then randomly picks one
|
||||
e-node per e-class as the starting representation for the profiling phase. The combination means
|
||||
some runs pick a correct kernel and some pick a broken one from the same e-class.
|
||||
|
||||
**Lesson**: When a test fails intermittently at a roughly 50% rate, suspect the egglog extractor
|
||||
choosing between two e-nodes in the same e-class — one correct, one broken. The fix is always in
|
||||
the broken e-node's rewrite rule.
|
||||
|
||||
---
|
||||
|
||||
### Bug 1: `test_gather_elements` — KernelEmbed and RowEmbed False Match
|
||||
|
||||
**Files changed**:
|
||||
- `crates/luminal_cuda/src/kernel/hlir.rs` (KernelEmbed, 4 rules)
|
||||
- `crates/luminal_cuda/src/block/ops.rs` (RowEmbed, 4 rules)
|
||||
|
||||
#### What happened
|
||||
|
||||
`gather_elements` (axis-aware gather) decomposes into a flat gather by computing:
|
||||
|
||||
```
|
||||
flat_idx = Add(
|
||||
Mul(indices, stride[axis]),
|
||||
Mul(Expand(Iota(dim_size)), stride[non_axis])
|
||||
)
|
||||
```
|
||||
|
||||
`KernelEmbed` and `RowEmbed` are optimized embedding lookup kernels. A genuine embedding
|
||||
lookup produces:
|
||||
|
||||
```
|
||||
flat_idx = Add(
|
||||
Mul(Cast(token_ids), embed_dim),
|
||||
Iota(embed_dim) ← bare Iota, the position within an embedding row
|
||||
)
|
||||
```
|
||||
|
||||
The egglog rewrite rules for both ops matched `Add(?mul_result, ?iota_result)` where
|
||||
`?iota_result` was **unconstrained** — it could bind to anything, including
|
||||
`Mul(Expand(Iota(n)), stride)` from `gather_elements`. This created a `KernelEmbed`/`RowEmbed`
|
||||
node in the same e-class as the `Gather` node. When the extractor picked it, `build_payload`
|
||||
called `flatten_mul_strides(range, token_stride)` which asserted `range.len() == token_stride.len()`:
|
||||
- `range` came from `RemoveNthFromEnd(idx_shape, 0)` → length 1
|
||||
- `token_stride` came from the indices strides → length 2
|
||||
- Assertion failed → panic.
|
||||
|
||||
#### The fix
|
||||
|
||||
Add `(= ?iota_result (Iota ?iota_expr ?iota_range))` to all 8 rules, requiring the positional
|
||||
component to be a bare `Iota` node:
|
||||
|
||||
```egglog
|
||||
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
|
||||
(= ?iota_result (Iota ?iota_expr ?iota_range)) ← added
|
||||
(= ?mul_result (Mul ...))
|
||||
```
|
||||
|
||||
#### Investigation note
|
||||
|
||||
The initial plan correctly identified `KernelEmbed` as faulty, but missed `RowEmbed`. The two
|
||||
ops are structurally identical but live in different parts of the codebase (`kernel/` vs
|
||||
`block/`). The second bug was only discovered when the backtrace pointed to
|
||||
`RowEmbed::build_payload` instead of `KernelEmbed::compile`. Always search for sibling
|
||||
implementations when fixing a pattern-matching bug in one op.
|
||||
|
||||
---
|
||||
|
||||
### Bug 2: `test_matmul_batched` — CuBlasLt Drops Batch Dimension
|
||||
|
||||
**Files changed**:
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg`
|
||||
|
||||
#### What happened
|
||||
|
||||
The luminal frontend decomposes `(2,3,4) @ (2,4,5)` into:
|
||||
|
||||
```rust
|
||||
let w = rhs.permute((0, 2, 1)); // (2,4,5) → (2,5,4)
|
||||
let mul = self.expand_dim(2, d) // (2,3,4) → (2,3,5,4)
|
||||
* w.expand_dim(1, b); // (2,5,4) → (2,3,5,4)
|
||||
mul.sum(3) // → (2,3,5), correct out_shape
|
||||
```
|
||||
|
||||
All four cublaslt rewrite rules extracted `m` and `n` from the output shape using
|
||||
`nth_from_end`, which succeeds for any rank:
|
||||
|
||||
```egglog
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
```
|
||||
|
||||
For `out_shape = [2, 3, 5]`: `?m = 3`, `?n = 5`. The batch dim `2` is never extracted or
|
||||
stored. The rules also validated stride patterns using `nth_from_end` on the stride arrays —
|
||||
but for this batched case, **all stride checks coincidentally passed** because the last three
|
||||
strides of the 4D expanded tensors happened to satisfy the 2D row/column-major patterns.
|
||||
|
||||
The resulting `CuBlasLt` node had `output_size() = m * n = 15`. The batch dimension was
|
||||
silently discarded. The runtime allocated a 15-element output buffer, cuBLAS wrote a 3×5
|
||||
result, and the test got back 15 values instead of 30.
|
||||
|
||||
#### The fix
|
||||
|
||||
Add `(= (len ?out_shape) 2)` to all 4 rules:
|
||||
|
||||
```egglog
|
||||
(= (len ?out_shape) 2) ← added: cuBLAS is 2D only
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
```
|
||||
|
||||
`len` counts elements in the `ECons`-list shape. With this constraint, any `Sum` node with a
|
||||
3D+ output shape (batched matmul) is not matched by cuBLAS rules and falls through to
|
||||
`KernelSumReduce + KernelMul` (or the tiling block ops), which correctly use
|
||||
`out_shape.iter().product()` for their output sizes.
|
||||
|
||||
Note: `TileMatmulSplitK` and `TileMatmulFullSplit` do NOT need this fix — their `output_size()`
|
||||
already returns `untiled_range.iter().product()` which includes all dimensions.
|
||||
|
||||
---
|
||||
|
||||
### General Principle: Always Constrain Shape Rank in Egglog Rules
|
||||
|
||||
Both bugs share the same structural cause: **egglog rewrite rules that used `nth_from_end` to
|
||||
extract dimensions from a shape list without constraining the list's length.** Since
|
||||
`nth_from_end` silently succeeds for any list with enough trailing elements, rules written for
|
||||
2D tensors accidentally matched higher-rank tensors.
|
||||
|
||||
**Rule for writing egglog rewrite rules in this codebase**:
|
||||
|
||||
> If a rule is designed for a specific tensor rank, always add an explicit
|
||||
> `(= (len ?shape) N)` constraint. If a rule is designed to handle arbitrary ranks but an
|
||||
> op's output only covers a subset of dimensions (like cuBLAS covering only the last 2),
|
||||
> that is a correctness bug — either implement strided batched cuBLAS or add the rank
|
||||
> constraint and fall back to a kernel that handles all dimensions.
|
||||
|
||||
---
|
||||
|
||||
### Debugging Intermittent CUDA Failures: Effective Approach
|
||||
|
||||
The investigation used extensive `eprintln!` debug logging to trace which kernels were compiled
|
||||
vs. skipped. Key observations:
|
||||
|
||||
1. **In the passing case**: `KernelSumReduce::compile()` was called, kernels were allocated.
|
||||
2. **In the failing case**: `KernelSumReduce::compile()` was never called, yet output was produced.
|
||||
|
||||
This asymmetry pointed to a `HostOp` path (cuBLAS) executing instead of the `KernelOp` path,
|
||||
which narrowed the search to cublaslt rewrite rules. The HLIR-level `SumReduce::to_egglog` log
|
||||
confirmed the correct HLIR node existed — the bug was in the e-graph optimization choosing
|
||||
a different (broken) e-node from the same e-class.
|
||||
|
||||
**Effective debug strategy for egglog non-determinism bugs**:
|
||||
1. Add logging at compile time for each kernel type (`KernelFoo::compile`, `HostFoo::execute`)
|
||||
2. Compare passing vs. failing runs to see which kernels are/aren't invoked
|
||||
3. The missing kernel's e-class contains a broken alternative — find it via the egglog rewrite rules
|
||||
4. Check the op that *is* executing — its `output_size()` reveals what's wrong with the false match
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-25 — OneHot Test Panic: Cast(Int→F32) Produces Int Output
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_onehot` panicked at `src/hlir.rs:1625` in `get_f32()`: the output buffer was
|
||||
`NativeData::Int` instead of the expected `NativeData::F32`.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The Cast parser's `* 1.0` workaround for `Int → F32` casts used `input * one_expanded`
|
||||
(Int GraphTensor on the left, F32 constant on the right). However, `Mul for GraphTensor`
|
||||
always uses `self.dtype` (the **left** operand's dtype) for the result, and the native
|
||||
runtime's `Mul::execute` dispatches on the **first** input's `NativeData` variant. So
|
||||
`Int * F32` produced `DType::Int` / `NativeData::Int` — the exact opposite of the intended
|
||||
F32 output.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **The OneHot parser was a red herring**: The initial plan assumed the OneHot ONNX node
|
||||
was being parsed, but `torch.onnx.export` decomposes `one_hot` into
|
||||
`Unsqueeze → Equal → Cast(Bool→Int) → Cast(Int→F32)`. The OneHot parser was never called.
|
||||
2. **The `* 1.0` workaround looked correct**: It was used successfully in many other parsers,
|
||||
but those all had F32 inputs (where `F32 * F32 = F32`). The Int→F32 case was the only
|
||||
path where the left operand was Int.
|
||||
3. **Operand order matters silently**: Nothing warns about mixed-dtype Mul — it just takes
|
||||
the left operand's dtype.
|
||||
|
||||
### The fix
|
||||
|
||||
In `ops_parse/unary.rs` `parse_cast_node`, split the combined condition into two cases:
|
||||
- **No-op cast** (`cast_result.id == input.id`): `input * one_expanded` — preserves dtype
|
||||
- **Int source** (`input.dtype == DType::Int`): `one_expanded * input` — F32 on the left
|
||||
ensures F32 output
|
||||
|
||||
### General principle
|
||||
|
||||
**In luminal, binary op dtype is always the LEFT operand's dtype.** When constructing
|
||||
`GraphTensor * constant_float(1.0)` for type materialization, always put the operand
|
||||
whose dtype you want to preserve on the LEFT side. When converting Int→F32, the F32
|
||||
constant must be the left operand.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-26 — ScatterND Fails on CUDA: "does not produce an egraph"
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_scatter_nd` passed on native backend but failed on CUDA with "does not produce an
|
||||
egraph". The CUDA compilation could not extract a valid program from the e-graph.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`scatter_nd` in `movement.rs` does `indices * 1` (line 353) to materialize the tensor for
|
||||
reshaping. The `* 1` dispatches to `Mul<S: Into<Expression>>`, which creates a `constant(1)`
|
||||
→ `Iota(1,1)` → `DType::Int`. But the ONNX parser creates all tensors as `DType::F32`
|
||||
(via `named_tensor()` in `compiled_graph.rs:70`), so indices arrive as F32. This produces
|
||||
`Mul(F32, Int)` — mixed dtypes.
|
||||
|
||||
The HLIR Mul dtype rule (`hlir.rs:886-888`) uses `(= ?dty (dtype ?lhs))` and
|
||||
`(= ?dty (dtype ?rhs))` with the same `?dty` variable, requiring both inputs to have
|
||||
matching dtypes. `F32 != Int` → the rule never fires → the Mul node gets **no dtype**.
|
||||
|
||||
Every downstream op checks `(= ?dty (dtype ?upstream))`. Without dtype on the Mul, no
|
||||
CUDA kernel rewrite rules fire for any downstream op (KernelMul, KernelAdd, KernelLessThan,
|
||||
etc.). When `cleanup_hlir` runs (enabled for CUDA, disabled for native), it deletes all
|
||||
unrewritten HLIR ops, leaving empty e-classes → egraph extraction fails.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Works on native**: `cleanup_hlir = false` for NativeRuntime, so unrewritten HLIR ops
|
||||
are never deleted. NativeOp dispatches on actual runtime data, not egglog dtype.
|
||||
2. **Cascading failure**: The root cause (missing dtype on one Mul) silently propagated
|
||||
through every downstream op, making it look like a systemic CUDA issue rather than a
|
||||
single dtype mismatch.
|
||||
3. **`scatter_elements` works fine**: The sibling op already cast indices via
|
||||
`(idx_f32 + (is_neg * adj)).cast(DType::Int)`, so only `scatter_nd` had this bug.
|
||||
|
||||
### The fix
|
||||
|
||||
Added `let indices = indices.cast(DType::Int);` at the top of `scatter_nd` in
|
||||
`movement.rs`, before any arithmetic on indices. `GraphTensor::cast()` short-circuits
|
||||
when `self.dtype == dtype`, so this is safe for callers already passing Int indices.
|
||||
Also added the same cast in `parse_scatter_nd_node` for explicitness.
|
||||
|
||||
### General principle
|
||||
|
||||
**Always cast index tensors to `DType::Int` before arithmetic in graph-building code.**
|
||||
ONNX tensors arrive as F32 from the Python bridge. Any `indices * stride` or
|
||||
`indices * 1` will produce `Mul(F32, Int)` which breaks HLIR dtype propagation on CUDA.
|
||||
The pattern `let indices = indices.cast(DType::Int);` at the top of any index-consuming
|
||||
function is defensive and free (no-op when already Int).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-04 — Dynamic Shapes: Empty Buffer for BOOL Scalar Initializer
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_decode_loop_dynamic` panicked at `bin_fn: a index 0 out of bounds (a.len=0), shape=[1, 1, 4, 4], strides=[0, 0, 0, 0]`. An Input node labeled `"new_ones"` had an empty buffer at runtime.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
Two issues combined:
|
||||
|
||||
1. **`load_tensor_floats` didn't handle ONNX data_type=9 (BOOL)**. The `new_ones` initializer was a BOOL scalar (1 byte in `raw_data`). `load_tensor_floats` fell through to the fallback case, which tried `chunks_exact(4)` on 1 byte → produced 0 chunks → returned empty vec `[]`. The buffer was set with empty data.
|
||||
|
||||
2. **Scalar initializers with empty `dims` created 0-dimensional tensors**. ONNX represents scalars with `dims=[]`. The initializer loop computed `shape = init.dims.iter().map(|&d| d as usize).collect()` → empty vec `[]`, then called `named_tensor(name, [])` which created a tensor with 0 dimensions instead of the intended scalar `[1]`.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Misdiagnosed as ConstantOfShape issue**: The original plan targeted `ConstantOfShape` with dynamic shapes. The shape `[1,1,4,4]` with strides `[0,0,0,0]` looked like a broadcast from a constant fill. But `parse_constant_of_shape` was never called — the `new_ones` tensor came from an ONNX initializer, not a computation node.
|
||||
|
||||
2. **The BOOL data type is unusual**: Most ONNX tensors are FLOAT, INT32, or INT64. BOOL initializers only appear in specific patterns (like `torch.ones()` in attention mask computation). `load_initializer_as_f32` already handled BOOL, but its sibling `load_tensor_floats` didn't.
|
||||
|
||||
3. **Empty vec is valid data**: `set_data(node_id, [])` doesn't panic — it silently sets an empty buffer. The error only manifests later when a downstream op tries to read index 0.
|
||||
|
||||
### The fix
|
||||
|
||||
1. Added `data_type=9` (BOOL) handling to `load_tensor_floats` in `util.rs` — same logic as `load_initializer_as_f32`: 1 byte per element, non-zero → 1.0, zero → 0.0.
|
||||
|
||||
2. In `compiled_graph.rs`, initializer tensor creation: if `shape.is_empty()`, set `shape = vec![1]` (scalar representation in luminal).
|
||||
|
||||
### General principle
|
||||
|
||||
**Keep data loading functions in sync.** `load_tensor_floats` and `load_initializer_as_f32` serve the same purpose (loading ONNX TensorProto data as f32) but had different data type coverage. When adding a new data type to one, check and update the other. Better yet, refactor them into a single function.
|
||||
|
||||
**ONNX scalars have `dims=[]`, luminal scalars have shape `[1]`.** Always convert empty dims to `[1]` when creating luminal tensors from ONNX data.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-04 — Where Node Missing Broadcast: KernelMul flatten_strides Panic on CUDA
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama3_1b_decode_loop_dynamic` panicked at `flatten_strides` with `left: 4, right: 1` during
|
||||
CUDA `KernelMul::compile`. The `KernelMul` had `out_shape=[1, 1, a, a]` but `b_stride=[z]` (1D).
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`parse_where_node` called `x.cond(condition, y)` without broadcasting the inputs to matching ranks.
|
||||
The ONNX Where op for the attention mask had condition=[1,1,a,a] (4D), x=[1] (scalar), y=[1] (scalar).
|
||||
Luminal's `cond` doesn't auto-broadcast — it passes the shape trackers directly to the HLIR node.
|
||||
The resulting Mul had input A with 4D strides and input B with 1D strides.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Only triggered by 1B model**: The tiny model's Where inputs all had matching ranks (no scalars).
|
||||
2. **CUDA-only**: The native runtime's `bin_fn` uses `StridedIterator` which handles mismatched
|
||||
strides more gracefully. CUDA's `KernelMul::compile` calls `flatten_strides` which asserts
|
||||
`range.len() == strides.len()`.
|
||||
3. **Delayed crash**: The mismatch was created during ONNX parsing but only manifested during
|
||||
CUDA kernel compilation (graph search phase).
|
||||
|
||||
### The fix
|
||||
|
||||
Added numpy-style broadcasting to `parse_where_node`: compute the broadcast shape across all 3
|
||||
inputs, then `broadcast_to_expr` each to the common shape before calling `cond`.
|
||||
|
||||
### General principle
|
||||
|
||||
**ONNX binary/ternary ops all use numpy broadcasting.** When parsing ONNX ops that take multiple
|
||||
tensor inputs (Where, Add, Mul, etc.), always broadcast all inputs to a common shape BEFORE
|
||||
calling the luminal graph operation. Luminal graph ops do NOT auto-broadcast — they expect inputs
|
||||
with matching shape tracker dimensions.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
|
||||
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
|
||||
the value at column 0 of each row (all three top-k positions got the same value).
|
||||
Native backend was fine.
|
||||
|
||||
2. **Root cause**: `gather_elements` was called with a non-contiguous index tensor produced by
|
||||
`argsort(axis=1) → slice_along(..k, axis=1)`. The slice creates a ShapeTracker view of the
|
||||
[4,8] argsort buffer with dims [4,3] and strides [8,1]. When this flowed through the
|
||||
gather_elements Int arithmetic chain (cast, multiply, add) and into the final Gather CUDA
|
||||
kernel, the non-contiguous strides caused incorrect index reads for later rows.
|
||||
|
||||
3. **Why it was hard to find**: `test_topk_indices` passed (it only tests argsort+slice, not
|
||||
the downstream gather_elements). A standalone `test_gather_elements` with constant indices
|
||||
also passed because constant indices are contiguous. The bug only manifested when runtime-
|
||||
computed non-contiguous indices were used with data of a different size along the gather axis.
|
||||
|
||||
4. **Fix**: In `parse_topk_node`, compute `gather_elements(x, full_argsort, axis)` with the
|
||||
full [4,8] argsort result (same size as data), then slice the gathered values to [4,3].
|
||||
This ensures gather_elements always operates on same-sized contiguous tensors.
|
||||
|
||||
5. **General principle**: When building graph operations that chain shape-tracker views
|
||||
(slice, transpose, etc.) into downstream HLIR ops on CUDA, prefer operating on full
|
||||
contiguous tensors first and slicing the result afterward. Non-contiguous views flowing
|
||||
through multiple CUDA kernels can trigger stride-related bugs in the egglog-compiled code.
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-07 — Non-deterministic CUDA_ERROR_ILLEGAL_ADDRESS: Multiple Missing Rank Constraints
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_tiny` on CUDA failed ~70% of runs with `CUDA_ERROR_ILLEGAL_ADDRESS`. Failures
|
||||
were non-deterministic due to egglog's `FxHashMap` iteration order in `random_initial_choice()`.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Multiple** matmul egglog rules lacked `(= (len ?out_shape) 2)` constraints:
|
||||
|
||||
1. `TileMatmulSplitK` in `block/ops.rs` (disabled via comment but rule still registered)
|
||||
2. `TileMatmulFullSplit` in `block/ops.rs`
|
||||
3. All 4 `sgemm_v2_*.egg` rules in `host/cublas/`
|
||||
|
||||
The `cublaslt_*.egg` rules already had the constraint. When egglog picked TileMatmul or sgemm
|
||||
for a 3D+ batched matmul, the generated CUDA kernels accessed out-of-bounds memory.
|
||||
|
||||
Additionally, `KernelEmbed` in `kernel/hlir.rs` had an output indexing bug:
|
||||
`out[out_offset * embed_dim + embed_idx]` should be `out[out_offset + embed_idx]` because
|
||||
`out_offset` already includes the embed_dim factor from `flatten_strides`.
|
||||
|
||||
**Most critically**, the KernelEmbed and RowEmbed "with cast" egglog rules passed the
|
||||
**pre-cast** float token_ids (`?token_ids`) to the embed kernel instead of the **post-cast**
|
||||
int token_ids (`?token_ids_cast`). The CUDA kernel reads token_ids as `const int*`, so float
|
||||
data gets reinterpreted as enormous garbage integers, causing out-of-bounds embed table access.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Multiple independent bug sources**: The ~70% failure rate was caused by three separate bugs
|
||||
(matmul rank, embed output indexing, embed pre-cast input). Each fix only reduced the rate
|
||||
partially, making it seem like each fix was insufficient.
|
||||
2. **CudaGraph wrapping**: The crash occurred inside `CudaGraphOp::execute_internal` which
|
||||
batches multiple kernels via CUDA graphs. The error just said "CudaGraph" — it
|
||||
didn't identify which kernel crashed. Adding per-kernel debug launches was essential.
|
||||
3. **Cascading failures**: When the Megakernel (containing RowEmbed with the pre-cast bug)
|
||||
corrupted the embed output, the NEXT CudaGraph group's kernels crashed reading the garbage.
|
||||
This made the Megakernel appear to be the victim, not the source.
|
||||
4. **The pre-cast bug only crashes SOMETIMES**: Egglog's random choice determines whether
|
||||
KernelEmbed/RowEmbed is selected (crash) or the generic Gather path is used (works).
|
||||
Float token_id 1.0 (= 0x3F800000 = 1065353216 as int) produces an astronomically large
|
||||
embed table index, causing ILLEGAL_ADDRESS.
|
||||
|
||||
### The fix
|
||||
|
||||
- Added `(= (len ?out_shape) 2)` to TileMatmulSplitK, TileMatmulFullSplit, and all 4 sgemm_v2 rules
|
||||
- Fixed KernelEmbed output indexing: `out[out_offset + embed_idx]`
|
||||
- **Fixed KernelEmbed/RowEmbed "with cast" rules**: Changed input from `?token_ids` to
|
||||
`?token_ids_cast` — using the post-Cast int tensor instead of the pre-Cast float tensor
|
||||
|
||||
### Results
|
||||
|
||||
Failure rate: ~70% → 0% (20/20 passing). All three bugs needed to be fixed together.
|
||||
|
||||
### General principle
|
||||
|
||||
**When an egglog rule matches a sub-expression chain (like Cast→Mul→Add), be precise about
|
||||
which intermediate result becomes each input.** The "with cast" embed rules matched
|
||||
`Cast(?token_ids, ...)` to verify the Cast existed, but then passed `?token_ids` (the Cast
|
||||
INPUT) instead of `?token_ids_cast` (the Cast OUTPUT) to the embed kernel. The kernel expects
|
||||
int data, so the pre-cast float data was reinterpreted as garbage ints.
|
||||
|
||||
**Always search for sibling implementations**: KernelEmbed (in `kernel/hlir.rs`) and RowEmbed
|
||||
(in `block/ops.rs`) had the SAME bug in their "with cast" rules. Fixing one without the other
|
||||
only reduces the failure rate — both must be fixed.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-09 — TileMatmulFullSplit Matches Element-wise Square+Sum from LayerNorm
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_qwen_image_transformer_tiny` on CUDA produced NaN in specific output rows. The failure
|
||||
was non-deterministic (~85% failure rate) due to egglog's random e-class extraction picking
|
||||
TileMatmulFullSplit for some operations.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The `TileMatmulFullSplit` rewrite rule in `block/ops.rs` matched any `Mul + Sum` pattern with
|
||||
a 2D output, contiguous K-strides, and F32 inputs. This correctly matched real matmuls, but
|
||||
ALSO matched the element-wise `x * x + Sum(last_dim)` pattern from LayerNorm/RMSNorm
|
||||
(Pow(x, 2) → ReduceMean).
|
||||
|
||||
For a [1, 4, 64] activation tensor `x`:
|
||||
- `Mul(x, x)` shape: [1, 4, 64], strides: [256z, 64z, z] for both inputs
|
||||
- `Sum(dim=2)` output: [1, 4], len=2 ✓
|
||||
|
||||
TileMatmulFullSplit interpreted this as a [1, 64] × [64, 4] → [1, 4] matmul with:
|
||||
- A = row 0 of x (64 elements), B = same buffer at column offsets
|
||||
|
||||
The kernel computed `C[j] = sum_k x[k] * x[j*64+k]` (cross-products) instead of the correct
|
||||
`C[j] = sum_k x[j*64+k]^2` (squared sums). This produced subtly wrong values for j > 0
|
||||
(correct for j=0 since cross-product with self = squared sum). These wrong values propagated
|
||||
through LayerNorm → downstream operations → softmax → NaN.
|
||||
|
||||
Key diagnostic: adding `printf` to the kernel showed `a_ptr == b_ptr` (same buffer for both
|
||||
inputs), confirming the kernel was operating on `x * x` not a real matmul.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Individual op tests passed**: Simple Gemm tests, attention tests, and all other bisection
|
||||
tests passed because they didn't have the specific `x*x → Sum` pattern.
|
||||
2. **Non-deterministic**: The bug only manifested when egglog selected TileMatmulFullSplit
|
||||
over the kernel fallback for the square+sum operation.
|
||||
3. **No NaN from TileMatmulFullSplit itself**: The kernel produced wrong-but-finite values.
|
||||
NaN only appeared downstream through softmax (exp(large) → ∞ → ∞/∞ = NaN).
|
||||
4. **Systematic elimination needed**: Had to disable all block ops, then enable one at a time,
|
||||
to narrow down TileMatmulFullSplit as the culprit.
|
||||
|
||||
### The fix
|
||||
|
||||
Added matmul broadcast constraints to both `TileMatmulFullSplit` and `TileMatmulSplitK` rules:
|
||||
|
||||
```egglog
|
||||
; Assert proper matmul broadcast pattern:
|
||||
; A is broadcast over N (a_n_stride = 0), B is broadcast over M (b_m_stride = 0)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
```
|
||||
|
||||
In a real matmul `[M, K] × [K, N]`, the Mul is created by expanding dims:
|
||||
- A is broadcast over N → a_n_stride = 0
|
||||
- B is broadcast over M → b_m_stride = 0
|
||||
|
||||
In element-wise `x * x`, both strides are identical (non-zero for all dims), so the
|
||||
constraints correctly reject it. The cuBLAS `.egg` rules already had these constraints.
|
||||
|
||||
### General principle
|
||||
|
||||
**Matmul Mul+Sum patterns have specific broadcast structure: one input is broadcast over M
|
||||
and the other over N.** When writing egglog rules that match `Mul + Sum` patterns for matmul
|
||||
optimization, always verify the broadcast pattern (`a_n_stride = 0` and `b_m_stride = 0`).
|
||||
This prevents matching element-wise operations like `x*x → sum` that happen to have a 2D
|
||||
output and contiguous strides.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-09 — Conv3D Permute Axis Mismatch in ONNX Conv Parser
|
||||
|
||||
### Symptom
|
||||
|
||||
`test_qwen_image_vae_decoder_tiny` panicked with:
|
||||
> Permute axes (5) doesn't match shape axes (6)
|
||||
|
||||
at `src/shape/tracker.rs:153`, during `parse_conv_node`.
|
||||
|
||||
### Root cause
|
||||
|
||||
The Conv parser's unfold → matmul algorithm used two consecutive permutes with incorrect
|
||||
index calculations. After unfold produces a 2N-dimensional tensor
|
||||
`[win_0..win_{N-1}, k_0..k_{N-1}]`, the first permute swapped kernel dims to the front.
|
||||
But the second permute's index math still assumed the original (pre-first-permute) ordering,
|
||||
confusing kernel dimensions with window dimensions. Additionally:
|
||||
|
||||
1. `output_spatial_dims` was captured from wrong indices (kernel dims instead of window
|
||||
spatial dims)
|
||||
2. The `split_dims` loop iterated `spatial` times instead of `spatial-1`, creating a
|
||||
spurious size-1 dimension
|
||||
3. The final permute array had `1+spatial` elements for a tensor with `2+spatial` dims
|
||||
|
||||
For Conv2D (spatial=2) this was never caught because the xfail'd VAE decoder test was the
|
||||
only test exercising the Conv parser — the transformer tests don't use Conv ONNX nodes.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
The Conv parser was written and the VAE test immediately xfail'd due to a *different* bug
|
||||
(`merge_dims` being `todo!()`). Once `merge_dims` was implemented, the Conv parser's own
|
||||
bugs surfaced for the first time.
|
||||
|
||||
### Fix
|
||||
|
||||
Rewrote the unfold → matmul section with a single correct permute:
|
||||
|
||||
1. **One permute** to `[N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]`
|
||||
— groups batch | output spatial | channel+kernel
|
||||
2. **Capture** `output_spatial_dims` from correct indices `[1..1+spatial]`
|
||||
3. **Merge** all channel+kernel dims from the end into one
|
||||
4. **Merge** spatial dims into one → `[N, spatial_product, C_in*kernel_product]`
|
||||
5. **Matmul** → `[N, spatial_product, C_out]`
|
||||
6. **Split** spatial back with `spatial-1` splits (not `spatial`)
|
||||
7. **Permute** C_out to position 1 with correct `2+spatial` element array
|
||||
|
||||
### General principle
|
||||
|
||||
**When chaining permutes on high-dimensional tensors, prefer a single combined permute.**
|
||||
Multiple permutes with hand-computed index arrays are error-prone because each permute
|
||||
redefines what indices mean. A single permute from the original layout to the target layout
|
||||
is easier to verify and less likely to confuse source/destination ordering. Also, ensure
|
||||
`split_dims` loop counts match: splitting N dims out of a product requires N-1 splits
|
||||
(the outermost dim is the quotient, not split out separately).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-18 — CUDA Search Rejects All Candidates: Zero Dummy Data Causes NaN for Div/Pow/Mod/Erf
|
||||
|
||||
### What the symptom was
|
||||
|
||||
6 CUDA tests (`test_pow`, `test_pow_broadcast`, `test_div`, `test_mod`, `test_mod_broadcast`,
|
||||
`test_erf`) consistently failed with `Failed to find a viable initial genome for group 0 after
|
||||
100 attempts`. All 6 passed on native backend.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The CUDA two-phase initialization in `build_cuda_backend` set ALL input tensor buffers to
|
||||
`0.0f32` as dummy data for profiling. When `torch.compile` decomposes a model, it passes
|
||||
model weights as additional ONNX graph inputs (not initializers). Since there were no ONNX
|
||||
initializers to overwrite the zeros, weight buffers stayed all-zero during search.
|
||||
|
||||
Operations with zero inputs produced NaN:
|
||||
- `fmod(0, 0) = NaN` (Mod test)
|
||||
- `weight * recip(0) = weight * inf` → with any zero weight → `0 * inf = NaN` (Div test)
|
||||
- `abs(0).log() = log(0) = -inf` → downstream NaN (Pow test)
|
||||
- `sign(0)` chain → operations on zero inputs (Erf test)
|
||||
|
||||
The `has_nan_outputs` check rejected every candidate genome, exhausting all 100 attempts.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **No panic, no crash — silent NaN rejection**: The error message said "Failed to find a
|
||||
viable initial genome" which suggested an egglog rewrite issue, not a data issue.
|
||||
2. **Works on native**: `NativeRuntime::has_nan_outputs()` returns `false` by default (no NaN
|
||||
check), so zero inputs never caused problems on native.
|
||||
3. **torch.compile vs direct export difference**: Directly exporting a model via
|
||||
`torch.onnx.export(model, ...)` produces initializers. But `torch.compile`'s backend
|
||||
receives a `GraphModule` where weights are graph inputs, not initializers. The ONNX file
|
||||
from `torch.compile` has 0 initializers.
|
||||
4. **CudaRuntime's own `allocate_dummy_input` already uses 1.0**: The runtime knew zeros
|
||||
were problematic (comment: "Zero inputs often hide numerical issues"), but the
|
||||
`compiled_graph.rs` code used `0.0f32` independently.
|
||||
|
||||
### The fix
|
||||
|
||||
Changed dummy data from `vec![0.0f32; n_elements]` to `vec![1.0f32; n_elements]` in
|
||||
`build_cuda_backend`. Using 1.0 is numerically safe: `fmod(1,1)=0`, `recip(1)=1`,
|
||||
`log(1)=0`, `exp(1)≈2.7` — no NaN or inf. Profiling timing is unaffected (same number
|
||||
of FLOPs and memory accesses).
|
||||
|
||||
### General principle
|
||||
|
||||
**Use small non-zero values (1.0) for dummy profiling data, never zeros.** Zero is a
|
||||
singularity for many floating-point operations (division, log, fmod with zero divisor).
|
||||
The CUDA runtime's `allocate_dummy_input` already followed this principle — the ONNX
|
||||
pipeline's `build_cuda_backend` was inconsistent. When creating dummy data for GPU
|
||||
profiling, always match the runtime's safer default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-18 — Dynamic Decode Loop Fails: HLIR Weight Buffers Consumed After First Execute
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama3_1b_decode_loop_dynamic` passed step 0 (seq_len=6) but panicked on step 1
|
||||
(seq_len=7) with `no entry found for key` at `cublaslt/mod.rs:294` — the CuBlasLt op couldn't
|
||||
find its weight input buffer.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Two bugs:**
|
||||
|
||||
1. **Missing `)` in egglog rule** (`luminal_cuda_lite/src/kernel/hlir.rs:3042`): The fourth
|
||||
KernelEmbed rule ("kernel embed with mul reversed") had 3 closing parens after `INil` instead
|
||||
of 4. The missing `)` failed to close the `(= ?mul_result ...)` form. This caused an egglog
|
||||
parse error during search, caught by `catch_unwind`. The rule was dead code — it never fired,
|
||||
but the parse error consumed a search iteration.
|
||||
|
||||
2. **HLIR buffer consumption killed weight buffers** (`luminal_cuda_lite/src/runtime.rs:1010-1057`):
|
||||
After each `execute()`, the runtime removed all HLIR buffers (weights, constants) except those
|
||||
directly connected to Output nodes. This was intended to free one-shot input data, but it also
|
||||
deleted all 168 weight buffers. On the next `graph.run()`, CuBlasLt couldn't find any of its
|
||||
weight inputs — `hlir_buffers` had 1 entry (the just-set `input_ids`) instead of 169.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Misdirection by the egglog syntax error**: The plan identified the missing `)` as THE cause.
|
||||
Fixing it allowed the rule to parse correctly, but the real runtime failure was independent.
|
||||
2. **Step 0 always succeeds**: The weight consumption happens AFTER a successful execution. So
|
||||
the first `graph.run()` works perfectly — all 169 HLIR buffers exist. The panic only occurs
|
||||
on the second call, after consumption has cleared 168 of them.
|
||||
3. **The consumption code was deliberately designed**: Comments said "weight tensors must have
|
||||
`.persist()` to survive." The ONNX pipeline didn't call `.persist()` on weights, but this
|
||||
had never been a problem before because single-shot inference only calls `execute()` once.
|
||||
4. **Search phase panics masked by `catch_unwind`**: The same "no entry found for key" error
|
||||
occurred during profiling of search candidates, but was silently caught. This made it look
|
||||
like only certain LLIR variants had the issue, not all of them.
|
||||
5. **Debug output needed 4 iterations to find**: The first debug showed which NodeIndex was
|
||||
missing, the second showed it was an Input node, the third showed the HLIR mapping, and
|
||||
the fourth revealed `hlir_buffers_count` dropping from 169 to 1 between steps.
|
||||
|
||||
### The fix
|
||||
|
||||
1. Added missing `)` to the KernelEmbed egglog rule at `hlir.rs:3042`.
|
||||
2. In `compiled_graph.rs`, added `.persist()` calls on all weight/constant tensors (anything
|
||||
not in `input_names`) after `process_onnx_nodes` completes. `.persist()` creates an Output
|
||||
node connected to the Input, which the consumption code recognizes as "do not consume."
|
||||
User inputs (like `input_ids`) are intentionally NOT persisted — they are consumed after
|
||||
each `execute()` and re-set via `set_input()` before the next call.
|
||||
|
||||
### General principle
|
||||
|
||||
**Mark weight/constant tensors as persistent in the graph-building pipeline.** The runtime's
|
||||
`execute()` consumes all HLIR buffers not connected to Output nodes. This is correct behavior
|
||||
for one-shot user inputs, but weights must survive across calls. Always call `.persist()` on
|
||||
tensors that should outlive a single execution. In the ONNX pipeline, the distinction is clear:
|
||||
`input_names` (user-provided data per step) vs everything else (weights/constants loaded once).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-20 — PT2 CUDA Search Rejects All Candidates: Integer Buffers Misinterpreted as Float NaN
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_tiny` on CUDA via PT2 failed with:
|
||||
`pyo3_runtime.PanicException: Failed to find a viable initial genome for group 0 after 100 attempts`
|
||||
|
||||
The search tried 100 different egglog rewrites and ALL were rejected by the `has_nan_outputs` check.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Two issues, both required to fix:**
|
||||
|
||||
1. **Integer buffers misinterpreted as float in NaN check.** `has_nan_outputs` in
|
||||
`luminal_cuda_lite/src/runtime.rs` checks ALL `self.buffers` by reinterpreting raw bytes
|
||||
as `f32` and calling `is_nan()`. The PT2-translated graph has integer intermediate
|
||||
buffers (from `arange`, `cast(Int)`, integer arithmetic for embedding index computation).
|
||||
Certain valid `i32` bit patterns (e.g., large integers from `token_id * hidden_dim`)
|
||||
have exponent=0xFF and non-zero mantissa when reinterpreted as f32 — matching the
|
||||
IEEE 754 NaN pattern. This caused false NaN rejections for EVERY candidate genome.
|
||||
|
||||
2. **Real weights/constants loaded before search contain -inf.** The PT2 path loaded real
|
||||
safetensors weights and model constants (including the causal attention mask with `-inf`
|
||||
values) BEFORE the search. While the ONNX path also loads real initializer data before
|
||||
search, the PT2 graph's different structure (more explicit integer operations) made the
|
||||
integer NaN false-positive the blocking issue.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
- The original plan diagnosed this as the same zero-dummy-data bug fixed on 2026-03-18.
|
||||
Changing `0.0` to `1.0` was insufficient because the root cause was different.
|
||||
- `has_nan_outputs` checking ALL intermediate buffers (not just outputs) masked the real
|
||||
issue — the NaN was in integer index-computation buffers, not in the model's float outputs.
|
||||
- The ONNX-translated graph didn't have this problem because it doesn't produce as many
|
||||
integer intermediate buffers (ONNX embedding uses different ops).
|
||||
- The NaN pattern was identical across all 100 search attempts, which was the key clue:
|
||||
it was deterministic and independent of egglog rewrite choices, pointing to input data
|
||||
or buffer interpretation rather than graph optimization issues.
|
||||
|
||||
### The fix
|
||||
|
||||
Four changes:
|
||||
|
||||
1. **`luminal_cuda_lite/src/kernel/mod.rs`** (`KernelOp` trait): Added `output_dtype()`
|
||||
method with default `DType::F32`. Each kernel now reports its actual output dtype.
|
||||
|
||||
2. **`luminal_cuda_lite/src/kernel/hlir.rs`** and **`other_ops.rs`**: Overrode
|
||||
`output_dtype()` in all kernels with a `dtype` field (returns `self.dtype`), plus
|
||||
special cases: `KernelIota` → `DType::Int`, `KernelLessThan` → `DType::Bool`,
|
||||
`KernelCast` → `self.out_dtype`.
|
||||
|
||||
3. **`luminal_cuda_lite/src/runtime.rs`** (`has_nan_outputs`): Replaced fragile
|
||||
`format!("{:?}").contains("dtype: Int")` string matching with proper
|
||||
`op.to_dialect::<dyn KernelOp>().output_dtype()` check. Only F32 buffers are
|
||||
checked for NaN; integer and bool buffers are skipped.
|
||||
|
||||
4. **`rust/src/pt2_compiled_model.rs`** (`init_cuda_runtime`): Set ALL input nodes
|
||||
(weights, constants, user inputs) to `vec![1.0f32; n_elements]` before search via
|
||||
new `set_all_inputs_dummy_cuda` function, then reload real data after search.
|
||||
This prevents any -inf values from the causal mask from polluting intermediate
|
||||
float computations during profiling.
|
||||
|
||||
### General principle
|
||||
|
||||
**Never reinterpret integer buffer bytes as float for NaN checking.** When a graph has
|
||||
mixed-dtype operations (float model computation + integer index computation), raw byte
|
||||
buffers from integer kernels contain valid i32 values that look like NaN when cast to f32.
|
||||
The search's `has_nan_outputs` must be dtype-aware — use the kernel's `output_dtype()`
|
||||
method rather than string-matching on Debug output. Additionally, when diagnosing "all
|
||||
candidates rejected" during search, check whether the rejection is from actual float NaN
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
|
||||
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
|
||||
|
||||
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
|
||||
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
|
||||
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
0
crates/luminal_python/README.md
Normal file
0
crates/luminal_python/README.md
Normal file
369
crates/luminal_python/modal_pytest_runner.py
Normal file
369
crates/luminal_python/modal_pytest_runner.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""Run pytest on Modal with a dynamically selected GPU.
|
||||
|
||||
Usage:
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
|
||||
uv run modal run modal_pytest_runner.py --gpu T4 tests/
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import modal
|
||||
from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
|
||||
SRC_PATH = f"{PROJECT_DIR}/src"
|
||||
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
|
||||
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
|
||||
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
|
||||
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
HF_TOKEN_ENV_KEY = "HF_TOKEN"
|
||||
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
|
||||
HF_CACHE_VOLUME = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
|
||||
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
|
||||
.uv_sync(
|
||||
str(LOCAL_PROJECT_DIR),
|
||||
frozen=False,
|
||||
groups=["dev"],
|
||||
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
|
||||
)
|
||||
.workdir(PROJECT_DIR)
|
||||
.add_local_dir(
|
||||
str(LOCAL_PROJECT_DIR.parent.parent),
|
||||
remote_path="/root/luminal",
|
||||
copy=True,
|
||||
ignore=[
|
||||
".git",
|
||||
".claude-project",
|
||||
".cargo-local",
|
||||
"**/.venv",
|
||||
"**/.pytest_cache",
|
||||
"**/__pycache__",
|
||||
"**/luminal_artifacts",
|
||||
"**/target",
|
||||
"docs",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _hf_token_secret() -> modal.Secret | None:
|
||||
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
|
||||
if not hf_token:
|
||||
return None
|
||||
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
|
||||
|
||||
|
||||
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
|
||||
return any(arg == flag for arg in pytest_args)
|
||||
|
||||
|
||||
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
|
||||
return (
|
||||
cli_profile
|
||||
or _has_pytest_flag(pytest_args, "--profile")
|
||||
or _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
)
|
||||
|
||||
|
||||
def _run_id() -> str:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _prepare_scratch_dir(scratch_dir: Path) -> None:
|
||||
scratch_dir.mkdir(parents=True, exist_ok=True)
|
||||
linked_names = {
|
||||
".venv",
|
||||
".pytest_cache",
|
||||
"__pycache__",
|
||||
"luminal_artifacts",
|
||||
"prof",
|
||||
}
|
||||
for entry in Path(PROJECT_DIR).iterdir():
|
||||
if entry.name in linked_names:
|
||||
continue
|
||||
|
||||
target = scratch_dir / entry.name
|
||||
if target.exists() or target.is_symlink():
|
||||
continue
|
||||
|
||||
target.symlink_to(entry, target_is_directory=entry.is_dir())
|
||||
|
||||
|
||||
def _default_profile_output_dir(run_id: str) -> Path:
|
||||
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
|
||||
|
||||
|
||||
def _prepare_local_profile_dir(output_dir: Path) -> None:
|
||||
if output_dir.exists() and not output_dir.is_dir():
|
||||
raise NotADirectoryError(f"{output_dir} is not a directory")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = output_dir / "prof"
|
||||
if prof_dir.exists():
|
||||
shutil.rmtree(prof_dir)
|
||||
|
||||
manifest_path = output_dir / "manifest.json"
|
||||
if manifest_path.exists():
|
||||
manifest_path.unlink()
|
||||
|
||||
|
||||
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
|
||||
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
|
||||
_prepare_local_profile_dir(output_dir)
|
||||
|
||||
for entry in entries:
|
||||
relative_path = Path(entry.path).relative_to(run_id)
|
||||
if relative_path == Path("."):
|
||||
continue
|
||||
|
||||
destination = output_dir / relative_path
|
||||
if entry.type == FileEntryType.DIRECTORY:
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
||||
if entry.type != FileEntryType.FILE:
|
||||
continue
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("wb") as handle:
|
||||
for chunk in PROFILE_VOLUME.read_file(entry.path):
|
||||
handle.write(chunk)
|
||||
|
||||
|
||||
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
try:
|
||||
PROFILE_VOLUME.remove_file(run_id, recursive=True)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
def run(
|
||||
self,
|
||||
pytest_args: list[str],
|
||||
pytest_addopts: str = "",
|
||||
profile_enabled: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
started_at = _utc_now()
|
||||
run_id = _run_id() if profile_enabled else None
|
||||
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
|
||||
if scratch_dir is not None:
|
||||
_prepare_scratch_dir(scratch_dir)
|
||||
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
env["HF_HOME"] = HF_CACHE_PATH
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
|
||||
]
|
||||
if profile_enabled:
|
||||
sanitized_pytest_args.append("--profile")
|
||||
if dot_available:
|
||||
sanitized_pytest_args.append("--profile-svg")
|
||||
elif original_svg_requested:
|
||||
print(
|
||||
"Graphviz 'dot' is unavailable in the Modal container; "
|
||||
"falling back to raw .prof artifacts only.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
svg_requested = profile_enabled and dot_available
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
*sanitized_pytest_args,
|
||||
]
|
||||
exit_code = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
|
||||
).returncode
|
||||
HF_CACHE_VOLUME.commit()
|
||||
finished_at = _utc_now()
|
||||
|
||||
if not profile_enabled:
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": None,
|
||||
"profile_enabled": False,
|
||||
"remote_profile_dir": None,
|
||||
"local_default_dirname": None,
|
||||
}
|
||||
|
||||
volume_root = Path(PROFILE_VOLUME_PATH)
|
||||
if not volume_root.exists():
|
||||
raise RuntimeError(
|
||||
"Profiling requested but the profile volume is not mounted."
|
||||
)
|
||||
|
||||
remote_run_dir = volume_root / run_id
|
||||
remote_run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = scratch_dir / "prof"
|
||||
if prof_dir.is_dir():
|
||||
shutil.copytree(prof_dir, remote_run_dir / "prof")
|
||||
|
||||
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
|
||||
manifest = {
|
||||
"exit_code": exit_code,
|
||||
"finished_at": finished_at,
|
||||
"profile_enabled": True,
|
||||
"pytest_args": sanitized_pytest_args,
|
||||
"run_id": run_id,
|
||||
"started_at": started_at,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
(remote_run_dir / "manifest.json").write_text(
|
||||
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
PROFILE_VOLUME.commit()
|
||||
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": run_id,
|
||||
"profile_enabled": True,
|
||||
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
|
||||
"local_default_dirname": run_id,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
allow_abbrev=False,
|
||||
description="Run pytest on Modal with a dynamically selected GPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
required=True,
|
||||
help="GPU type to request from Modal (for example: A100, T4, H100).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Enable pytest-profiling and download the resulting artifacts locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-output-dir",
|
||||
help="Directory to download profiling artifacts into when profiling is enabled.",
|
||||
)
|
||||
parsed, pytest_args = parser.parse_known_args(cli_args)
|
||||
|
||||
if pytest_args and pytest_args[0] == "--":
|
||||
pytest_args = pytest_args[1:]
|
||||
if not pytest_args:
|
||||
pytest_args = ["tests/"]
|
||||
|
||||
return (
|
||||
parsed.gpu,
|
||||
parsed.timeout,
|
||||
parsed.profile,
|
||||
parsed.profile_output_dir,
|
||||
pytest_args,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main(*cli_args: str):
|
||||
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
|
||||
cli_args
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
if hf_token_secret is not None:
|
||||
runner_options["secrets"] = [hf_token_secret]
|
||||
runner = TestRunner.with_options(**runner_options)()
|
||||
result = runner.run.remote(
|
||||
pytest_args=pytest_args,
|
||||
pytest_addopts=pytest_addopts,
|
||||
profile_enabled=profile_enabled,
|
||||
)
|
||||
|
||||
if result["profile_enabled"] and result["run_id"] is not None:
|
||||
if profile_output_dir:
|
||||
output_dir = Path(profile_output_dir).expanduser().resolve()
|
||||
else:
|
||||
output_dir = _default_profile_output_dir(result["local_default_dirname"])
|
||||
|
||||
try:
|
||||
_download_profile_artifacts(result["run_id"], output_dir)
|
||||
print(f"Profile artifacts downloaded to {output_dir}")
|
||||
_cleanup_remote_profile_artifacts(result["run_id"])
|
||||
except FileNotFoundError as exc:
|
||||
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
|
||||
except OSError as exc:
|
||||
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
|
||||
|
||||
sys.exit(result["exit_code"])
|
||||
65
crates/luminal_python/pyproject.toml
Normal file
65
crates/luminal_python/pyproject.toml
Normal file
@@ -0,0 +1,65 @@
|
||||
[project]
|
||||
name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
"flash-attn-3>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
flash-attn-3 = { index = "pytorch-cu128" }
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "src"
|
||||
manifest-path = "rust/Cargo.toml"
|
||||
module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: tests that download large models or require pre-generated artifacts",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=5.5.0,<6",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"tiktoken>=0.12.0",
|
||||
"pydantic>=2.12.5",
|
||||
"psutil>=7.2.2",
|
||||
"modal>=1.3.5",
|
||||
"pillow",
|
||||
"flash-attn>=2.8.3",
|
||||
]
|
||||
flash-attention-4 = [
|
||||
"nvidia-cutlass-dsl==4.1.0",
|
||||
]
|
||||
44
crates/luminal_python/run_all_tests.sh
Executable file
44
crates/luminal_python/run_all_tests.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 1: Building native backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native + ONNX ---"
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Native + PT2 ---"
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Building CUDA backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA + ONNX ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: CUDA + PT2 ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " All tests passed!"
|
||||
echo "=========================================="
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user