mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
404 Commits
matmul-fla
...
tucker/cub
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a45629cece | ||
|
|
4cd47ffa45 | ||
|
|
db72cf505c | ||
|
|
766db93b08 | ||
|
|
4e93f02725 | ||
|
|
25393a9fdd | ||
|
|
81ea750e6b | ||
|
|
f94335b1b8 | ||
|
|
f62e3c50d0 | ||
|
|
eeeabd7c20 | ||
|
|
0f02466f3d | ||
|
|
156fac518e | ||
|
|
a3df68bd43 | ||
|
|
7a95e56a8b | ||
|
|
e558ce6849 | ||
|
|
c898b7fd53 | ||
|
|
6cfbf538d0 | ||
|
|
966f6f8147 | ||
|
|
8ea9a71747 | ||
|
|
861c3f0419 | ||
|
|
8f17561094 | ||
|
|
d5e9001c8b | ||
|
|
6416ddb5f8 | ||
|
|
c9d4ce6217 | ||
|
|
1dcd0370ce | ||
|
|
6757a4e37b | ||
|
|
631451f8b8 | ||
|
|
70bdd75163 | ||
|
|
855f2bfd02 | ||
|
|
cf7fa2297c | ||
|
|
cd3f55a3a7 | ||
|
|
11653c6903 | ||
|
|
6d16bdba21 | ||
|
|
7bfd19fb72 | ||
|
|
42caa4750e | ||
|
|
1279dca4e6 | ||
|
|
53f7960130 | ||
|
|
5c3407c596 | ||
|
|
47530062a4 | ||
|
|
8524636d6f | ||
|
|
22e7b2da49 | ||
|
|
198bd2d76b | ||
|
|
6a86e70a19 | ||
|
|
141c06f2bf | ||
|
|
352478f63c | ||
|
|
a63a5278b9 | ||
|
|
6b5504de47 | ||
|
|
6ad13f06d3 | ||
|
|
2d736cc499 | ||
|
|
2862f7ed22 | ||
|
|
b063a6ce73 | ||
|
|
b28b3e7dc6 | ||
|
|
c745f77be7 | ||
|
|
4a1bd598b4 | ||
|
|
724d7e2975 | ||
|
|
39e593e2df | ||
|
|
cfedd80c9b | ||
|
|
84fa320b53 | ||
|
|
5748ac644e | ||
|
|
5c8c9fc95a | ||
|
|
706d24883d | ||
|
|
b7aa15a51c | ||
|
|
3361fce3dc | ||
|
|
f4739a7900 | ||
|
|
cfe27e8001 | ||
|
|
9594d41e21 | ||
|
|
a2ce18063b | ||
|
|
b6e5a71383 | ||
|
|
3a20266785 | ||
|
|
cf4d88bf48 | ||
|
|
98b9b8ac54 | ||
|
|
c0f3970feb | ||
|
|
a5ab33a680 | ||
|
|
7235a98a43 | ||
|
|
6f291c4b9a | ||
|
|
b739a21d3b | ||
|
|
88bcd12a96 | ||
|
|
8bdcae291c | ||
|
|
45ae09b1c2 | ||
|
|
8f3f2a3048 | ||
|
|
6a7cefd3b2 | ||
|
|
f94f7ca43d | ||
|
|
86800211ff | ||
|
|
08c06d440e | ||
|
|
50733ea85c | ||
|
|
5f14b1e84f | ||
|
|
b5d6daf08e | ||
|
|
cf9c27aca9 | ||
|
|
1e3dff6ee7 | ||
|
|
e3968edb1a | ||
|
|
04b407560b | ||
|
|
c2e12b666f | ||
|
|
89238d4b24 | ||
|
|
16c7345e5a | ||
|
|
2724466a3f | ||
|
|
4d1ff217be | ||
|
|
44b293bee0 | ||
|
|
f9b9657c1c | ||
|
|
6db0f716d5 | ||
|
|
d03ab816d8 | ||
|
|
61904fbc76 | ||
|
|
f461fca3da | ||
|
|
5f199e94c6 | ||
|
|
93fb02c495 | ||
|
|
16de9638fc | ||
|
|
f08d24e73f | ||
|
|
aba9627563 | ||
|
|
7d68b62aa8 | ||
|
|
13c870de86 | ||
|
|
f8b742d718 | ||
|
|
3555d169bd | ||
|
|
be74153c12 | ||
|
|
75535c93f0 | ||
|
|
84f13cae00 | ||
|
|
703c2d9ea4 | ||
|
|
44324f1c2d | ||
|
|
f6845011d8 | ||
|
|
6e7ee5581d | ||
|
|
2e3158c48e | ||
|
|
8af22776aa | ||
|
|
cd8c01f620 | ||
|
|
461b746937 | ||
|
|
38e467aa6c | ||
|
|
7429ac163b | ||
|
|
07c151dd70 | ||
|
|
c0f7f1f054 | ||
|
|
df96fe5110 | ||
|
|
18a550dd15 | ||
|
|
254680001d | ||
|
|
2920011897 | ||
|
|
d879376697 | ||
|
|
2be30c18cd | ||
|
|
48f921d2a1 | ||
|
|
f55e7e0589 | ||
|
|
db2027d345 | ||
|
|
9a5032bfc9 | ||
|
|
c665b01c4e | ||
|
|
883508e682 | ||
|
|
080b99b69e | ||
|
|
0bd19289ea | ||
|
|
a3b7f6ecc1 | ||
|
|
438ae460bf | ||
|
|
da440fdef0 | ||
|
|
586365be4d | ||
|
|
3c962a9df8 | ||
|
|
1a460bac96 | ||
|
|
ce06a901cc | ||
|
|
c97288cdae | ||
|
|
d66b3f2643 | ||
|
|
66b0807462 | ||
|
|
c24ea4a7a5 | ||
|
|
c309d9b4ed | ||
|
|
745c071ee5 | ||
|
|
56ffe8bbb3 | ||
|
|
13dbdcb53b | ||
|
|
c8ad5f8b75 | ||
|
|
51c6596f6a | ||
|
|
aef4c68537 | ||
|
|
1ac423c36c | ||
|
|
59c38b3c88 | ||
|
|
9b3b2f5244 | ||
|
|
aed7b86aad | ||
|
|
e3c6d98f36 | ||
|
|
10971d7d05 | ||
|
|
4b0bfa5669 | ||
|
|
2c0c3bb988 | ||
|
|
ca6fac8f78 | ||
|
|
900fee4d67 | ||
|
|
59901c8b12 | ||
|
|
a860a2cb6b | ||
|
|
52b2a45c62 | ||
|
|
0af1c186fd | ||
|
|
e6d13a3979 | ||
|
|
86b2784b51 | ||
|
|
773935b91b | ||
|
|
afb8d7ae4d | ||
|
|
fb23b80a01 | ||
|
|
d6a3171b7b | ||
|
|
59edd0b179 | ||
|
|
8a2fd832b6 | ||
|
|
76c0d43aa0 | ||
|
|
f99f1e10cb | ||
|
|
a5b26100ba | ||
|
|
a40f5dd386 | ||
|
|
efe746ba39 | ||
|
|
d91dce41d4 | ||
|
|
11d59a351c | ||
|
|
6d66f80340 | ||
|
|
2da5cdaa30 | ||
|
|
44520a8100 | ||
|
|
53c58576fc | ||
|
|
64e4eedcc6 | ||
|
|
cc1b448c90 | ||
|
|
63afb602b0 | ||
|
|
985e7752aa | ||
|
|
3fd7831e6d | ||
|
|
4c8bed686f | ||
|
|
cbf1ef5fc4 | ||
|
|
7a53d39852 | ||
|
|
3786977f01 | ||
|
|
1a4662ec3b | ||
|
|
2963278637 | ||
|
|
97f11a78bf | ||
|
|
27faf0819c | ||
|
|
c225d3affb | ||
|
|
ac10f82308 | ||
|
|
f2f5944f47 | ||
|
|
f9865ae2a3 | ||
|
|
46ebc58334 | ||
|
|
a28b755245 | ||
|
|
fd83534e53 | ||
|
|
b5d984c3fa | ||
|
|
64a5ca41b5 | ||
|
|
9bda47714a | ||
|
|
9e513b6589 | ||
|
|
a62d728bd7 | ||
|
|
4114714d3f | ||
|
|
6191597571 | ||
|
|
253cd95ab0 | ||
|
|
d7e396ba5b | ||
|
|
1a53626716 | ||
|
|
4329d68adc | ||
|
|
989e7e2d44 | ||
|
|
4f0a3ab102 | ||
|
|
019972cdd4 | ||
|
|
d7a3f468bd | ||
|
|
c504fbf8a1 | ||
|
|
648720caf9 | ||
|
|
625be7f4da | ||
|
|
21ed7ef31f | ||
|
|
6e94f80c9e | ||
|
|
c2a17a4854 | ||
|
|
386b3df983 | ||
|
|
5c60f1d768 | ||
|
|
4c51e3ea84 | ||
|
|
846551aa6f | ||
|
|
c26076bc75 | ||
|
|
871629b770 | ||
|
|
c6dfa9c62f | ||
|
|
90e3a915d7 | ||
|
|
56cb237aa2 | ||
|
|
a2c42b35c8 | ||
|
|
898204b2dd | ||
|
|
2c1a7f087f | ||
|
|
412147ea78 | ||
|
|
2e27c29b47 | ||
|
|
92e4260f1e | ||
|
|
662a564efc | ||
|
|
1761dc6b66 | ||
|
|
da71273d7e | ||
|
|
39122672b4 | ||
|
|
d866ba6407 | ||
|
|
9a0fb453ed | ||
|
|
dab60f0b21 | ||
|
|
1ea872bd2a | ||
|
|
90a66ac704 | ||
|
|
2b94ba0b71 | ||
|
|
2ed65d5386 | ||
|
|
336d49c147 | ||
|
|
1ff5840a76 | ||
|
|
bc94b10648 | ||
|
|
7c921d03a8 | ||
|
|
4e46051617 | ||
|
|
a55952d591 | ||
|
|
679aa7e092 | ||
|
|
3dd2be2fb2 | ||
|
|
c290e266f7 | ||
|
|
be3d8aa064 | ||
|
|
84e0c842a1 | ||
|
|
403fd36b1f | ||
|
|
651a4c2aee | ||
|
|
41ddd244ef | ||
|
|
0dbee87a8c | ||
|
|
194b8adfa5 | ||
|
|
86e616800d | ||
|
|
683205121d | ||
|
|
6d653e854d | ||
|
|
08397b566d | ||
|
|
5310335256 | ||
|
|
638765b62b | ||
|
|
3850b3a533 | ||
|
|
a4c84c6cf5 | ||
|
|
f32161d43b | ||
|
|
da83d51b27 | ||
|
|
29a3ffa3e3 | ||
|
|
97e358916a | ||
|
|
631c1b53d7 | ||
|
|
02449a6bea | ||
|
|
6c4597102e | ||
|
|
b077cfdb76 | ||
|
|
869b519e39 | ||
|
|
2b831c9f25 | ||
|
|
f35a950496 | ||
|
|
9ab0e1472c | ||
|
|
88f2601d5e | ||
|
|
b0ebdcba8c | ||
|
|
0ab124194b | ||
|
|
7f042ae615 | ||
|
|
082d9c48bd | ||
|
|
251e9526f3 | ||
|
|
41b3774ec2 | ||
|
|
3fdb464f5a | ||
|
|
a3a4fd94ec | ||
|
|
5446dccb04 | ||
|
|
8e6535563e | ||
|
|
bdc923aa50 | ||
|
|
ea67742b3b | ||
|
|
149e570f26 | ||
|
|
ac52098d5c | ||
|
|
01946ecd10 | ||
|
|
ef70fee204 | ||
|
|
a6fea110dc | ||
|
|
2d4ebb2cb6 | ||
|
|
5adb875b04 | ||
|
|
2adfcfa70e | ||
|
|
65600e8730 | ||
|
|
4c4f39b4af | ||
|
|
49b9209ad0 | ||
|
|
dea5df51dd | ||
|
|
eb6a6c2174 | ||
|
|
8864ef31fb | ||
|
|
38c98a8835 | ||
|
|
31b5fd886d | ||
|
|
04b2753aa8 | ||
|
|
dbb5282fd6 | ||
|
|
8e315c62df | ||
|
|
f53d990581 | ||
|
|
2c8ecba6a5 | ||
|
|
1644cce031 | ||
|
|
b2bb455b30 | ||
|
|
8628b1425a | ||
|
|
ca66609d6f | ||
|
|
c50e122ac1 | ||
|
|
272acabd0c | ||
|
|
f772c0529a | ||
|
|
c3e1f568ea | ||
|
|
eb3dd02836 | ||
|
|
8a9f85b0ce | ||
|
|
372501e527 | ||
|
|
cda12a6d84 | ||
|
|
566fb00ed2 | ||
|
|
ecb78a2635 | ||
|
|
f401ffb900 | ||
|
|
cd94000140 | ||
|
|
1f1636e188 | ||
|
|
371fa8491a | ||
|
|
6d1fe67b66 | ||
|
|
189d1e2594 | ||
|
|
c7acfb9794 | ||
|
|
fe6af5290a | ||
|
|
8b8669c744 | ||
|
|
7af771b999 | ||
|
|
133757f187 | ||
|
|
07ee241b25 | ||
|
|
958331ab6c | ||
|
|
340199d4a8 | ||
|
|
f17a95e673 | ||
|
|
6bb576e711 | ||
|
|
744e4d767a | ||
|
|
c940161f25 | ||
|
|
3aa2c309f5 | ||
|
|
8a0592646b | ||
|
|
68ce81e52b | ||
|
|
39789404f4 | ||
|
|
8c53234966 | ||
|
|
71eca945cb | ||
|
|
8da130ae1c | ||
|
|
fef6a45c9c | ||
|
|
c6763a69ba | ||
|
|
30caca106c | ||
|
|
6c90bb5059 | ||
|
|
82189cd602 | ||
|
|
cc5e0a639d | ||
|
|
8dc05233cb | ||
|
|
0ab9947292 | ||
|
|
f11ba3a388 | ||
|
|
a346e503db | ||
|
|
6bbf244924 | ||
|
|
a8505668ac | ||
|
|
a0b237c424 | ||
|
|
8fabacd17e | ||
|
|
df0128ad04 | ||
|
|
de55e67594 | ||
|
|
cdff26755f | ||
|
|
9f11b7e24a | ||
|
|
27344a0e45 | ||
|
|
e9e6f824a1 | ||
|
|
d894eeae50 | ||
|
|
da078b5bdd | ||
|
|
f156265ff4 | ||
|
|
b34f104cea | ||
|
|
1873e26185 | ||
|
|
384a426ba3 | ||
|
|
a49c970029 | ||
|
|
1d2db8f88f | ||
|
|
a0a162049e | ||
|
|
d63cb1a115 | ||
|
|
e0413c640a | ||
|
|
da9a45a044 | ||
|
|
a736d1aa2f | ||
|
|
0d5880296a | ||
|
|
ea3fa459ec | ||
|
|
d21969370f | ||
|
|
45a6a62909 |
@@ -1,3 +1,6 @@
|
||||
[alias]
|
||||
examples = "run --release --bin examples-perf --"
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
rustflags = [
|
||||
"-Ctarget-feature=+fp16,+fhm"
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
{
|
||||
"name": "Luminal",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:latest",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {}
|
||||
},
|
||||
"remoteEnv": {
|
||||
"GH_TOKEN": "${localEnv:GH_TOKEN}"
|
||||
},
|
||||
"name": "Luminal (CPU)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--gpus=all"
|
||||
"--env-file", ".env"
|
||||
],
|
||||
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}",
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
@@ -26,6 +39,7 @@
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
@@ -34,4 +48,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
55
.devcontainer/cuda/devcontainer.json
Normal file
55
.devcontainer/cuda/devcontainer.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"name": "Luminal (CUDA)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--env-file",
|
||||
".env",
|
||||
"--runtime=nvidia",
|
||||
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
|
||||
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: CUDA Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: CUDA Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as safe for git
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
23
.github/workflows/fmt.yml
vendored
Normal file
23
.github/workflows/fmt.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Fmt
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
25
.github/workflows/metal-clippy.yml
vendored
Normal file
25
.github/workflows/metal-clippy.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Metal Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
47
.github/workflows/modal-examples.yml
vendored
Normal file
47
.github/workflows/modal-examples.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Modal Examples
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
# - { type: "H100" }
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
EXAMPLE: ${{ matrix.example }}
|
||||
GPU_TYPE: ${{ matrix.gpu.type }}
|
||||
run: modal run ci/modal_example.py
|
||||
23
.github/workflows/ruff-format.yml
vendored
Normal file
23
.github/workflows/ruff-format.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff Format
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
23
.github/workflows/ruff.yml
vendored
Normal file
23
.github/workflows/ruff.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
24
.github/workflows/test-core.yml
vendored
Normal file
24
.github/workflows/test-core.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Test Core
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release -p luminal -p luminal_nn -p luminal_tracing -p luminal_python --verbose
|
||||
37
.github/workflows/test-cuda.yml
vendored
Normal file
37
.github/workflows/test-cuda.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: Test CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_unit_test:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Test Full CUDA
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
rust_cuda_ignored_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Rust CUDA Ignored Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run ignored CUDA Rust tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
GPU_TYPE: H100
|
||||
MODAL_TIMEOUT: "14400"
|
||||
CARGO_TEST_ARGS: "--ignored --test-threads=1"
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
python_cuda_slow_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Python CUDA Slow Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run slow pytest CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100-80GB --timeout 14400 tests/ -v -s -m slow
|
||||
36
.github/workflows/test-metal.yml
vendored
Normal file
36
.github/workflows/test-metal.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Test Metal
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
llama_1b_metal_example:
|
||||
name: Llama 1B Metal Example
|
||||
runs-on: macos-14-xlarge
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Print runner hardware
|
||||
run: system_profiler SPHardwareDataType SPDisplaysDataType
|
||||
- name: Cache Hugging Face models
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: llama-1b-metal-hf-${{ runner.os }}-${{ runner.arch }}-v1
|
||||
- name: Run Llama 1B Metal example and validate output
|
||||
run: rustup update; python3 ci/metal_llama_1b_example.py
|
||||
49
.github/workflows/test-python-cuda.yml
vendored
Normal file
49
.github/workflows/test-python-cuda.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Test Python CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_cuda_tests:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run pytest with CUDA backend on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 7200 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
retention-days: 7
|
||||
if-no-files-found: warn
|
||||
28
.github/workflows/test-python-native.yml
vendored
Normal file
28
.github/workflows/test-python-native.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Test Python Native
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --profile release
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
138
.github/workflows/test.yml
vendored
138
.github/workflows/test.yml
vendored
@@ -1,138 +0,0 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: rustup update; cargo test --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run clippy
|
||||
run: rustup update; cargo clippy --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Format
|
||||
run: cargo fmt --all --check
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:latest
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Run CUDA crate tests
|
||||
run: cargo test -p luminal_cuda --verbose -- --test-threads=1
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get install -y protobuf-compiler
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/ -v
|
||||
|
||||
python_cuda_tests:
|
||||
name: Python CUDA Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:latest
|
||||
options: --gpus all
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --features cuda
|
||||
- name: Run pytest with CUDA backend
|
||||
env:
|
||||
LUMINAL_BACKEND: cuda
|
||||
run: uv run pytest tests/ -v
|
||||
|
||||
# cuda_llama: # disabled because t4 doesn't have enough memory for full precision llama. re-enable when we can run on larger machines or use 8-bit precision
|
||||
# name: Cuda Llama
|
||||
# runs-on: cuda_t4_runner
|
||||
# timeout-minutes: 30
|
||||
# env:
|
||||
# CUDA_HOME: /usr/local/cuda-12.8
|
||||
# LD_LIBRARY_PATH: /usr/local/cuda-12.8/lib64
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Install system deps
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# protobuf-compiler \
|
||||
# cuda-nvrtc-12-8
|
||||
# - name: Install Rust
|
||||
# run: |
|
||||
# curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal
|
||||
# echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||
# - name: Update Rust
|
||||
# run: rustup update
|
||||
# - name: Install uv
|
||||
# run: curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# - name: Download Llama
|
||||
# working-directory: examples/llama
|
||||
# run: uv run --script setup/setup.py
|
||||
# - name: Run Llama
|
||||
# working-directory: examples/llama
|
||||
# run: SEARCH=1 cargo run --release
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -15,6 +15,10 @@ Cargo.lock
|
||||
*.gguf
|
||||
|
||||
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.pftrace
|
||||
*.safetensors
|
||||
*.safetensors.index.json
|
||||
|
||||
38
.pre-commit-config.yaml
Normal file
38
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.5
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
name: ruff check
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: cargo-fmt
|
||||
name: cargo fmt
|
||||
entry: cargo fmt --all --check
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy
|
||||
name: cargo clippy
|
||||
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy-metal
|
||||
name: cargo clippy metal
|
||||
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
- id: cargo-clippy-cuda-lite
|
||||
name: cargo clippy cuda_lite
|
||||
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
16
AGENTS.md
16
AGENTS.md
@@ -3,9 +3,19 @@
|
||||
## Structure
|
||||
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
|
||||
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
|
||||
## Debugging and Correctness
|
||||
- Treat model examples as specifications of the intended architecture. Do not change model code, prompt templates, weights, or example logic to hide compiler/runtime/search bugs unless the model code is demonstrably semantically wrong.
|
||||
- When outputs are incorrect, first root-cause the failing compiler/runtime path. Prefer isolating the bad LLIR/HLIR graph, rewrite, op lowering, shape/stride assumption, layout contract, or runtime implementation that caused the mismatch.
|
||||
- Avoid narrow special-case fixes. A fix should state and enforce the general invariant it relies on, or explicitly document why the affected operation is only valid for a restricted layout/shape and ensure rewrites enforce that restriction.
|
||||
- For e-graph/search issues, assume all selectable LLIR graphs are intended to be semantically equivalent. If two selectable graphs disagree, debug the equivalence violation rather than selecting around the bad graph.
|
||||
- Add regression tests at the level where the bug occurred. Prefer tests that compare against a semantic reference such as `NativeRuntime` or a small independent reference, and use fixed seeds for any randomized search/fuzz test so failures are reproducible.
|
||||
|
||||
## Compiler Rewrite Boundary
|
||||
- All graph pattern matching and op selection must be expressed in egglog rewrites. Do not add Rust-side LLIR graph post-passes that search for op patterns, fuse kernels, select backend ops, or otherwise rewrite extracted graphs after egglog. If a backend needs a fused/specialized op, add the match and rewrite in egglog and let extraction produce that op directly.
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -25,6 +25,7 @@ generational-box = "0.5.6"
|
||||
serde_json = "1.0.140"
|
||||
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-reports = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
|
||||
tracing = "0.1.43"
|
||||
paste = "1.0.15"
|
||||
@@ -32,13 +33,14 @@ pretty-duration = "0.1.1"
|
||||
anyhow = "1.0"
|
||||
graphviz-rust = { version = "0.9", default-features = false}
|
||||
lru = "0.16.2"
|
||||
rayon = "1.10"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
candle-nn = "0.9.2-alpha.1"
|
||||
candle-core = "0.9.2"
|
||||
candle-nn = "0.9.2"
|
||||
ordered-float = "5.1.0"
|
||||
proptest = "1.9.0"
|
||||
|
||||
@@ -46,7 +48,7 @@ proptest = "1.9.0"
|
||||
members = [
|
||||
"examples/*",
|
||||
"crates/luminal_nn",
|
||||
"crates/luminal_cuda",
|
||||
"crates/luminal_cuda_lite",
|
||||
"crates/luminal_metal",
|
||||
"crates/luminal_tracing",
|
||||
"crates/luminal_bench",
|
||||
@@ -54,4 +56,4 @@ members = [
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }
|
||||
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }
|
||||
|
||||
54
README.md
54
README.md
@@ -1,10 +1,10 @@
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/user-attachments/assets/c5832634-55d5-45b7-ba65-6efe36afce4a" />
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/luminal-ai/luminal/blob/main/docs/logo/inference_at_the_speed_of_light.png" />
|
||||
|
||||
<h3 align="center">
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
</h3>
|
||||
|
||||
[](https://github.com/jafioti/luminal/actions)
|
||||
[](https://github.com/luminal-ai/luminal/actions)
|
||||
[](https://docs.luminalai.com)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/APjuwHAbGy)
|
||||
@@ -55,23 +55,27 @@ Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100.
|
||||
|
||||
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
### PyTorch-native
|
||||
|
||||
Luminal directly integrates with PyTorch as a compiler backend. Simply do `torch.compile(model, backend=luminal_cuda)` to compile your PyTorch models. We also have an excellent tensor API in Rust.
|
||||
|
||||
### RISC-style architecture
|
||||
|
||||
Everything in Luminal boils down to 14 primitive ops:
|
||||
Everything in Luminal boils down to 15 primitive ops:
|
||||
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Scatter, Cast`
|
||||
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model.
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model in the world.
|
||||
|
||||
### Search
|
||||
|
||||
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
|
||||
The best heuristic is no heuristic. Luminal tries to search every possible decision to give the compiler the flexibility to discover complex optimizations. This allows us to automatically discover Flash Attention and other similarly complex optimizations without relying on hand-written operations or heuristics. It also allows us to stay extremely small and simple long into the future and beat the performance of far larger frameworks.
|
||||
|
||||
### Native
|
||||
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the accelerator APIs (CUDA, Metal, etc.). No indirections or abstractions, compatability layers, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
@@ -85,39 +89,45 @@ Most deep learning libraries are eager-first, meaning each op call directly oper
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### What about XLA?
|
||||
|
||||
XLA, torch.compile, TVM, and other traditional compiler stacks suffer from complexity explosion. They are made up of a very large set of destructive (one-direction) rewrite rules that lower and optimize a graph from a high-level representation to low-level machine code. But since these rules are destructive, they are required to only fire when it's certian that there's a performance benefit. This leads to the rules becoming very complex, special-cased, and numerous. Once additional hardware backends, model architectures, and new dtypes get thrown in, they suffer from the weight of their complexity and often produce very suboptimal code, requiring DSLs like Pallas or Triton to regain performance.
|
||||
|
||||
### Compile everything
|
||||
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as a static computation graphs, compiled, and executed later.
|
||||
|
||||
### First-class dynamism
|
||||
|
||||
A fully-static world would be nice, but we live in a world of nessecary dynamism. So we model dynamic shapes natively, as symbolic dimensions. Luminal supports arbitrary symbolic dimensions, including complex expressions, to give us shapes like `(s, 4096)`, `(b, h, w + 3)`, etc. This rich representation gives the compiler full visibility into shapes and lets it still do aggressive specialization.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, Luminal has global knowledge. This means we can push most ML complexity to the compiler. For instance, devices, datatypes, and even autograd is modeled ahead of time and optimized by the compiler!
|
||||
|
||||
Now we can do:
|
||||
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
- Low-precision dtypes (mxfp4, nvfp4, fp8, etc.)
|
||||
- Complex mutli-device parallelism topologies, searched ahead-of-time
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures
|
||||
|
||||
## Where are we?
|
||||
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- Native PyTorch support
|
||||
- Many kernel libraries supported in the search space (FlashInfer, cuBLASLt, etc.)
|
||||
- Many models implemented in our Rust tensor API in `examples/`.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
|
||||
- Expand the search space to utilize Tensor Cores more flexibly
|
||||
- Bring cuda to parity with Metal
|
||||
- Add Blackwell intrinsics, such as TMEM and TMA
|
||||
- Build a ROCm backend
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM inference _and_ training
|
||||
- More fine-grained dialects supporting thread- and warp-level intrinsics like TMA and tcgen.05
|
||||
- ROCm backend
|
||||
- More public infernce accelerator backends (coming very soon...)
|
||||
- Public benchmarking suite
|
||||
- Automatically searched model parallelism (TP, PP, EPS, EPR, SP, etc.)
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
|
||||
BIN
ci/__pycache__/modal_llama.cpython-312.pyc
Normal file
BIN
ci/__pycache__/modal_llama.cpython-312.pyc
Normal file
Binary file not shown.
85
ci/example_output.py
Normal file
85
ci/example_output.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import re
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
EXPECTED_CONCEPTS = {
|
||||
"llama": [
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "adapt"],
|
||||
["data", "patterns", "features"],
|
||||
],
|
||||
"gemma": [
|
||||
["neural network", "neural networks"],
|
||||
["nodes", "neurons"],
|
||||
["layers"],
|
||||
["weights"],
|
||||
["training", "learn", "learns"],
|
||||
],
|
||||
"qwen": [
|
||||
["neural network", "neural networks"],
|
||||
["computational model", "computational system"],
|
||||
["brain"],
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "training"],
|
||||
],
|
||||
"qwen3_moe": [
|
||||
["capital"],
|
||||
["france"],
|
||||
["paris"],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
normalized_output = normalize_output(output)
|
||||
|
||||
expected_concepts = EXPECTED_CONCEPTS.get(example)
|
||||
if expected_concepts is not None:
|
||||
missing = [
|
||||
concept_group
|
||||
for concept_group in expected_concepts
|
||||
if not any(normalize_output(term) in normalized_output for term in concept_group)
|
||||
]
|
||||
if missing:
|
||||
expected = "\n - ".join(" / ".join(group) for group in expected_concepts)
|
||||
missing_terms = "\n - ".join(" / ".join(group) for group in missing)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}.\n"
|
||||
f"Expected concept groups:\n - {expected}\n"
|
||||
f"Missing concept groups:\n - {missing_terms}"
|
||||
)
|
||||
|
||||
expected = ", ".join(" / ".join(group) for group in expected_concepts)
|
||||
print(f"\nOutput check passed for {example!r}: found concepts {expected}")
|
||||
return
|
||||
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
185
ci/examples_perf.py
Normal file
185
ci/examples_perf.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
|
||||
DEFAULT_EXAMPLES = ["llama", "gemma", "qwen", "qwen3_moe", "gemma4_moe", "whisper"]
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"llama": ["run", "--release", "-p", "llama"],
|
||||
"gemma": ["run", "--release", "-p", "gemma"],
|
||||
"qwen": ["run", "--release", "-p", "qwen", "--features", "cuda"],
|
||||
"qwen3_moe": ["run", "--release", "-p", "qwen3_moe"],
|
||||
"gemma4_moe": ["run", "--release", "-p", "gemma4_moe"],
|
||||
"whisper": ["run", "--release", "-p", "whisper"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
ttft_ms: float | None = None
|
||||
tpot_ms: float | None = None
|
||||
tps: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleResult:
|
||||
name: str
|
||||
ok: bool
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
wall_s: float = 0.0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = [arg for arg in sys.argv[1:] if arg != "--"]
|
||||
if any(arg in {"-h", "--help"} for arg in args):
|
||||
print_help()
|
||||
return
|
||||
if "--list" in args:
|
||||
print("\n".join(DEFAULT_EXAMPLES))
|
||||
return
|
||||
|
||||
examples = args or DEFAULT_EXAMPLES
|
||||
results = [run_example(example) for example in examples]
|
||||
print_table(results)
|
||||
if any(not result.ok for result in results):
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def print_help() -> None:
|
||||
print(
|
||||
"Run validated Luminal examples, validate textual output, and summarize perf.\n"
|
||||
"\n"
|
||||
"Usage:\n"
|
||||
" cargo examples\n"
|
||||
" cargo examples llama qwen whisper\n"
|
||||
"\n"
|
||||
"Options:\n"
|
||||
" --list Print the default validated examples\n"
|
||||
" -h, --help\n"
|
||||
"\n"
|
||||
f"The default set matches the Modal examples CI: {', '.join(DEFAULT_EXAMPLES)}."
|
||||
)
|
||||
|
||||
|
||||
def run_example(example: str) -> ExampleResult:
|
||||
cargo_args = EXAMPLE_CARGO_ARGS.get(example)
|
||||
if cargo_args is None:
|
||||
known = ", ".join(DEFAULT_EXAMPLES)
|
||||
return ExampleResult(example, False, error=f"unknown example; known examples: {known}")
|
||||
|
||||
print(f"\n=== Running {example} ===")
|
||||
print(f"$ cargo {' '.join(cargo_args)}")
|
||||
started = time.monotonic()
|
||||
env = os.environ.copy()
|
||||
env.setdefault("CUDARC_CUDA_VERSION", "12080")
|
||||
process = subprocess.Popen(
|
||||
["cargo", *cargo_args],
|
||||
cwd=repo_root(),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
wall_s = time.monotonic() - started
|
||||
metrics = parse_metrics(output)
|
||||
|
||||
if return_code:
|
||||
return ExampleResult(
|
||||
example,
|
||||
False,
|
||||
metrics=metrics,
|
||||
wall_s=wall_s,
|
||||
error=f"process exited with code {return_code}",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_output(example, output)
|
||||
except Exception as exc:
|
||||
return ExampleResult(example, False, metrics=metrics, wall_s=wall_s, error=str(exc))
|
||||
|
||||
return ExampleResult(example, True, metrics=metrics, wall_s=wall_s)
|
||||
|
||||
|
||||
def repo_root() -> str:
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def parse_metrics(output: str) -> Metrics:
|
||||
metrics = Metrics()
|
||||
for line in output.splitlines():
|
||||
if "TTFT:" in line:
|
||||
metrics.ttft_ms = parse_number_after(line, "TTFT:")
|
||||
if "TPOT:" in line:
|
||||
metrics.tpot_ms = parse_number_after(line, "TPOT:")
|
||||
if "tok/s" in line:
|
||||
metrics.tps = parse_tok_per_second(line)
|
||||
if metrics.tps is None and metrics.tpot_ms:
|
||||
metrics.tps = 1000.0 / metrics.tpot_ms
|
||||
return metrics
|
||||
|
||||
|
||||
def parse_number_after(line: str, marker: str) -> float | None:
|
||||
tail = line.split(marker, 1)[1].lstrip()
|
||||
chars = []
|
||||
for char in tail:
|
||||
if char.isdigit() or char == ".":
|
||||
chars.append(char)
|
||||
else:
|
||||
break
|
||||
if not chars:
|
||||
return None
|
||||
return float("".join(chars))
|
||||
|
||||
|
||||
def parse_tok_per_second(line: str) -> float | None:
|
||||
head = line.split("tok/s", 1)[0].rstrip(" (")
|
||||
parts = head.split()
|
||||
if not parts:
|
||||
return None
|
||||
try:
|
||||
return float(parts[-1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def print_table(results: list[ExampleResult]) -> None:
|
||||
print("\nSummary")
|
||||
print(f"{'example':<14} {'status':<8} {'TTFT ms':>10} {'TPOT ms':>10} {'tok/s':>10} {'wall s':>10}")
|
||||
print("-" * 68)
|
||||
for result in results:
|
||||
status = "ok" if result.ok else "failed"
|
||||
print(
|
||||
f"{result.name:<14} {status:<8} "
|
||||
f"{format_metric(result.metrics.ttft_ms):>10} "
|
||||
f"{format_metric(result.metrics.tpot_ms):>10} "
|
||||
f"{format_metric(result.metrics.tps):>10} "
|
||||
f"{result.wall_s:>10.1f}"
|
||||
)
|
||||
if result.error:
|
||||
print(f" error: {result.error}")
|
||||
|
||||
|
||||
def format_metric(value: float | None) -> str:
|
||||
return "-" if value is None else f"{value:.2f}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
ci/metal_llama_1b_example.py
Normal file
48
ci/metal_llama_1b_example.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
sys.path.insert(0, os.path.join(repo_root, "ci"))
|
||||
from example_output import validate_output
|
||||
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "luminal_metal", "--example", "llama_1b"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("Llama 1B Metal example did not complete generation")
|
||||
validate_output("llama", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
ci/metal_qwen_example.py
Normal file
46
ci/metal_qwen_example.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "qwen", "--features", "metal"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("qwen Metal example did not complete generation")
|
||||
validate_output("qwen", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
ci/modal_cargo_test.py
Normal file
74
ci/modal_cargo_test.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import shlex
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
modal_timeout = int(os.environ.get("MODAL_TIMEOUT", "7200"))
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=modal_timeout,
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
# Detect GPU compute capability
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
test_args = shlex.split(os.environ.get("CARGO_TEST_ARGS", "--test-threads=1"))
|
||||
cmd = [
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
*test_args,
|
||||
]
|
||||
print("Running:", " ".join(cmd), flush=True)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"CUDA_COMPUTE_CAP": compute_cap,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_cargo_test.remote()
|
||||
103
ci/modal_example.py
Normal file
103
ci/modal_example.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import modal
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
|
||||
app = modal.App(f"luminal-ci-{example}")
|
||||
|
||||
hf_cache = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"qwen": ["--features", "cuda"],
|
||||
}
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
)
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
)
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
sys.path.insert(0, f"{WORKDIR}/ci")
|
||||
from example_output import validate_output
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", *EXAMPLE_CARGO_ARGS.get(example, [])],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
)
|
||||
validate_output(example, output)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_example.remote(example)
|
||||
@@ -39,7 +39,7 @@ fn run_metal_pattern_benchmark(
|
||||
let mut cx = Graph::default();
|
||||
pattern.build_graph(&mut cx, *size);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, 5);
|
||||
let mut rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
|
||||
@@ -41,7 +41,7 @@ struct PreparedBench {
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, 5);
|
||||
let rt = cx.search(rt, CompileOptions::new(5));
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
|
||||
@@ -41,7 +41,7 @@ mod metal_backend {
|
||||
const NAME: &'static str = "Metal";
|
||||
|
||||
fn build_search_space(cx: &mut Graph) {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,13 +106,13 @@ impl Case {
|
||||
let out = match self {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor(size);
|
||||
x.clone() * x
|
||||
x * x
|
||||
}
|
||||
Case::Sigmoid => cx.tensor(size).sigmoid(),
|
||||
Case::Tanh => cx.tensor(size).tanh(),
|
||||
Case::GeluInner => {
|
||||
let x = cx.tensor(size);
|
||||
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
|
||||
(0.797_884_6_f32 * x * (1. + 0.044_715_f32 * x * x)).tanh()
|
||||
}
|
||||
Case::Gelu => cx.tensor(size).gelu(),
|
||||
Case::LayerNorm => {
|
||||
@@ -447,10 +447,10 @@ where
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty() {
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty()
|
||||
&& let Some(ref backend) = backend_analysis
|
||||
{
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
|
||||
// Trace facts for explicit variables.
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
use itertools::Itertools;
|
||||
use luminal::{prelude::FxHashMap, shape::Expression};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CStructType {
|
||||
Float,
|
||||
FloatArr(usize),
|
||||
Int,
|
||||
IntArr(usize),
|
||||
Long,
|
||||
LongArr(usize),
|
||||
Bool,
|
||||
BoolArr(usize),
|
||||
Ptr,
|
||||
PtrArr(usize),
|
||||
Bytes(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CStruct<'a> {
|
||||
buf: Vec<u8>,
|
||||
max_align: usize,
|
||||
struct_types: Vec<(String, CStructType)>,
|
||||
expressions: Option<&'a FxHashMap<Expression, i32>>,
|
||||
pub(crate) recorded_expressions: Vec<Expression>,
|
||||
}
|
||||
|
||||
impl<'a> CStruct<'a> {
|
||||
pub fn new(expressions: Option<&'a FxHashMap<Expression, i32>>) -> Self {
|
||||
Self {
|
||||
max_align: 1,
|
||||
struct_types: vec![],
|
||||
buf: vec![],
|
||||
expressions,
|
||||
recorded_expressions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn align_to(&mut self, align: usize) {
|
||||
self.max_align = self.max_align.max(align);
|
||||
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
if rem != 0 {
|
||||
let pad = align - rem;
|
||||
self.buf.extend(std::iter::repeat_n(0u8, pad));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int(mut self, name: impl ToString, v: i32) -> Self {
|
||||
self.struct_types.push((name.to_string(), CStructType::Int));
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn int_arr(mut self, name: impl ToString, vs: &[i32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::IntArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expr(mut self, name: impl ToString, v: impl Into<Expression>) -> Self {
|
||||
if let Some(expressions) = self.expressions {
|
||||
self.struct_types.push((name.to_string(), CStructType::Int));
|
||||
let v = expressions[&v.into()];
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
} else {
|
||||
self.recorded_expressions.push(v.into());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expr_arr(mut self, name: impl ToString, vs: &[Expression]) -> Self {
|
||||
if let Some(expressions) = self.expressions {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::IntArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
let v = expressions[&v];
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
} else {
|
||||
self.recorded_expressions.extend(vs.iter().copied());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn long(mut self, name: impl ToString, v: i64) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Long));
|
||||
self.align_to(8);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn long_arr(mut self, name: impl ToString, vs: &[i64]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::LongArr(vs.len())));
|
||||
self.align_to(8);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn float(mut self, name: impl ToString, v: f32) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Float));
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn float_arr(mut self, name: impl ToString, vs: &[f32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::FloatArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn bool(mut self, name: impl ToString, v: bool) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Bool));
|
||||
self.align_to(1);
|
||||
self.buf.push(if v { 1 } else { 0 });
|
||||
self
|
||||
}
|
||||
|
||||
pub fn bool_arr(mut self, name: impl ToString, vs: &[bool]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::BoolArr(vs.len())));
|
||||
self.align_to(1);
|
||||
for &v in vs {
|
||||
self.buf.push(if v { 1 } else { 0 });
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ptr_const_f32(mut self, name: impl ToString, p: *const f32) -> Self {
|
||||
self.struct_types.push((name.to_string(), CStructType::Ptr));
|
||||
let ptr_size = std::mem::size_of::<usize>(); // usually 8
|
||||
let ptr_align = ptr_size;
|
||||
self.align_to(ptr_align);
|
||||
|
||||
let addr = p as usize;
|
||||
let bytes = addr.to_ne_bytes();
|
||||
|
||||
self.buf.extend_from_slice(&bytes[..ptr_size]);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ptr_mut_f32(self, name: impl ToString, p: *mut f32) -> Self {
|
||||
self.ptr_const_f32(name, p as *const f32)
|
||||
}
|
||||
|
||||
pub fn ptr_const_f32_arr(mut self, name: impl ToString, p: &[*const f32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::PtrArr(p.len())));
|
||||
let ptr_size = std::mem::size_of::<usize>(); // usually 8
|
||||
let ptr_align = ptr_size;
|
||||
self.align_to(ptr_align);
|
||||
|
||||
for &p in p {
|
||||
let addr = p as usize;
|
||||
let bytes = addr.to_ne_bytes();
|
||||
self.buf.extend_from_slice(&bytes[..ptr_size]);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the current size of the buffer after alignment for a pointer field.
|
||||
/// Useful for computing field offsets.
|
||||
pub fn current_size(&self) -> usize {
|
||||
let ptr_align = std::mem::size_of::<usize>();
|
||||
let len = self.buf.len();
|
||||
let rem = len % ptr_align;
|
||||
if rem != 0 {
|
||||
len + (ptr_align - rem)
|
||||
} else {
|
||||
len
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad the struct size to a multiple of max_align.
|
||||
pub fn finish_struct(mut self) -> Vec<u8> {
|
||||
assert!(
|
||||
self.expressions.is_some(),
|
||||
"Can only create cstruct bytes when expression map is provided!"
|
||||
);
|
||||
let align = self.max_align;
|
||||
if align > 1 {
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
if rem != 0 {
|
||||
let pad = align - rem;
|
||||
self.buf.extend(std::iter::repeat_n(0u8, pad));
|
||||
}
|
||||
}
|
||||
self.buf
|
||||
}
|
||||
|
||||
/// Returns (size, alignment) of the struct.
|
||||
pub fn size_and_align(&self) -> (usize, usize) {
|
||||
let align = self.max_align;
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
let size = if rem != 0 { len + (align - rem) } else { len };
|
||||
(size, align)
|
||||
}
|
||||
|
||||
/// Insert a raw byte field (e.g., another struct).
|
||||
/// `align` must be the alignment of the nested struct.
|
||||
pub fn bytes(mut self, align: usize, name: impl ToString, data: &[u8]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Bytes(data.len())));
|
||||
self.align_to(align);
|
||||
self.buf.extend_from_slice(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CStruct<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = self
|
||||
.struct_types
|
||||
.iter()
|
||||
.map(|(name, ty)| match ty {
|
||||
CStructType::Bool => format!("bool {name};"),
|
||||
CStructType::BoolArr(l) => format!("bool {name}[{l}];"),
|
||||
CStructType::Float => format!("float {name};"),
|
||||
CStructType::FloatArr(l) => format!("float {name}[{l}];"),
|
||||
CStructType::Int => format!("int {name};"),
|
||||
CStructType::IntArr(l) => format!("int {name}[{l}];"),
|
||||
CStructType::Long => format!("long {name};"),
|
||||
CStructType::LongArr(l) => format!("long {name}[{l}];"),
|
||||
CStructType::Ptr => format!("float* {name};"),
|
||||
CStructType::PtrArr(l) => format!("float* {name}[{l}];"),
|
||||
CStructType::Bytes(l) => format!("char payload[{l}];"),
|
||||
})
|
||||
.join("\n");
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
const int N_OPS = 0;
|
||||
const int N_TIMING_SLOTS = 0;
|
||||
const int N_TASKS = 0; // Rendered at compile time
|
||||
//%n_barriers_const%
|
||||
|
||||
enum OpCode {
|
||||
//%extra_op_codes%
|
||||
};
|
||||
|
||||
//%extra_op_structs%
|
||||
|
||||
union Payload {
|
||||
//%extra_op_payloads%
|
||||
};
|
||||
|
||||
struct Task {
|
||||
OpCode op;
|
||||
int range;
|
||||
int remaining;
|
||||
int in_dep_a_stride;
|
||||
int in_dep_a_base;
|
||||
int in_dep_b_stride;
|
||||
int in_dep_b_base;
|
||||
int in_dep_c_stride;
|
||||
int in_dep_c_base;
|
||||
int out_dep_stride;
|
||||
int out_dep_base;
|
||||
int source_indices[6];
|
||||
int out_index;
|
||||
Payload payload;
|
||||
};
|
||||
|
||||
struct SMEvent {
|
||||
unsigned long long start;
|
||||
unsigned long long stop;
|
||||
int event;
|
||||
};
|
||||
|
||||
//%constants%
|
||||
|
||||
__device__ __noinline__ int eval_expression(int expression, int const_z) {
|
||||
switch (expression) {
|
||||
//%expr_fns%
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ unsigned long long read_globaltimer() {
|
||||
unsigned long long t;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t));
|
||||
return t;
|
||||
}
|
||||
|
||||
//%extra_op_functions%
|
||||
|
||||
//%extra_prologue_functions%
|
||||
|
||||
__device__ __forceinline__ void nanosleep(unsigned int cycles) {
|
||||
asm volatile("nanosleep.u32 %0;" ::"r"(cycles));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_load_acquire(int *addr) {
|
||||
int val;
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(val) : "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
struct NextTask {
|
||||
int current;
|
||||
int task_idx;
|
||||
};
|
||||
|
||||
// Lock-free task fetching using atomicSub for claiming (reduces CAS contention)
|
||||
// remaining encoding:
|
||||
// -1 = uninitialized
|
||||
// > 0 = iterations remaining (atomicSub to claim, iteration = old - 1)
|
||||
// <= 0 = exhausted
|
||||
__device__ inline bool fetch_next_task(Task *tasks, int num_tasks, int *head,
|
||||
NextTask *out) {
|
||||
while (true) {
|
||||
int idx = atomic_load_acquire(head);
|
||||
if (idx >= num_tasks)
|
||||
return false;
|
||||
|
||||
Task *t = &tasks[idx];
|
||||
int remaining = atomicAdd(&t->remaining, 0);
|
||||
|
||||
// Handle uninitialized task - one CAS to initialize
|
||||
if (remaining == -1) {
|
||||
int range = eval_expression(t->range, 0);
|
||||
atomicCAS(&t->remaining, -1, range);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Task already exhausted, advance head
|
||||
if (remaining <= 0) {
|
||||
atomicMax(head, idx + 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Claim via atomicSub - guaranteed to make progress, no CAS retry
|
||||
int old = atomicSub(&t->remaining, 1);
|
||||
|
||||
if (old > 0) {
|
||||
out->task_idx = idx;
|
||||
out->current = old - 1;
|
||||
if (old == 1) {
|
||||
atomicMax(head, idx + 1);
|
||||
}
|
||||
// DEBUG: This path indicates successful task claim
|
||||
return true;
|
||||
}
|
||||
|
||||
// Race: exhausted between check and atomicSub, advance head
|
||||
atomicMax(head, idx + 1);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void record_event(SMEvent *__restrict__ timings,
|
||||
int *event_idx, int event_type) {
|
||||
if (*event_idx < N_TIMING_SLOTS) {
|
||||
unsigned long long now = read_globaltimer();
|
||||
if (*event_idx > 0) { // record the end of the previous op
|
||||
timings[*event_idx - 1].stop = now;
|
||||
}
|
||||
timings[*event_idx].start = now;
|
||||
timings[*event_idx].stop = 0ull;
|
||||
timings[*event_idx].event = event_type;
|
||||
(*event_idx)++;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Kernel params: internal buffers in order, then dyn_dims
|
||||
// tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
|
||||
__global__ void worker_kernel(
|
||||
Task* __restrict__ tasks,
|
||||
int* __restrict__ head,
|
||||
int* __restrict__ ready,
|
||||
int* __restrict__ queue_lock,
|
||||
SMEvent* __restrict__ timings,
|
||||
unsigned long long* __restrict__ start_times,
|
||||
float* const* buffers,
|
||||
int* __restrict__ dyn_dims
|
||||
) {
|
||||
// Constants N_TASKS and N_BARRIERS are baked into the kernel string
|
||||
|
||||
// Note: Reset is now done on host side in pre_execute
|
||||
// All buffers (head, queue_lock, ready, tasks) are pre-initialized
|
||||
|
||||
// DEBUG: Count tasks fetched (use queue_lock as counter since it's not being used)
|
||||
// Note: queue_lock is in internal_bufs[3]
|
||||
|
||||
__shared__ NextTask nt;
|
||||
__shared__ int done;
|
||||
__shared__ int dep_out;
|
||||
__shared__ bool run_a_prologue;
|
||||
__shared__ bool run_b_prologue;
|
||||
__shared__ bool run_c_prologue;
|
||||
__shared__ bool stop_wait_loop;
|
||||
__shared__ float scratchpad[8192]; // 32 KB scratchpad
|
||||
__shared__ const float* source_ptrs[6];
|
||||
__shared__ float* out_ptr;
|
||||
int recorded_event = 0;
|
||||
timings += blockIdx.x * N_TIMING_SLOTS;
|
||||
if (threadIdx.x == 0) {
|
||||
start_times[blockIdx.x] = read_globaltimer();
|
||||
}
|
||||
while (true) {
|
||||
if (threadIdx.x == 0) {
|
||||
record_event(timings, &recorded_event, 0); // Record issue start
|
||||
done = !fetch_next_task(tasks, N_TASKS, head, &nt);
|
||||
}
|
||||
__syncthreads();
|
||||
if (done)
|
||||
break;
|
||||
|
||||
const Task *t = &tasks[nt.task_idx];
|
||||
|
||||
// Resolve buffer pointers from indices
|
||||
if (threadIdx.x == 0) {
|
||||
source_ptrs[0] = buffers[t->source_indices[0]];
|
||||
source_ptrs[1] = buffers[t->source_indices[1]];
|
||||
source_ptrs[2] = buffers[t->source_indices[2]];
|
||||
source_ptrs[3] = buffers[t->source_indices[3]];
|
||||
source_ptrs[4] = buffers[t->source_indices[4]];
|
||||
source_ptrs[5] = buffers[t->source_indices[5]];
|
||||
out_ptr = buffers[t->out_index];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int dep_a = 0;
|
||||
int dep_b = 0;
|
||||
int dep_c = 0;
|
||||
|
||||
// Thread 0 calculates dependencies and waits for inputs
|
||||
if (threadIdx.x == 0) {
|
||||
// Note: atomic_load_acquire provides visibility for ready array
|
||||
dep_a = (t->in_dep_a_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_a_base, 0) +
|
||||
eval_expression(t->in_dep_a_stride, nt.current)));
|
||||
dep_b = (t->in_dep_b_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_b_base, 0) +
|
||||
eval_expression(t->in_dep_b_stride, nt.current)));
|
||||
dep_c = (t->in_dep_c_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_c_base, 0) +
|
||||
eval_expression(t->in_dep_c_stride, nt.current)));
|
||||
dep_out = eval_expression(t->out_dep_base, 0) +
|
||||
eval_expression(t->out_dep_stride, nt.current);
|
||||
|
||||
// Increment the output barrier to signal an op is in-flight
|
||||
atomicAdd(&ready[dep_out], 1);
|
||||
|
||||
record_event(timings, &recorded_event, 1); // Record wait start
|
||||
|
||||
// Wait on input dependencies and run prologues as inputs become ready
|
||||
run_a_prologue = false;
|
||||
run_b_prologue = false;
|
||||
run_c_prologue = false;
|
||||
stop_wait_loop = false;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bool a_done = false, b_done = false, c_done = false, tmp;
|
||||
// Optimize: if deps are same, reuse atomic load result
|
||||
const bool ab_same = (dep_a == dep_b);
|
||||
const bool ac_same = (dep_a == dep_c);
|
||||
const bool bc_same = (dep_b == dep_c);
|
||||
|
||||
while (true) {
|
||||
if (threadIdx.x == 0) {
|
||||
// Derive x_done and run_x_prologue with optimized atomic loads
|
||||
if (!a_done) {
|
||||
tmp = atomic_load_acquire(&ready[dep_a]) <= 0;
|
||||
if (tmp) {
|
||||
run_a_prologue = true;
|
||||
a_done = true;
|
||||
// Propagate to same deps
|
||||
if (ab_same) {
|
||||
run_b_prologue = true;
|
||||
b_done = true;
|
||||
}
|
||||
if (ac_same) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!b_done && !ab_same) {
|
||||
tmp = atomic_load_acquire(&ready[dep_b]) <= 0;
|
||||
if (tmp) {
|
||||
run_b_prologue = true;
|
||||
b_done = true;
|
||||
if (bc_same) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!c_done && !ac_same && !bc_same) {
|
||||
tmp = atomic_load_acquire(&ready[dep_c]) <= 0;
|
||||
if (tmp) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
if (a_done && b_done && c_done)
|
||||
stop_wait_loop = true;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Early exit if all dependencies satisfied (skip prologue checks)
|
||||
if (stop_wait_loop)
|
||||
break;
|
||||
|
||||
if (run_a_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_a_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_a_prologue = false;
|
||||
}
|
||||
}
|
||||
if (run_b_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_b_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_b_prologue = false;
|
||||
}
|
||||
}
|
||||
if (run_c_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_c_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_c_prologue = false;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
record_event(timings, &recorded_event,
|
||||
t->op + 2); // Record main op, ends Wait
|
||||
|
||||
// Execute main operation
|
||||
switch (t->op) {
|
||||
//%extra_op_calls%
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Arrive at output barrier
|
||||
if (threadIdx.x == 0) {
|
||||
__threadfence();
|
||||
atomicSub(&ready[dep_out], 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && recorded_event > 0) {
|
||||
timings[recorded_event - 1].stop = read_globaltimer();
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,82 +0,0 @@
|
||||
//! Compiles BlockOp subgraphs into KernelOp (MegakernelOp).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
op::LLIROp,
|
||||
prelude::{
|
||||
FxHashMap, FxHashSet, NodeIndex,
|
||||
petgraph::{Direction, visit::EdgeRef},
|
||||
},
|
||||
};
|
||||
use tracing::{Level, span};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::partition_marked_convex};
|
||||
|
||||
use super::{BlockOp, MegakernelOp};
|
||||
|
||||
/// Compile all BlockOp subgraphs in the LLIR graph into MegakernelOps.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Finds all BlockOp nodes in the graph
|
||||
/// 2. Partitions them into convex subgraphs
|
||||
/// 3. For each subgraph, creates a MegakernelOp (which implements KernelOp)
|
||||
/// 4. Adds the megakernel node to the llir_graph with appropriate edges
|
||||
///
|
||||
/// Returns mappings needed for the kernel compilation phase:
|
||||
/// - `megakernel_to_blocks`: Maps each megakernel node to the BlockOp nodes it contains
|
||||
/// (used to include block op nodes in the kernel's inputs for buffer pointer collection)
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn block_to_kernel(
|
||||
llir_graph: &mut LLIRGraph,
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> FxHashMap<NodeIndex, Vec<NodeIndex>> {
|
||||
let _span = span!(Level::TRACE, "block_to_kernel").entered();
|
||||
|
||||
let block_ops_in_graph = llir_graph
|
||||
.node_indices()
|
||||
.filter(|n| llir_graph[*n].to_dialect::<dyn BlockOp>().is_some())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
if block_ops_in_graph.is_empty() {
|
||||
return FxHashMap::default();
|
||||
}
|
||||
|
||||
let mut megakernel_to_blocks: FxHashMap<NodeIndex, Vec<NodeIndex>> = FxHashMap::default();
|
||||
|
||||
for subgraph in partition_marked_convex(llir_graph, &block_ops_in_graph).unwrap() {
|
||||
// Create MegakernelOp which implements KernelOp
|
||||
let megakernel_op = MegakernelOp::new(llir_graph, &subgraph, cuda_stream, kernel_cache);
|
||||
|
||||
// Add megakernel node to llir_graph as a KernelOp
|
||||
let megakernel_node =
|
||||
llir_graph.add_node(LLIROp::new(Box::new(megakernel_op) as Box<dyn KernelOp>));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
// These edges establish exec_graph dependencies (megakernel waits for inputs)
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to megakernel node
|
||||
// Note: We don't add edges TO external consumers because the original
|
||||
// block op -> consumer edges still exist and will be used for exec_graph ordering
|
||||
for input in &external_inputs {
|
||||
llir_graph.add_edge(*input, megakernel_node, ());
|
||||
}
|
||||
|
||||
// Map megakernel node to all block op nodes it contains
|
||||
megakernel_to_blocks.insert(megakernel_node, subgraph.into_iter().collect());
|
||||
}
|
||||
|
||||
megakernel_to_blocks
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, IR, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
|
||||
|
||||
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
|
||||
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
// Strip quotes if present (egglog strings are stored with quotes)
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasSgemmV2 {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
/// Lazily initialized cuBLAS handle - created on first execute
|
||||
cublas: OnceLock<Arc<CudaBlas>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasSgemmV2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
cublas: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("a", IR),
|
||||
("b", IR),
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[children[5]].0;
|
||||
let b_layout_str = &egraph.enodes[children[6]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cublas: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, vec![children[0], children[1]])
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasSgemmV2 {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as i32;
|
||||
let n = self.n.exec(dyn_map).unwrap() as i32;
|
||||
let k = self.k.exec(dyn_map).unwrap() as i32;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i32;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * 4,
|
||||
b_buf.len(),
|
||||
k * n * 4,
|
||||
c_buf.len(),
|
||||
m * n * 4
|
||||
);
|
||||
let _sgemm_span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLAS_SGEMM_V2",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
?a_layout,
|
||||
?b_layout,
|
||||
)
|
||||
.entered();
|
||||
|
||||
// Use shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
|
||||
|
||||
// Set the stream for this operation (cuBLAS handle can work with any stream)
|
||||
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
|
||||
unsafe {
|
||||
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
|
||||
}
|
||||
|
||||
let status = unsafe {
|
||||
cublasSgemm_v2(
|
||||
*cublas.handle(),
|
||||
a_layout,
|
||||
b_layout,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha as *const f32,
|
||||
a_ptr as *const f32,
|
||||
lda,
|
||||
b_ptr as *const f32,
|
||||
ldb,
|
||||
&beta as *const f32,
|
||||
c_ptr as *mut f32,
|
||||
ldc,
|
||||
)
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLAS SGEMM TN failed with status: {:?}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
|
||||
self.output_size() * 4
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape (must be exactly 2D for cuBLAS)
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape (must be exactly 2D for cuBLAS)
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape (must be exactly 2D for cuBLAS)
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape (must be exactly 2D for cuBLAS)
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
@@ -1,392 +0,0 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, EXPRESSION, IR, STRING},
|
||||
extract_dtype, extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
CudaBlasLT, MatmulShared,
|
||||
sys::{
|
||||
cublasComputeType_t, cublasLtMatmul, cublasLtMatmulAlgoGetHeuristic,
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescCreate, cublasLtMatmulDescDestroy,
|
||||
cublasLtMatmulDescSetAttribute, cublasLtMatmulHeuristicResult_t,
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t,
|
||||
cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceDestroy,
|
||||
cublasLtMatmulPreferenceSetAttribute, cublasLtMatrixLayout_t,
|
||||
cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutDestroy, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasLt {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
dtype: DType,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasLt {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
dtype: DType::F32,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLt {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"cublaslt",
|
||||
&[
|
||||
("a", IR),
|
||||
("b", IR),
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[children[5]].0;
|
||||
let b_layout_str = &egraph.enodes[children[6]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
|
||||
// Extract dtype from egglog
|
||||
let dtype = extract_dtype(egraph, children[10]);
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
dtype,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, vec![children[0], children[1]])
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert DType to CUDA types for cuBLAS LT
|
||||
/// Returns (matrix_dtype, compute_type, scale_dtype)
|
||||
fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cudaDataType) {
|
||||
match dtype {
|
||||
// F64: matrix=f64, compute=f64, scale=f64
|
||||
DType::F64 => (
|
||||
cudaDataType::CUDA_R_64F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_64F,
|
||||
cudaDataType::CUDA_R_64F,
|
||||
),
|
||||
// F32: matrix=f32, compute=f32, scale=f32
|
||||
DType::F32 => (
|
||||
cudaDataType::CUDA_R_32F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// F16: matrix=f16, compute=f32 (FP32 accumulation for accuracy), scale=f32
|
||||
DType::F16 => (
|
||||
cudaDataType::CUDA_R_16F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// BF16: matrix=bf16, compute=f32 with tensor cores, scale=f32
|
||||
DType::Bf16 => (
|
||||
cudaDataType::CUDA_R_16BF,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// TF32: stored as f32, use fast TF32 tensor core path
|
||||
DType::TF32 => (
|
||||
cudaDataType::CUDA_R_32F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// FP8 E4M3: matrix=fp8_e4m3, compute=f32, scale=f32
|
||||
DType::F8E4M3 => (
|
||||
cudaDataType::CUDA_R_8F_E4M3,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
// FP8 E5M2: matrix=fp8_e5m2, compute=f32, scale=f32
|
||||
DType::F8E5M2 => (
|
||||
cudaDataType::CUDA_R_8F_E5M2,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
cudaDataType::CUDA_R_32F,
|
||||
),
|
||||
DType::Int => panic!("cuBLAS LT does not support integer matmul"),
|
||||
DType::Bool => panic!("cuBLAS LT does not support bool matmul"),
|
||||
other => todo!("cuBLAS LT matmul not yet implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as u64;
|
||||
let n = self.n.exec(dyn_map).unwrap() as u64;
|
||||
let k = self.k.exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i64;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
|
||||
|
||||
// Get CUDA types based on dtype
|
||||
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
|
||||
let element_size = (self.dtype.bits() / 8) as u64;
|
||||
assert!(
|
||||
element_size > 0,
|
||||
"cuBLAS LT does not support sub-byte dtype {}",
|
||||
self.dtype
|
||||
);
|
||||
|
||||
// Alpha/beta scale values (all dtypes use F32 scale type)
|
||||
let alpha_f32: f32 = 1.0;
|
||||
let beta_f32: f32 = 0.0;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug tracing
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * element_size,
|
||||
b_buf.len(),
|
||||
k * n * element_size,
|
||||
c_buf.len(),
|
||||
m * n * element_size
|
||||
);
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT",
|
||||
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut b_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut c_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
let mut preference: cublasLtMatmulPreference_t = std::ptr::null_mut();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
// Allocate workspace (32 MiB)
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
|
||||
unsafe {
|
||||
// Create matmul descriptor (compute_type, scale_type for alpha/beta)
|
||||
cublasLtMatmulDescCreate(&mut matmul_desc, compute_type, scale_dtype).result()?;
|
||||
|
||||
// Set transpose attributes
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&a_layout as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&b_layout as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<cublasOperation_t>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Create matrix layout descriptors
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
|
||||
cublasLtMatrixLayoutCreate(&mut a_desc, cuda_dtype, a_rows, a_cols, lda).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
|
||||
|
||||
// Create preference and set workspace size
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&WORKSPACE_SIZE as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Get heuristic (best algorithm)
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
c_desc, // D layout same as C
|
||||
preference,
|
||||
1, // Request 1 result
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
// Cleanup before returning error
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
|
||||
// All dtypes use F32 scale type for alpha/beta
|
||||
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
|
||||
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
matmul_desc,
|
||||
alpha_ptr,
|
||||
a_ptr as *const std::ffi::c_void,
|
||||
a_desc,
|
||||
b_ptr as *const std::ffi::c_void,
|
||||
b_desc,
|
||||
beta_ptr,
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
c_desc, // D layout same as C
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
// Cleanup
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
cublasLtMatrixLayoutDestroy(a_desc);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
moe::GLUMoE,
|
||||
);
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
/// Execute the operation with access to buffers via a map.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `stream` - The CUDA stream to execute on
|
||||
/// * `self_node` - The NodeIndex of this op in the llir_graph (used as output buffer)
|
||||
/// * `inputs` - NodeIndices of input nodes (in edge order from the graph)
|
||||
/// * `buffers` - Map from NodeIndex to device buffer for all allocated nodes
|
||||
/// * `dyn_map` - Dynamic dimension values
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Returns the output buffer size in elements.
|
||||
/// Return 0 if this op doesn't have a single output buffer (e.g., CudaGraphOp).
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns additional nodes (beyond graph edges) that this op needs buffers for.
|
||||
///
|
||||
/// For most ops, this returns empty (buffers determined by graph edges).
|
||||
/// For CudaGraphOp, this returns all internal kernel nodes.
|
||||
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
/// For CudaGraphOp, this returns sizes for all internal kernel output buffers.
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
FxHashMap::default()
|
||||
}
|
||||
|
||||
/// Returns the name of this host op for stats reporting, or None if not reportable.
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Iota ?gu_io ?gu_iota_base_range))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Mul ?gu_mul_base_shape ?topk_idx ?gu_mul_base_a_stride ?gu_iota_base ?gu_mul_base_b_stride ?gu_mul_base_out_stride))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Cast ?gu_mul_base ?gu_cast_base_size (F32)))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Iota (MIter) ?gu_iota_within_range))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Cast ?gu_iota_within ?gu_cast_within_size (F32)))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Add ?gu_add_shape ?gu_cast_base ?gu_add_a_stride ?gu_cast_within ?gu_add_b_stride ?gu_add_out_stride))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Cast ?gu_add_idx ?gu_cast_idx_size (Int)))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Gather ?gu_cast_idx ?gu_gather_idx_shape ?gu_gather_idx_stride ?gate_up_w ?gu_gather_data_shape ?gu_gather_data_stride))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Cast ?gu_gathered ?gu_f32_size (F32)))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Mul ?gu_matmul_mul_shape ?x ?gu_matmul_a_stride ?gu_f32 ?gu_matmul_b_stride ?gu_matmul_mul_out_stride))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_mul ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Iota ?up_iota_expr ?up_iota_range))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Gather ?up_iota ?up_gather_idx_shape ?up_gather_idx_stride ?gu_matmul ?up_gather_data_shape ?up_gather_data_stride))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Constant -1.000000))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Mul ?silu_shape1 ?gu_matmul ?silu_a_stride1 ?neg1 ?silu_b_stride1 ?silu_out_stride1))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Constant 1.442695))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Mul ?silu_shape2 ?neg_gate ?silu_a_stride2 ?log2e ?silu_b_stride2 ?silu_out_stride2))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Exp2 ?silu_shape3 ?scaled ?silu_in_stride3 ?silu_out_stride3))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Constant 1.000000))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Add ?silu_shape4 ?exp2_val ?silu_a_stride4 ?one ?silu_b_stride4 ?silu_out_stride4))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Recip ?silu_shape5 ?plus1 ?silu_in_stride5 ?silu_out_stride5))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Mul ?silu_shape6 ?gu_matmul ?silu_a_stride6 ?sigmoid ?silu_b_stride6 ?silu_out_stride6))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Mul ?swiglu_shape ?silu_out ?swiglu_a_stride ?up_slice ?swiglu_b_stride ?swiglu_out_stride))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Iota ?dn_io ?dn_iota_base_range))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Mul ?dn_mul_base_shape ?topk_idx ?dn_mul_base_a_stride ?dn_iota_base ?dn_mul_base_b_stride ?dn_mul_base_out_stride))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Cast ?dn_mul_base ?dn_cast_base_size (F32)))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Iota (MIter) ?dn_iota_within_range))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Cast ?dn_iota_within ?dn_cast_within_size (F32)))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Add ?dn_add_shape ?dn_cast_base ?dn_add_a_stride ?dn_cast_within ?dn_add_b_stride ?dn_add_out_stride))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Cast ?dn_add_idx ?dn_cast_idx_size (Int)))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Gather ?dn_cast_idx ?dn_gather_idx_shape ?dn_gather_idx_stride ?down_w ?dn_gather_data_shape ?dn_gather_data_stride))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_f32 (Cast ?dn_gathered ?dn_f32_size (F32)))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Mul ?dn_matmul_mul_shape ?swiglu_out ?dn_matmul_a_stride ?dn_f32 ?dn_matmul_b_stride ?dn_matmul_mul_out_stride))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_mul ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Mul ?weighted_shape ?dn_matmul ?weighted_a_stride ?topk_vals ?weighted_b_stride ?weighted_out_stride))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Sum ?output_shape ?output_k ?weighted ?output_in_stride ?output_k_stride ?output_out_stride))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (GLUMoE ?x ?topk_idx ?topk_vals ?gate_up_w ?down_w
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,206 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
use cudarc::{
|
||||
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
|
||||
nvrtc::CompileOptions,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, IR},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (KernelMeanReduce,);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
pub struct KernelMeanReduce {
|
||||
out_shape: Vec<Expression>,
|
||||
iters: Expression,
|
||||
in_stride: Vec<Expression>,
|
||||
iter_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
impl EgglogOp for KernelMeanReduce {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"KernelMean",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("iters", EXPRESSION),
|
||||
("inp", IR),
|
||||
("strides", ELIST),
|
||||
("iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Disabled: the e-graph union introduced by this rule can cause the search
|
||||
// to select genomes with accumulated FP precision issues over many layers.
|
||||
// The unfused Sum + Mul(Recip(Cast(Iota))) path produces equivalent results.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
{
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap();
|
||||
let iters = extract_expr(egraph, children[1], expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, children[3], list_cache, expr_cache).unwrap();
|
||||
let iter_stride = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap();
|
||||
let dtype = extract_dtype(egraph, children[6]);
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
iters,
|
||||
in_stride,
|
||||
iter_stride,
|
||||
out_stride,
|
||||
dtype,
|
||||
}) as Box<dyn KernelOp>)
|
||||
},
|
||||
vec![children[2]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelMeanReduce {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.iters.dyn_vars())
|
||||
.chain(self.iter_stride.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block = 256; // 8 warps per block
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let iter_stride_of_i = self.iter_stride.to_kernel().replace("const_z", "i");
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void reduce_mean_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = blockIdx.x;
|
||||
long long n_elements = {n_outputs};
|
||||
if (const_z >= n_elements) return;
|
||||
|
||||
long long in_start = {in_index};
|
||||
long long iters = {iters};
|
||||
|
||||
{dtype} sum = 0;
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
sum += in[in_start + {iter_stride_of_i}];
|
||||
}}
|
||||
|
||||
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
in_index = flatten_strides(&self.out_shape, &self.in_stride).to_kernel(),
|
||||
out_index = flatten_strides(&self.out_shape, &self.out_stride).to_kernel(),
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
iters = self.iters.to_kernel(),
|
||||
iter_stride_of_i = iter_stride_of_i,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_mean_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.out_shape.iter().copied().product::<Expression>() * self.iters * self.dtype.bits())
|
||||
.ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
n_outputs * self.iters + n_outputs
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MeanReduce"
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
pub mod block;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::F64 => "double",
|
||||
DType::F32 => "float",
|
||||
DType::F16 => "half",
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
DType::U8 => "unsigned char",
|
||||
DType::Bool => "unsigned char",
|
||||
DType::F8E4M3 => "__nv_fp8_e4m3",
|
||||
DType::F8E5M2 => "__nv_fp8_e5m2",
|
||||
DType::F8UE8M0 => "__nv_fp8_e8m0",
|
||||
DType::F6E2M3 => "__nv_fp6_e2m3",
|
||||
DType::F6E3M2 => "__nv_fp6_e3m2",
|
||||
DType::F4E2M1 => "__nv_fp4_e2m1",
|
||||
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 273,
|
||||
"NVIDIA H100 PCIe" => 2_000,
|
||||
"NVIDIA H100 SXM" => 3_350,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in TFLOPs
|
||||
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 125, // forced to use tf32 flops
|
||||
"NVIDIA H100 PCIe" => 756,
|
||||
"NVIDIA H100 SXM" => 989,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef},
|
||||
base::OP_SORTS,
|
||||
},
|
||||
op::EgglogOp,
|
||||
};
|
||||
|
||||
pub type Ops = (Exp, Sigmoid);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp;
|
||||
impl EgglogOp for Exp {
|
||||
fn sort(&self) -> SortDef {
|
||||
OP_SORTS.unary("Exp")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?exp_const (Constant 1.442695))
|
||||
(= ?mul (Mul ?shape ?x ?x_stride ?exp_const ?const_stride ?intermediate_stride))
|
||||
(= ?exp2 (Exp2 ?shape ?mul ?intermediate_stride ?out_stride))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?exp (Exp ?shape ?x ?x_stride ?out_stride))
|
||||
(union ?exp2 ?exp)
|
||||
(set (dtype ?exp) ?dt)
|
||||
)
|
||||
)",
|
||||
)]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Sigmoid;
|
||||
impl EgglogOp for Sigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
OP_SORTS.unary("Sigmoid")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw("(rule
|
||||
(
|
||||
(= ?neg_input (Mul ?input_range ?input ?input_stride (Constant -1.0) ?const_stride ?intermediate_stride))
|
||||
(= ?exp (Exp ?input_range ?neg_input ?intermediate_stride ?exp_stride))
|
||||
(= ?plus_one (Add ?input_range ?exp ?exp_stride (Constant 1.0) ?const_stride ?plus_one_stride))
|
||||
(= ?sig_out (Recip ?input_range ?plus_one ?plus_one_stride ?out_stride))
|
||||
(= ?dt (dtype ?input))
|
||||
)
|
||||
(
|
||||
(let ?sig (Sigmoid ?input_range ?input ?input_stride ?out_stride))
|
||||
(union ?sig_out ?sig)
|
||||
(set (dtype ?sig) ?dt)
|
||||
)
|
||||
:name \"sigmoid\"
|
||||
)")]
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +0,0 @@
|
||||
pub mod utilities;
|
||||
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
@@ -1,494 +0,0 @@
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use luminal::egglog_utils::{
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use num_traits::{Num, Signed};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::runtime::{CudaRuntime, ToCudaInput};
|
||||
|
||||
/// Safety factor multiplied with epsilon for tolerance calculations
|
||||
pub const TOLERANCE_SAFETY_FACTOR: f32 = 2.0;
|
||||
|
||||
/// Number of genomes to fuzz per op test invocation.
|
||||
pub const GENOME_FUZZ_COUNT: usize = 20;
|
||||
|
||||
/// Trait for test-compatible data types that can be used in generic test functions.
|
||||
/// Bridges luminal's runtime types with candle's tensor types.
|
||||
pub trait TestDType:
|
||||
Clone + Sized + WithDType + PartialEq + Copy + std::fmt::Debug + 'static
|
||||
where
|
||||
Vec<Self>: ToCudaInput,
|
||||
{
|
||||
/// The corresponding luminal DType
|
||||
const DTYPE: luminal::dtype::DType;
|
||||
|
||||
/// Retrieve data from the runtime in this dtype
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self>;
|
||||
/// Extract a Vec from a candle Tensor
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self>;
|
||||
/// Compare two result vectors. Float types use tolerance; exact types use equality.
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32);
|
||||
}
|
||||
|
||||
impl TestDType for f32 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F32;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_f32(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<f32>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, rtol, atol);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for f16 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::F16;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_f16(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<f16>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, f16::from_f32(rtol), f16::from_f32(atol));
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for bf16 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Bf16;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_bf16(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<bf16>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], rtol: f32, atol: f32) {
|
||||
assert_close(a, b, bf16::from_f32(rtol), bf16::from_f32(atol));
|
||||
}
|
||||
}
|
||||
|
||||
impl TestDType for i32 {
|
||||
const DTYPE: luminal::dtype::DType = luminal::dtype::DType::Int;
|
||||
|
||||
fn get_from_runtime(rt: &CudaRuntime, id: NodeIndex) -> Vec<Self> {
|
||||
rt.get_i32(id)
|
||||
}
|
||||
fn candle_to_vec(tensor: &Tensor) -> Vec<Self> {
|
||||
tensor.to_vec1::<i32>().unwrap()
|
||||
}
|
||||
fn assert_match(a: &[Self], b: &[Self], _rtol: f32, _atol: f32) {
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn random_i32_vec(n: usize, seed: u64, low: i32, high: i32) -> Vec<i32> {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
(0..n).map(|_| rng.random_range(low..=high)).collect()
|
||||
}
|
||||
|
||||
pub fn random_f32_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<f32> {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
(0..n).map(|_| rng.random_range(low..high)).collect()
|
||||
}
|
||||
|
||||
/// Assert two vectors are close following NumPy/PyTorch conventions.
|
||||
/// Formula: |a - b| <= atol + rtol * |b|
|
||||
/// Generic version that works with any Float type (f32, f16, bf16).
|
||||
pub fn assert_close<T: Num + Signed + PartialOrd + Copy + std::fmt::Display>(
|
||||
a_vec: &[T],
|
||||
b_vec: &[T],
|
||||
rtol: T,
|
||||
atol: T,
|
||||
) {
|
||||
assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match");
|
||||
for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
|
||||
let diff = (*a - *b).abs();
|
||||
let tolerance = atol + rtol * b.abs();
|
||||
|
||||
if diff > tolerance {
|
||||
panic!("{a} is not close to {b}, index {i}, diff: {diff}, tolerance: {tolerance}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
ctx.bind_to_thread().ok()?;
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
ctx.compute_capability().ok()
|
||||
}
|
||||
|
||||
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
|
||||
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
|
||||
let Some((major, _)) = gpu_compute_cap() else {
|
||||
return false;
|
||||
};
|
||||
match dtype {
|
||||
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
|
||||
luminal::dtype::DType::F4E2M1
|
||||
| luminal::dtype::DType::F8E4M3
|
||||
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Machine epsilon for each dtype (approximate)
|
||||
pub fn dtype_epsilon(dtype: luminal::dtype::DType) -> f32 {
|
||||
match dtype {
|
||||
luminal::dtype::DType::F32 => 1.19e-7, // 2^-23
|
||||
luminal::dtype::DType::F16 => 9.77e-4, // 2^-10
|
||||
luminal::dtype::DType::Bf16 => 7.81e-3, // 2^-7
|
||||
luminal::dtype::DType::Int => 0.0,
|
||||
luminal::dtype::DType::Bool => 0.0,
|
||||
other => todo!("dtype_epsilon not implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a luminal DType to the corresponding candle DType.
|
||||
pub fn to_candle_dtype(dtype: luminal::dtype::DType) -> candle_core::DType {
|
||||
match dtype {
|
||||
luminal::dtype::DType::F32 => candle_core::DType::F32,
|
||||
luminal::dtype::DType::F16 => candle_core::DType::F16,
|
||||
luminal::dtype::DType::Bf16 => candle_core::DType::BF16,
|
||||
luminal::dtype::DType::Int => candle_core::DType::I32,
|
||||
luminal::dtype::DType::Bool => candle_core::DType::U8,
|
||||
other => todo!("candle dtype mapping not implemented for {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Base unary test function with input generator (CUDA version)
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
pub fn test_unary_cuda<T: TestDType>(
|
||||
shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor) -> Tensor,
|
||||
generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
seed: u64,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let shape: Vec<usize> = shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let n_elements: usize = shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(shape.clone());
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
|
||||
// Reference using candle on CUDA
|
||||
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
|
||||
let ref_a = Tensor::from_slice(&input_data, shape, &device).unwrap();
|
||||
let ref_b = ref_func(ref_a).flatten_all().unwrap();
|
||||
let ref_vec = T::candle_to_vec(&ref_b);
|
||||
|
||||
let eps = dtype_epsilon(<T as TestDType>::DTYPE);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
T::assert_match(&result, &ref_vec, tol, tol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<T>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| rt.set_data(a, input_data.clone()),
|
||||
b.id,
|
||||
&ref_vec,
|
||||
tol,
|
||||
tol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Base binary test function with input generators
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
|
||||
pub fn test_binary_cuda<T: TestDType>(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor, Tensor) -> Tensor,
|
||||
a_generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
b_generator: impl Fn(usize, u64) -> Vec<T>,
|
||||
seed: u64,
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let a_shape: Vec<usize> = a_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let b_shape: Vec<usize> = b_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let a_elements: usize = a_shape.iter().product();
|
||||
let b_elements: usize = b_shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a: GraphTensor = cx.tensor(a_shape.clone());
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = a_generator(a_elements, seed);
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
|
||||
// Reference using candle on CUDA
|
||||
let device = Device::new_cuda(0).expect("Candle CUDA device required for test");
|
||||
let ref_a = Tensor::from_slice(&a_data, a_shape, &device).unwrap();
|
||||
let ref_b = Tensor::from_slice(&b_data, b_shape, &device).unwrap();
|
||||
let ref_c = ref_func(ref_a, ref_b).flatten_all().unwrap();
|
||||
let ref_vec = T::candle_to_vec(&ref_c);
|
||||
|
||||
T::assert_match(&result, &ref_vec, rtol, atol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<T>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
},
|
||||
c.id,
|
||||
&ref_vec,
|
||||
rtol,
|
||||
atol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Test mod operation with element-wise reference using Rust's % operator
|
||||
pub fn test_mod(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
seed: u64,
|
||||
) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let a_shape: Vec<usize> = a_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let b_shape: Vec<usize> = b_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let a_elements: usize = a_shape.iter().product();
|
||||
let b_elements: usize = b_shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(a_shape.clone());
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
|
||||
// Generate divisor values away from zero (0.1 to 0.5) to avoid division issues
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
// Reference: Rust's % operator matches CUDA's fmodf (IEEE 754 remainder)
|
||||
let expected: Vec<f32> = a_data
|
||||
.iter()
|
||||
.zip(b_data.iter())
|
||||
.map(|(x, y)| x % y)
|
||||
.collect();
|
||||
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let rtol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let atol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
assert_close(&result, &expected, rtol, atol);
|
||||
|
||||
// Fuzz genomes: verify multiple graph rewrites produce consistent results
|
||||
fuzz_genomes::<f32>(
|
||||
&cx,
|
||||
&stream,
|
||||
|rt| {
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
},
|
||||
c.id,
|
||||
&expected,
|
||||
rtol,
|
||||
atol,
|
||||
GENOME_FUZZ_COUNT,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generate a slice range for an axis of given size.
|
||||
/// If do_start is true, randomly choose a start offset (leaving at least 1 element).
|
||||
/// If do_end is true, randomly choose an end before the axis end.
|
||||
pub fn gen_slice_range(
|
||||
size: usize,
|
||||
do_start: bool,
|
||||
do_end: bool,
|
||||
rng: &mut impl Rng,
|
||||
) -> (usize, usize) {
|
||||
let start = if do_start && size > 1 {
|
||||
rng.random_range(0..size)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let remaining = size - start;
|
||||
let end = if do_end && remaining > 1 {
|
||||
start + rng.random_range(1..remaining)
|
||||
} else {
|
||||
size
|
||||
};
|
||||
(start, end)
|
||||
}
|
||||
|
||||
/// Fuzz test multiple genomes from the e-graph search space.
|
||||
///
|
||||
/// After a graph has been built and compared against a reference, this function
|
||||
/// extracts random genomes via mutation and verifies they all produce results
|
||||
/// matching the expected reference output. This catches bugs where graph rewrites
|
||||
/// produce incorrect computation.
|
||||
///
|
||||
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
|
||||
pub fn fuzz_genomes<T: TestDType>(
|
||||
cx: &Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
setup_inputs: impl Fn(&mut CudaRuntime),
|
||||
output_id: NodeIndex,
|
||||
expected: &[T],
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
num_genomes: usize,
|
||||
seed: u64,
|
||||
) where
|
||||
Vec<T>: ToCudaInput,
|
||||
{
|
||||
let Some(egraph) = cx.egraph() else {
|
||||
return;
|
||||
};
|
||||
let Some(ops) = cx.egglog_ops() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check if there are alternative genomes to explore
|
||||
let mutable_eclasses: usize = egraph
|
||||
.eclasses
|
||||
.iter()
|
||||
.filter(|(_, (label, enodes))| {
|
||||
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
|
||||
})
|
||||
.count();
|
||||
if mutable_eclasses == 0 {
|
||||
return; // Only one valid graph, nothing to fuzz
|
||||
}
|
||||
|
||||
// Use a different seed offset to avoid correlating with the search seed
|
||||
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(7777));
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
|
||||
let initial = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&initial));
|
||||
|
||||
let mut base = initial;
|
||||
let mut tested = 0;
|
||||
|
||||
for _ in 0..100 {
|
||||
let offspring = extract_generation(egraph, &base, 10, 2, &mut prev_selected, &mut rng);
|
||||
|
||||
if offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
for genome in offspring {
|
||||
if validate_choice_set(egraph, &genome, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
setup_inputs(&mut rt);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = T::get_from_runtime(&rt, output_id);
|
||||
T::assert_match(&result, expected, rtol, atol);
|
||||
|
||||
tested += 1;
|
||||
base = genome;
|
||||
|
||||
if tested >= num_genomes {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "luminal_cuda"
|
||||
name = "luminal_cuda_lite"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
description = "Cuda compiler for luminal"
|
||||
@@ -10,7 +10,8 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
@@ -23,10 +24,12 @@ memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
lru = "0.16.2"
|
||||
libc = "0.2"
|
||||
libloading = "0.8"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
@@ -1,4 +1,4 @@
|
||||
## luminal_cuda
|
||||
## luminal_cuda_lite
|
||||
|
||||
This crate contains the CUDA backend for Luminal.
|
||||
|
||||
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
|
||||
|
||||
### Architecture
|
||||
|
||||
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
611
crates/luminal_cuda_lite/examples/egglog_saturation.rs
Normal file
611
crates/luminal_cuda_lite/examples/egglog_saturation.rs
Normal file
@@ -0,0 +1,611 @@
|
||||
use std::{collections::BTreeMap, sync::Arc, time::Instant};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::prelude::egglog::{ast::Span, prelude::RustSpan};
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
base::{base_cleanup_egglog, base_expression_egglog},
|
||||
hlir_to_egglog,
|
||||
},
|
||||
hlir::HLIROps,
|
||||
op::{EgglogOp, IntoEgglogOp, Runtime},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
const DEFAULT_PASSES: usize = 256;
|
||||
const EGGLOG_RULESETS: &[&str] = &[
|
||||
"matmul_flatten",
|
||||
"kernel_lower",
|
||||
"direct_kernel",
|
||||
"kernel_specialize",
|
||||
"buffer_reuse",
|
||||
"matmul_backend",
|
||||
"glumoe",
|
||||
"fusion_pair",
|
||||
"fusion_grow",
|
||||
"fusion_merge",
|
||||
];
|
||||
const MOE_SEQ: usize = 2;
|
||||
const MOE_HIDDEN: usize = 16;
|
||||
const MOE_NUM_EXPERTS: usize = 8;
|
||||
const MOE_TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const GEMMA_RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Backend {
|
||||
Native,
|
||||
Cuda,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Mode {
|
||||
Current,
|
||||
Steps,
|
||||
FullDefault,
|
||||
FullCycle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Case {
|
||||
Mul,
|
||||
UnaryChain(usize),
|
||||
Gelu,
|
||||
Softmax,
|
||||
LayerNorm,
|
||||
Matmul,
|
||||
Attention,
|
||||
QwenMoe,
|
||||
GemmaMoe,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
backend: Backend,
|
||||
mode: Mode,
|
||||
case: Case,
|
||||
passes: usize,
|
||||
cleanup: bool,
|
||||
skip_roll: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args {
|
||||
backend: Backend::Cuda,
|
||||
mode: Mode::Current,
|
||||
case: Case::Gelu,
|
||||
passes: DEFAULT_PASSES,
|
||||
cleanup: true,
|
||||
skip_roll: false,
|
||||
};
|
||||
|
||||
let mut iter = std::env::args().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--backend" => {
|
||||
args.backend = match iter.next().as_deref() {
|
||||
Some("native") => Backend::Native,
|
||||
Some("cuda") => Backend::Cuda,
|
||||
other => panic!("invalid --backend {other:?}; use native|cuda"),
|
||||
};
|
||||
}
|
||||
"--mode" => {
|
||||
args.mode = match iter.next().as_deref() {
|
||||
Some("current") => Mode::Current,
|
||||
Some("steps") => Mode::Steps,
|
||||
Some("full-default") => Mode::FullDefault,
|
||||
Some("full-cycle") => Mode::FullCycle,
|
||||
other => panic!(
|
||||
"invalid --mode {other:?}; use current|steps|full-default|full-cycle"
|
||||
),
|
||||
};
|
||||
}
|
||||
"--case" => {
|
||||
args.case = parse_case(&iter.next().expect("missing --case value"));
|
||||
}
|
||||
"--passes" => {
|
||||
args.passes = iter
|
||||
.next()
|
||||
.expect("missing --passes value")
|
||||
.parse()
|
||||
.expect("invalid --passes value");
|
||||
}
|
||||
"--no-cleanup" => args.cleanup = false,
|
||||
"--skip-roll" => args.skip_roll = true,
|
||||
"--help" | "-h" => {
|
||||
println!(
|
||||
"Usage: egglog_saturation [OPTIONS]\n\
|
||||
\n\
|
||||
Options:\n\
|
||||
--backend native|cuda default: cuda\n\
|
||||
--mode current|steps|full-default|full-cycle\n\
|
||||
--case mul|unary-chain:N|gelu|softmax|layer-norm|matmul|attention|qwen-moe|gemma-moe\n\
|
||||
--passes N default: 256\n\
|
||||
--no-cleanup omit backend/HLIR cleanup rules\n\
|
||||
--skip-roll skip auto loop rolling prepass"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => panic!("unknown argument {other}; use --help"),
|
||||
}
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
fn parse_case(s: &str) -> Case {
|
||||
if let Some(n) = s.strip_prefix("unary-chain:") {
|
||||
return Case::UnaryChain(n.parse().expect("invalid unary-chain length"));
|
||||
}
|
||||
match s {
|
||||
"mul" => Case::Mul,
|
||||
"gelu" => Case::Gelu,
|
||||
"softmax" => Case::Softmax,
|
||||
"layer-norm" | "layer_norm" => Case::LayerNorm,
|
||||
"matmul" => Case::Matmul,
|
||||
"attention" => Case::Attention,
|
||||
"qwen-moe" | "qwen_moe" => Case::QwenMoe,
|
||||
"gemma-moe" | "gemma_moe" => Case::GemmaMoe,
|
||||
other => panic!("unknown case {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_case(case: Case) -> Graph {
|
||||
let mut cx = Graph::new();
|
||||
let out = match case {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor((64, 64));
|
||||
x * x
|
||||
}
|
||||
Case::UnaryChain(n) => {
|
||||
let mut x = cx.tensor((64, 64));
|
||||
for i in 0..n {
|
||||
x = match i % 6 {
|
||||
0 => x.sin(),
|
||||
1 => x.sqrt(),
|
||||
2 => x.reciprocal(),
|
||||
3 => x.exp2(),
|
||||
4 => x.log2(),
|
||||
_ => x * 1.125,
|
||||
};
|
||||
}
|
||||
x
|
||||
}
|
||||
Case::Gelu => cx.tensor((64, 64)).gelu(),
|
||||
Case::Softmax => cx.tensor((128, 128)).softmax(1),
|
||||
Case::LayerNorm => cx.tensor((128, 128)).layer_norm(1, 1e-5),
|
||||
Case::Matmul => {
|
||||
let a = cx.tensor((32, 64));
|
||||
let b = cx.tensor((64, 32));
|
||||
a.matmul(b)
|
||||
}
|
||||
Case::Attention => {
|
||||
let q = cx.tensor((64, 32));
|
||||
let k = cx.tensor((64, 32));
|
||||
let v = cx.tensor((64, 32));
|
||||
let scores = q.matmul(k.permute((1, 0))) * (1.0 / 32.0_f32.sqrt());
|
||||
scores.softmax(1).matmul(v)
|
||||
}
|
||||
Case::QwenMoe => build_qwen_moe(&mut cx),
|
||||
Case::GemmaMoe => build_gemma_moe(&mut cx),
|
||||
};
|
||||
let _ = out.output();
|
||||
cx
|
||||
}
|
||||
|
||||
fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let x = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let router_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let expert_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router_scale = cx.tensor(MOE_HIDDEN);
|
||||
let router_proj = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let per_expert_scale = cx.tensor(MOE_NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, GEMMA_RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (MOE_HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, MOE_TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
weights.gather(exp_base + exp_within)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
let mut ir_variants = Vec::new();
|
||||
let mut opkind_variants = Vec::new();
|
||||
for op in ops {
|
||||
let sort = op.sort();
|
||||
let variant = format!(
|
||||
"({} {})",
|
||||
sort.name,
|
||||
sort.fields.iter().map(|field| &field.sort).join(" ")
|
||||
);
|
||||
match sort.class.as_str() {
|
||||
"IR" => ir_variants.push(variant),
|
||||
"OpKind" => opkind_variants.push(variant),
|
||||
other => panic!("unknown sort class {other} for {}", sort.name),
|
||||
}
|
||||
}
|
||||
let extra_ir = ops.iter().flat_map(|op| op.ir_defs()).unique().join("\n");
|
||||
format!(
|
||||
"
|
||||
(datatype*
|
||||
(IR
|
||||
(OutputJoin IR IR)
|
||||
(Op OpKind IList)
|
||||
{extra_ir}
|
||||
{}
|
||||
)
|
||||
(OpKind
|
||||
{}
|
||||
)
|
||||
(IList
|
||||
(ICons IR IList)
|
||||
(INil)
|
||||
)
|
||||
)
|
||||
(function dtype (IR) DType :merge new)
|
||||
",
|
||||
ir_variants.join("\n"),
|
||||
opkind_variants.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
ops.iter()
|
||||
.filter(|op| op.cleanup())
|
||||
.map(|op| {
|
||||
let sort = op.sort();
|
||||
let fields = (0..sort.fields.len())
|
||||
.map(|i| (b'a' + i as u8) as char)
|
||||
.join(" ");
|
||||
if sort.class == "OpKind" {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
((delete (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m ({} {fields})))
|
||||
((delete ({} {fields})))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
}
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn setup_program(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let rewrites = ops
|
||||
.iter()
|
||||
.flat_map(|op| op.rewrites())
|
||||
.map(|rule| rule.to_egglog_string())
|
||||
.join("\n");
|
||||
[
|
||||
EGGLOG_RULESETS
|
||||
.iter()
|
||||
.map(|ruleset| format!("(ruleset {ruleset})"))
|
||||
.join("\n"),
|
||||
base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
base_cleanup_egglog(),
|
||||
rewrites,
|
||||
program.to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn producer_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run matmul_flatten)
|
||||
(run kernel_lower)
|
||||
(run direct_kernel)
|
||||
(run kernel_specialize)
|
||||
(run buffer_reuse)
|
||||
(run matmul_backend)
|
||||
(run glumoe)
|
||||
(run fusion_pair)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fusion_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run fusion_grow)
|
||||
(run fusion_merge)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn split_cycle() -> Vec<(&'static str, String)> {
|
||||
vec![
|
||||
("producers", format!("(saturate {})", producer_schedule())),
|
||||
("fusion", format!("(saturate {})", fusion_schedule())),
|
||||
]
|
||||
}
|
||||
|
||||
fn split_cycle_schedule() -> String {
|
||||
format!(
|
||||
"(seq
|
||||
(saturate {})
|
||||
(saturate {})
|
||||
)",
|
||||
producer_schedule(),
|
||||
fusion_schedule()
|
||||
)
|
||||
}
|
||||
|
||||
fn phase(egraph: &mut egglog::EGraph, name: &str, schedule: &str) -> bool {
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let command = format!("(run-schedule {schedule})");
|
||||
let outputs = egraph
|
||||
.parse_and_run_program(None, &command)
|
||||
.unwrap_or_else(|err| panic!("failed phase {name} schedule {schedule}: {err}"));
|
||||
let elapsed = start.elapsed();
|
||||
let after = egraph.num_tuples();
|
||||
let report = outputs
|
||||
.into_iter()
|
||||
.find_map(|output| match output {
|
||||
egglog::CommandOutput::RunSchedule(report) => Some(report),
|
||||
_ => None,
|
||||
})
|
||||
.expect("run-schedule did not return a report");
|
||||
let mut rules = report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(rule, time)| {
|
||||
(
|
||||
rule.to_string(),
|
||||
*time,
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.get(rule)
|
||||
.copied()
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
rules.sort_by_key(|(_, time, matches)| (std::cmp::Reverse(*time), std::cmp::Reverse(*matches)));
|
||||
let matches = report.num_matches_per_rule.values().sum::<usize>();
|
||||
println!(
|
||||
"phase {name:<18} {elapsed_ms:>8.2} ms | tuples {before} -> {after} ({delta:+}) | updated={updated} | iters={iters} | matches={matches}",
|
||||
elapsed_ms = elapsed.as_secs_f64() * 1000.0,
|
||||
delta = after as isize - before as isize,
|
||||
updated = report.updated,
|
||||
iters = report.iterations.len(),
|
||||
);
|
||||
for (rule, time, matches) in rules
|
||||
.into_iter()
|
||||
.filter(|(_, time, matches)| !time.is_zero() || *matches > 0)
|
||||
.take(8)
|
||||
{
|
||||
println!(
|
||||
" rule {rule:<82} {ms:>8.2} ms | matches {matches}",
|
||||
ms = time.as_secs_f64() * 1000.0,
|
||||
);
|
||||
}
|
||||
report.updated
|
||||
}
|
||||
|
||||
fn serialize_summary(egraph: &mut egglog::EGraph, root: &str) {
|
||||
let (sort, value) = egraph.eval_expr(&egglog::var!(root.to_string())).unwrap();
|
||||
let output = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
let mut classes = std::collections::BTreeSet::new();
|
||||
let mut top_ops = BTreeMap::<String, usize>::new();
|
||||
let mut nodes = 0usize;
|
||||
for node in output.egraph.nodes.values().filter(|node| !node.subsumed) {
|
||||
nodes += 1;
|
||||
classes.insert(node.eclass.clone());
|
||||
*top_ops.entry(node.op.clone()).or_default() += 1;
|
||||
}
|
||||
let top_ops = top_ops
|
||||
.into_iter()
|
||||
.sorted_by_key(|(_, count)| std::cmp::Reverse(*count))
|
||||
.take(12)
|
||||
.map(|(op, count)| format!("{op}={count}"))
|
||||
.join(", ");
|
||||
println!(
|
||||
"serialize nodes={nodes} classes={} roots={} top_ops={top_ops}",
|
||||
classes.len(),
|
||||
output.egraph.root_eclasses.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn run(args: Args) {
|
||||
let mut graph = build_case(args.case);
|
||||
let rolled = if args.skip_roll {
|
||||
0
|
||||
} else {
|
||||
graph.auto_roll_loops_prepass()
|
||||
};
|
||||
let (program, root) = hlir_to_egglog(&graph);
|
||||
|
||||
let mut ops = match args.backend {
|
||||
Backend::Native => <NativeRuntime as Runtime>::Ops::into_vec(),
|
||||
Backend::Cuda => <CudaRuntime as Runtime>::Ops::into_vec(),
|
||||
};
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup = args.cleanup && matches!(args.backend, Backend::Cuda);
|
||||
let setup = setup_program(&program, &ops, cleanup);
|
||||
|
||||
println!(
|
||||
"case={:?} backend={:?} mode={:?} passes={} cleanup={} rolled={} hlir_nodes={} setup_lines={} setup_bytes={} root={root}",
|
||||
args.case,
|
||||
args.backend,
|
||||
args.mode,
|
||||
args.passes,
|
||||
cleanup,
|
||||
rolled,
|
||||
graph.graph.node_count(),
|
||||
setup.lines().count(),
|
||||
setup.len(),
|
||||
);
|
||||
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let commands = egraph.parser.get_program_from_string(None, &setup).unwrap();
|
||||
egraph.run_program(commands).unwrap();
|
||||
println!(
|
||||
"setup {:>8.2} ms | tuples {before} -> {} ({:+})",
|
||||
start.elapsed().as_secs_f64() * 1000.0,
|
||||
egraph.num_tuples(),
|
||||
egraph.num_tuples() as isize - before as isize,
|
||||
);
|
||||
|
||||
match args.mode {
|
||||
Mode::Current | Mode::Steps => {
|
||||
for pass in 1..=args.passes {
|
||||
let mut updated = false;
|
||||
for (name, schedule) in split_cycle() {
|
||||
updated |= phase(&mut egraph, &format!("{pass:03} {name}"), &schedule);
|
||||
}
|
||||
if matches!(args.mode, Mode::Current) && !updated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Mode::FullDefault => {
|
||||
phase(&mut egraph, "expr", "(saturate expr)");
|
||||
phase(&mut egraph, "dtype", "(saturate dtype_prop)");
|
||||
phase(&mut egraph, "default-full", "(saturate (run))");
|
||||
}
|
||||
Mode::FullCycle => {
|
||||
phase(
|
||||
&mut egraph,
|
||||
"cycle-full",
|
||||
&format!("(saturate {})", split_cycle_schedule()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
phase(&mut egraph, "final expr", "(saturate expr)");
|
||||
if cleanup {
|
||||
phase(&mut egraph, "cleanup", "(saturate cleanup)");
|
||||
}
|
||||
phase(&mut egraph, "base cleanup", "(saturate base_cleanup)");
|
||||
serialize_summary(&mut egraph, &root);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
run(parse_args());
|
||||
}
|
||||
87
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
87
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
//! [`DynBackend`] implementation for the CUDA lite runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
pub struct CudaLiteDynBackend {
|
||||
pub runtime: CudaRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for CudaLiteDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"cuda_lite"
|
||||
}
|
||||
fn device_type(&self) -> &str {
|
||||
"cuda"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
|
||||
self.runtime.set_data(node, bytes);
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_f16(&self, node: NodeIndex) -> Vec<half::f16> {
|
||||
self.runtime.get_f16(node)
|
||||
}
|
||||
fn get_output_bf16(&self, node: NodeIndex) -> Vec<half::bf16> {
|
||||
self.runtime.get_bf16(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
|
||||
self.runtime.get_i64(node)
|
||||
}
|
||||
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
|
||||
self.runtime.get_f64(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
unsafe fn set_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_device_ptr(node, ptr, n) }
|
||||
}
|
||||
unsafe fn set_output_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_output_device_ptr(node, ptr, n) }
|
||||
}
|
||||
fn output_is_zero_copy(&self, node: NodeIndex) -> bool {
|
||||
self.runtime.output_is_zero_copy(node)
|
||||
}
|
||||
unsafe fn copy_output_to_device_ptr(&self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.copy_output_to_device_ptr(node, ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_lite_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
compile_backend::<CudaRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(CudaRuntime::initialize(stream)),
|
||||
|rt, node, bytes, _dtype| {
|
||||
rt.set_data(node, bytes);
|
||||
},
|
||||
Some(&|rt, node, ptr, n| unsafe { rt.set_device_ptr(node, ptr, n) }),
|
||||
|rt| Box::new(CudaLiteDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,149 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,149 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,155 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; In broadcast [batch, m, n, k] space:
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
; A strides in [batch, m, n, k]
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; B strides in [batch, m, n, k]
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?n ; ldd
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,428 @@
|
||||
; Fuse a row-major Add on top of an existing cuBLASLt matmul into
|
||||
; D = alpha * A * B + beta * C.
|
||||
;
|
||||
; The existing matmul rewrites view Luminal's row-major output [m,n] as a
|
||||
; column-major cuBLASLt matrix [n,m]. A row-major C input with logical strides
|
||||
; [row_stride, 1] therefore maps to ldc=row_stride. This lets a C slice from a
|
||||
; wider parent tensor use a larger ldc while D keeps the matmul output layout.
|
||||
; cuBLASLt requires out-of-place C and D to have the same matrix order, so these
|
||||
; beta rules only fuse C layouts that map to the current COL-ordered D layout.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched c plus matmul beta"
|
||||
)
|
||||
|
||||
; ROW-ordered D beta fusions. These pair with cublaslt_row_order_rewrite.egg,
|
||||
; where the cuBLASLt problem dimensions match Luminal's logical output [m,n].
|
||||
; A row-major C input with logical strides [row_stride, 1] maps directly to a
|
||||
; ROW-ordered cuBLASLt C[m,n] descriptor with ldc=row_stride.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched c plus matmul beta"
|
||||
)
|
||||
@@ -0,0 +1,614 @@
|
||||
; cuBLASLt epilogue rewrites.
|
||||
;
|
||||
; ReLU in the frontend lowers through maximum_f32(0.0):
|
||||
;
|
||||
; (matmul < 0) * 0 + cast(cast((-cast(matmul < 0) + 1) as bool) as f32) * matmul
|
||||
;
|
||||
; These rules fuse that expression back into CUBLASLT_EPILOGUE_RELU.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu bias epilogue"
|
||||
)
|
||||
|
||||
; Canonical tanh-approx GELU can also appear directly as:
|
||||
;
|
||||
; x * sigmoid(1.5957691216 * x * (1 + 0.044715 * x * x))
|
||||
;
|
||||
; Match that sigmoid form and fuse it into the cuBLASLt GELU epilogues.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu bias epilogue"
|
||||
)
|
||||
|
||||
; This first slice fuses column-bias adds into CUBLASLT_EPILOGUE_BIAS for the
|
||||
; older COL-ordered output view. In that view Luminal's logical [m,n] output is
|
||||
; represented as a cuBLASLt [n,m] matrix, so cuBLASLt's row-broadcast bias maps
|
||||
; to the common logical column bias of length n.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d column bias plus matmul epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column bias plus matmul epilogue"
|
||||
)
|
||||
@@ -0,0 +1,811 @@
|
||||
; FP8 support is narrower than "any FP8 x any FP8". cuBLASLt's regular FP8
|
||||
; matmul table supports these A/B descriptor pairs for F32 outputs:
|
||||
; E4M3 x E4M3
|
||||
; E4M3 x E5M2
|
||||
; E5M2 x E4M3
|
||||
; and requires TN format on Ada/Hopper-class GPUs. These rules therefore match
|
||||
; row-major x column-major Luminal matmuls, which the existing COL-order lowering
|
||||
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
|
||||
; transb=N.
|
||||
|
||||
(rule
|
||||
(
|
||||
; Match the scaled FP8 linear form directly before the unscaled FP8
|
||||
; matmul rewrite can hide the quantize/dequant scale structure.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
(= ?scaled (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(= ?cast (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
(
|
||||
(delete (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(delete (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers direct output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Fusion growth can make the live path consume a raw FP8 cuBLASLt
|
||||
; candidate through an internal CudaBinaryElementwise scale multiply,
|
||||
; instead of the original HLIR output-scale Mul. The scalar scale
|
||||
; product is tensor-wide, so the two scalar factors can be passed as
|
||||
; cuBLASLt A/B scale inputs and the internal multiply can be bypassed.
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(let ?fs_sgemm (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
(union ?fused_scale ?fs_sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
(set (dtype ?fs_sgemm) (F32))
|
||||
)
|
||||
:ruleset fusion_grow
|
||||
:name "cublaslt scaled fp8 fused output-scale f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(= ?fused_scale (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
)
|
||||
(
|
||||
(delete (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(delete (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers fused output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Batched form of the scaled FP8 linear rewrite. The scale operands are
|
||||
; scalar tensors expanded across the last three output/activation axes.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 batched row-major x column-major f32 output"
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
; Mixed output dtype rewrites for cuBLASLt.
|
||||
;
|
||||
; The first mixed mode we need for low-precision matmuls is:
|
||||
;
|
||||
; D[f32] = A[fp16/bf16] * B[fp16/bf16]
|
||||
;
|
||||
; Luminal graphs express this today as a Cast(F32) around a low-precision
|
||||
; matmul. cuBLASLt can write the f32 output directly, so expose that candidate
|
||||
; before beta fusion tries to consume an f32 C input.
|
||||
;
|
||||
; `?beta = 0.0` guard: with non-zero beta the same `?inputs` C is read at
|
||||
; F32 over a low-precision buffer. Repro: tests/test_bf16_chain_block.py.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F16) (F16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt f16 matmul cast f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (Bf16) (Bf16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt bf16 matmul cast f32 output"
|
||||
)
|
||||
@@ -0,0 +1,484 @@
|
||||
; Natural cuBLASLt row-order output rewrites. These keep Luminal's logical
|
||||
; output C[m,n] as a cuBLASLt ROW-ordered D[m,n] instead of using the older
|
||||
; swapped COL-ordered D[n,m] view. A and B orders mirror their matched logical
|
||||
; layouts, so this family is the legal base for future ROW-ordered beta fusions.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x column-major"
|
||||
)
|
||||
@@ -0,0 +1,316 @@
|
||||
; Scalar alpha/beta rewrites for cuBLASLt. These rules target scalar constants
|
||||
; expanded across the matmul/add shape, i.e. zero strides on every logical axis.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; alpha=1.0 hash-conses ?fused == ?matmul; the union merges Mul into ?matmul's eclass and saturate diverges.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; See 2d alpha scale: alpha=1.0 makes (saturate ...) diverge.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c plus matmul beta"
|
||||
)
|
||||
1587
crates/luminal_cuda_lite/src/host/cublaslt/mod.rs
Normal file
1587
crates/luminal_cuda_lite/src/host/cublaslt/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
124
crates/luminal_cuda_lite/src/host/flashinfer/README.md
Normal file
124
crates/luminal_cuda_lite/src/host/flashinfer/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# FlashInfer Integration
|
||||
|
||||
FlashInfer replaces the multi-op attention pattern (Q×K^T → scale → mask → softmax → ×V) with a single fused GPU kernel via [FlashInfer](https://github.com/flashinfer-ai/flashinfer)'s batch decode and batch prefill APIs.
|
||||
|
||||
## Current State
|
||||
|
||||
**Working:**
|
||||
- Egglog rewrite rule matches any GQA paged attention pattern (model-agnostic shapes)
|
||||
- GA search selects FlashInfer when it wins profiling — verified on Llama 3 8B (32 layers) and Qwen 3 4B (36 layers)
|
||||
- **BatchDecode** (s=1): fp32 natively — FlashInfer's decode kernel uses scalar vectorized dot products, no tensor cores
|
||||
- **BatchPrefill**: template-instantiated for fp16 but **not callable from fp32** — FlashInfer's prefill kernel requires tensor core MMA (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically only operate on 16-bit types; the C API stubs return -1 for fp32; will be enabled when native fp16/bf16 pipeline is added
|
||||
- Decode handles all cases in the current fp32 pipeline (prefill uses cuBLAS attention via dim bucketing)
|
||||
- Indptr-based mask: `qo_indptr` and `kv_indptr` are computed in-graph so the egglog rule can see them in the same chunk as the attention ops
|
||||
|
||||
**Not yet implemented:**
|
||||
- Native fp16 / bf16 pipeline (would eliminate the cast overhead in prefill)
|
||||
- Page sizes > 1
|
||||
|
||||
---
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
src/host/flashinfer/
|
||||
flashinfer_attention.egg — egglog rewrite rule (pattern match → FlashInferAttention)
|
||||
mod.rs — FlashInferAttention op (EgglogOp + HostOp impl)
|
||||
jit.rs — JIT compilation: nvcc wrapper.cu → .so, dlopen, fn pointers
|
||||
find_indptrs.rs — walks the mask e-graph node to locate qo_indptr / kv_indptr inputs
|
||||
wrapper.cu — CUDA: FlashInfer template instantiation + helper kernels
|
||||
wrapper.h — C API header for wrapper.cu
|
||||
README.md — this file
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Egglog Pattern Matching
|
||||
|
||||
The rule in `flashinfer_attention.egg` matches the structural pattern of paged GQA attention:
|
||||
|
||||
```
|
||||
Gather(K_cache, idx) → GQA broadcast (Mul×1.0) → Q×K^T → Sum → scale → mask Add → softmax → attn×V → Sum → output
|
||||
Gather(V_cache, idx) → GQA broadcast (Mul×1.0) ──────────────────────────────────────────→ attn×V → Sum → output
|
||||
```
|
||||
|
||||
Key anchors that prevent false matches on MLP or other ops:
|
||||
- Two Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
- GQA broadcast via `Mul(gathered, Constant(1.0))` with all-zero strides
|
||||
- Mask Add with zero-stride broadcast in the first (nheads) dimension
|
||||
- Two sequential matmul+Sum pairs connected through softmax
|
||||
|
||||
Shape dimensions are egglog variables, not pinned constants — the rule works for any model with GQA (Llama, Qwen, Mistral, etc.). The structural invariants (dimension count, zero-stride positions, Gather from 2D) are enough to avoid combinatorial explosion during saturation.
|
||||
|
||||
When the rule fires, it unions `FlashInferAttention` with the original attention output, making it an equivalent alternative in the e-graph. The GA search then profiles both paths and picks the faster one.
|
||||
|
||||
### 2. Extraction: Finding Indptrs
|
||||
|
||||
During `extract()` (called when egglog selects the FlashInferAttention e-node), `find_indptrs.rs` walks backward from the mask node in the e-graph to locate the `qo_indptr` and `kv_indptr` Input nodes. It validates the mask structure by checking for the `Mul(allowed, Constant(1e10))` pattern that `compute_attn_mask()` produces.
|
||||
|
||||
The indptrs are appended as inputs 5 and 6 to the FlashInferAttention op, so the runtime can build the CSR page table directly without recomputing anything.
|
||||
|
||||
### 3. JIT Compilation
|
||||
|
||||
FlashInfer requires `HEAD_DIM` as a compile-time template parameter. Rather than baking it at `cargo build` time, `jit.rs` JIT-compiles `wrapper.cu` with the model's actual HEAD_DIM:
|
||||
|
||||
1. First call to `ensure_compiled(head_dim)` runs `nvcc` with `-DLUMINAL_HEAD_DIM=<N>`
|
||||
2. The compiled `.so` is cached at `~/.cache/luminal/flashinfer/libflashinfer_hd<N>_<arch>.so`
|
||||
3. Subsequent calls load the cached library via `dlopen`
|
||||
4. Function pointers (plan, run, transpose, etc.) are resolved and stored in a `static OnceLock`
|
||||
|
||||
Supported HEAD_DIM values: 64, 128, 256.
|
||||
|
||||
### 4. Runtime Execution
|
||||
|
||||
`FlashInferAttention::execute()` dispatches to decode or prefill based on `total_q_tokens vs batch_size`:
|
||||
|
||||
**Common steps:**
|
||||
1. **Extract kv_indices** — a helper kernel converts the flat gather index `(c, KV_DIM)` to slot indices `(c,)`
|
||||
2. **Read indptrs to host** — copied to CPU for the plan phase
|
||||
3. **Plan** — queries GPU occupancy and decides split-KV decomposition
|
||||
4. **Run** — the fused kernel writes `(total_q_tokens, num_qo_heads, head_dim)`
|
||||
5. **Transpose** — transposes to `(num_qo_heads, total_q_tokens, head_dim)` to match the Sum reduction layout
|
||||
|
||||
**Decode path** (current, fp32): Always used. Runs FlashInfer's BatchDecode directly on fp32 buffers.
|
||||
|
||||
**Prefill path** (future, fp16/bf16 only): The prefill kernel templates are compiled into the JIT .so for fp16 (CTA_TILE_Q=16/64/128, causal mask). The C API stubs currently return -1 since the pipeline is fp32. When native fp16/bf16 dtype support is added, `execute()` will dispatch to prefill when `total_q_tokens > batch_size`.
|
||||
|
||||
Global workspaces (`static OnceLock`) are shared across all FlashInferAttention instances to avoid ~4ms allocation overhead per GA profiling candidate. Without this, the GA never selects FlashInfer because the first-run allocation cost dwarfs the kernel time.
|
||||
|
||||
## How the Attention Mask Enables FlashInfer
|
||||
|
||||
For the egglog rule to fire, the `qo_indptr` and `kv_indptr` tensors must be visible in the same e-graph chunk as the attention ops. This is why the mask is computed *inside* each layer (via `compute_attn_mask()` in the model) rather than passed as a pre-computed input.
|
||||
|
||||
The mask computation uses a specific structure:
|
||||
```rust
|
||||
let allowed = same_request * causal;
|
||||
allowed * 1e10 - 1e10 // → 0.0 for allowed, -1e10 for blocked
|
||||
```
|
||||
|
||||
The `Mul(allowed, Constant(1e10))` pattern is the anchor that `find_indptrs.rs` uses to walk backward and locate the indptr inputs.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Items listed in priority order. Checked items are done.
|
||||
|
||||
- [x] Model-agnostic egglog rule (shape variables instead of Llama-specific constants)
|
||||
- [x] bs>1 supersequence decode
|
||||
- [x] Indptr-based attention mask (replaces CPU-computed mask)
|
||||
- [x] Multi-model support (verified on Llama 3 8B and Qwen 3 4B)
|
||||
- [x] BatchPrefill kernel compiled for fp16 (causal mask, CTA_TILE_Q=16/64/128)
|
||||
- [ ] Native fp16 / bf16 pipeline (enables prefill, reduces memory, eliminates cuBLAS prefill fallback)
|
||||
- [ ] HEAD_DIM dispatch for 64, 96 (JIT supports 64/128/256; wrapper.cu needs 96 for Phi)
|
||||
- [ ] Page sizes > 1 (currently page_size=1; larger pages reduce CSR overhead)
|
||||
- [ ] Sliding window, ALiBi, logits soft cap (FlashInfer `AttentionVariant` templates)
|
||||
- [ ] MHA / MQA / arbitrary GQA ratios beyond {1, 2, 4, 8}
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- **page_size=1**: Each KV cache slot is one "page". This simplifies the CSR page table (`kv_indices` = physical slot indices directly) and matches the flat `(num_slots, KV_DIM)` cache layout.
|
||||
|
||||
- **Pinned structural anchors**: The egglog rule pins the *structure* (number of dimensions, which dims are zero-stride, presence of Gather from 2D cache) but uses variables for the *values* (head counts, head_dim). This prevents saturation blowup while remaining model-agnostic.
|
||||
|
||||
- **Prefill requires fp16/bf16**: FlashInfer's prefill kernel uses tensor core MMA instructions (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically require 16-bit inputs — there is no fp32 tensor core matmul instruction. The prefill kernel templates are compiled into the .so for fp16 but the C API returns -1 for fp32 callers. When native fp16/bf16 is added, prefill will be enabled automatically.
|
||||
|
||||
- **Global workspaces**: Float workspace (128 MiB), int workspace (8 MiB), and a page-locked host buffer are allocated once via `static OnceLock` and shared across all instances.
|
||||
328
crates/luminal_cuda_lite/src/host/flashinfer/find_indptrs.rs
Normal file
328
crates/luminal_cuda_lite/src/host/flashinfer/find_indptrs.rs
Normal file
@@ -0,0 +1,328 @@
|
||||
//! Walk the e-graph from the mask node to find qo_indptr and kv_indptr Input nodes.
|
||||
//!
|
||||
//! The mask is produced by `compute_attn_mask(q_pos, qo_indptr, kv_indptr)` using
|
||||
//! primitive HLIR ops. This module validates the mask's structure and extracts the
|
||||
//! indptr Input node IDs so FlashInfer can use them directly.
|
||||
|
||||
use luminal::egglog_utils::{ClassId, NodeId, SerializedEGraph};
|
||||
use luminal::prelude::FxHashSet;
|
||||
|
||||
/// Result of walking the mask computation chain.
|
||||
#[derive(Debug)]
|
||||
pub struct IndptrNodes<'a> {
|
||||
pub qo_indptr: &'a NodeId,
|
||||
pub kv_indptr: &'a NodeId,
|
||||
}
|
||||
|
||||
/// Find the qo_indptr and kv_indptr Input nodes by walking backwards from the mask.
|
||||
///
|
||||
/// Validates the mask structure: `allowed * 1e10 + (-1e10)`. Then does a BFS from
|
||||
/// the `allowed` subtree to find all reachable Input nodes with names containing
|
||||
/// "qo_indptr" and "kv_indptr".
|
||||
///
|
||||
/// Panics with a diagnostic message if the structure doesn't match or the
|
||||
/// indptr inputs can't be found.
|
||||
pub fn find_indptr_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let mask_inputs = logical_binary_inputs(egraph, mask_node, "Add").unwrap_or_else(|| {
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
panic!("find_indptr_inputs: mask is not an Add (kind={mask_kind_label})");
|
||||
});
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
"find_indptr_inputs: mask Add should have 2 inputs, got {}",
|
||||
mask_inputs.len()
|
||||
);
|
||||
|
||||
// Step 2: One of the inputs should be Mul(allowed, Constant(1e10))
|
||||
let (scaled_allowed, allowed_node) = find_1e10_mul(egraph, &mask_inputs);
|
||||
|
||||
// Step 3: BFS from `allowed` to find all reachable Input nodes
|
||||
let reachable_inputs = find_reachable_inputs(egraph, allowed_node);
|
||||
|
||||
// Step 4: Match by name
|
||||
let mut qo_indptr: Option<&NodeId> = None;
|
||||
let mut kv_indptr: Option<&NodeId> = None;
|
||||
|
||||
for (node_id, name) in &reachable_inputs {
|
||||
if name.contains("qo_indptr") {
|
||||
qo_indptr = Some(node_id);
|
||||
} else if name.contains("kv_indptr") {
|
||||
kv_indptr = Some(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let qo = qo_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'qo_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
let kv = kv_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'kv_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
IndptrNodes {
|
||||
qo_indptr: qo,
|
||||
kv_indptr: kv,
|
||||
}
|
||||
}
|
||||
|
||||
fn find_1e10_mul<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let Some(mul_inputs) = logical_binary_inputs(egraph, input_node, "Mul") else {
|
||||
continue;
|
||||
};
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
for (i, &inp) in mul_inputs.iter().enumerate() {
|
||||
if is_constant(egraph, inp, 1e10) {
|
||||
let other = mul_inputs[1 - i];
|
||||
return (input_node, other);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut debug_info = String::new();
|
||||
for (i, &input_node) in mask_add_inputs.iter().enumerate() {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
debug_info.push_str(&format!("\n input[{i}]: label={label}"));
|
||||
if label == "Op" && !children.is_empty() {
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
debug_info.push_str(&format!(" kind={kind_label}"));
|
||||
for (j, kc) in egraph.enodes[kind].1.iter().enumerate() {
|
||||
let kc_node = resolve_first_node(egraph, kc);
|
||||
debug_info.push_str(&format!(" child[{j}]={}", egraph.enodes[kc_node].0));
|
||||
}
|
||||
if kind_label.contains("Mul") && children.len() >= 2 {
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for (j, &mi) in mul_inputs.iter().enumerate() {
|
||||
let (ml, mc) = &egraph.enodes[mi];
|
||||
debug_info.push_str(&format!("\n mul_input[{j}]: label={ml}"));
|
||||
if ml == "Op" && !mc.is_empty() {
|
||||
let mk = resolve_first_node(egraph, &mc[0]);
|
||||
debug_info.push_str(&format!(" kind={}", egraph.enodes[mk].0));
|
||||
for (k, mkc) in egraph.enodes[mk].1.iter().enumerate() {
|
||||
let mkc_node = resolve_first_node(egraph, mkc);
|
||||
debug_info.push_str(&format!(" ch[{k}]={}", egraph.enodes[mkc_node].0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find Mul(allowed, Constant(1e10)) in mask Add inputs.{debug_info}"
|
||||
);
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let node = resolve_op_with_kind(egraph, node, "Constant").unwrap_or(node);
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
if !kind_label.contains("Constant") {
|
||||
return false;
|
||||
}
|
||||
let val_children = &egraph.enodes[kind].1;
|
||||
if val_children.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let val_node = resolve_first_node(egraph, &val_children[0]);
|
||||
let val_str = &egraph.enodes[val_node].0;
|
||||
if let Ok(val) = val_str.parse::<f64>() {
|
||||
(val as f32 - expected).abs() < 1.0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn find_reachable_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
start: &'a NodeId,
|
||||
) -> Vec<(&'a NodeId, String)> {
|
||||
let mut found = Vec::new();
|
||||
let mut visited = FxHashSet::default();
|
||||
let mut stack = vec![start];
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
if !visited.insert(node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
|
||||
if label == "Input" {
|
||||
if children.len() >= 2 {
|
||||
let name_node = resolve_first_node(egraph, &children[1]);
|
||||
let name = egraph.enodes[name_node].0.trim_matches('"').to_string();
|
||||
found.push((node, name));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if label == "Op" && children.len() >= 2 {
|
||||
let ir_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for inp in ir_inputs {
|
||||
stack.push(inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found
|
||||
}
|
||||
|
||||
fn walk_ilist_simple<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
ilist_eclass: &'a ClassId,
|
||||
) -> Vec<&'a NodeId> {
|
||||
let mut inputs = Vec::new();
|
||||
let mut current = resolve_first_node(egraph, ilist_eclass);
|
||||
|
||||
loop {
|
||||
let (label, children) = &egraph.enodes[current];
|
||||
if label == "INil" {
|
||||
break;
|
||||
}
|
||||
if label != "ICons" {
|
||||
break;
|
||||
}
|
||||
let ir_node = resolve_first_ir_node(egraph, &children[0]);
|
||||
inputs.push(ir_node);
|
||||
current = resolve_first_node(egraph, &children[1]);
|
||||
}
|
||||
|
||||
inputs
|
||||
}
|
||||
|
||||
fn resolve_first_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
&egraph.eclasses[eclass].1[0]
|
||||
}
|
||||
|
||||
fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
let nodes = &egraph.eclasses[eclass].1;
|
||||
for node in nodes {
|
||||
let label = &egraph.enodes[node].0;
|
||||
if label == "Op" || label == "Input" {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
|
||||
fn resolve_op_with_kind<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
kind_substr: &str,
|
||||
) -> Option<&'a NodeId> {
|
||||
let class = egraph.node_to_class.get(node)?;
|
||||
for candidate in &egraph.eclasses[class].1 {
|
||||
let (label, children) = &egraph.enodes[candidate];
|
||||
if label != "Op" || children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains(kind_substr) {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn logical_binary_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
op_name: &str,
|
||||
) -> Option<Vec<&'a NodeId>> {
|
||||
if let Some(op_node) = resolve_op_with_kind(egraph, node, op_name) {
|
||||
let (_, children) = &egraph.enodes[op_node];
|
||||
return Some(walk_ilist_simple(egraph, &children[1]));
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains("CudaBinaryElementwise") {
|
||||
let opcode_class = egraph.enodes[kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
return Some(
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if !egraph.enodes[kind].0.contains("FusionEnd") {
|
||||
return None;
|
||||
}
|
||||
let fe_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
let elem = *fe_inputs.first()?;
|
||||
let (elem_label, elem_children) = &egraph.enodes[elem];
|
||||
if elem_label != "Op" || elem_children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let elem_kind = resolve_first_node(egraph, &elem_children[0]);
|
||||
if !egraph.enodes[elem_kind].0.contains("CudaBinaryElementwise") {
|
||||
return None;
|
||||
}
|
||||
let opcode_class = egraph.enodes[elem_kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
walk_ilist_simple(egraph, &elem_children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn unwrap_fusion_start<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> &'a NodeId {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return node;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("FusionStart") {
|
||||
return node;
|
||||
}
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.first()
|
||||
.copied()
|
||||
.unwrap_or(node)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
; FlashInfer batch decode attention rewrite rule.
|
||||
;
|
||||
; Matches the paged attention pattern for ANY model with GQA:
|
||||
; Gather(K_cache) → GQA broadcast → Q*K^T matmul → scale → add mask → softmax → attn*V matmul
|
||||
; Gather(V_cache) → GQA broadcast ──────────────────────────────────────────→ attn*V matmul
|
||||
;
|
||||
; Structural anchors (prevent false matches on MLP/other ops):
|
||||
; - Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
; - GQA broadcast via Mul(gathered, Constant(1.0)) with all-zero strides
|
||||
; - Scale Mul(QK, constant) connecting QK scores to mask Add
|
||||
; - Mask Add with zero-stride broadcast in first dim (nheads broadcast)
|
||||
; - Data flow: two sequential matmul+reduce pairs connected through softmax
|
||||
;
|
||||
; The egglog rule captures the mask as 5th input. During extract(), a Rust
|
||||
; function walks the mask's computation chain in the e-graph to locate the
|
||||
; qo_indptr and kv_indptr Input nodes (validated via the Constant(1e10) anchor
|
||||
; and structural checks). These are appended as inputs 5 and 6 so FlashInfer
|
||||
; can build the CSR page table directly — no runtime derivation needed.
|
||||
;
|
||||
; Shape dimensions are egglog variables, not pinned constants.
|
||||
; Dynamic dims "s" (batch/seq) and "c" (context) stay pinned as MVar.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ── Second matmul: Mul(softmax_out, V_gqa) ──
|
||||
; Shape: (nheads, s, hdim, c) — 4D
|
||||
(= ?mul2 (Op (Mul
|
||||
(ECons ?nheads (ECons (MVar "s") (ECons ?hdim (ECons (MVar "c") (ENil)))))
|
||||
?mul2_a_strides
|
||||
?mul2_b_strides
|
||||
?mul2_out_strides)
|
||||
(ICons ?soft (ICons ?v_gqa (INil)))))
|
||||
|
||||
; ── Second matmul: Sum (reduction over c) → output ──
|
||||
; Shape: (nheads, s, hdim) — reduces c
|
||||
(= ?output (Op (Sum
|
||||
(ECons ?nheads2 (ECons (MVar "s") (ECons ?hdim2 (ENil))))
|
||||
(MVar "c")
|
||||
?out_in_strides
|
||||
(MIter)
|
||||
?out_out_strides)
|
||||
(ICons ?mul2 (INil))))
|
||||
|
||||
; ── V GQA broadcast: Mul(V_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, c, hdim) — 3D
|
||||
(= ?v_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?v_gqa (Op (Mul
|
||||
(ECons ?nheads3 (ECons (MVar "c") (ECons ?hdim3 (ENil))))
|
||||
?v_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?v_gqa_out_strides)
|
||||
(ICons ?v_gathered (ICons ?v_gqa_const (INil)))))
|
||||
|
||||
; ── V Gather: rows from V_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?v_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim (ENil)))
|
||||
?v_gather_strides
|
||||
(ECons ?num_slots_v (ECons ?kvdim2 (ENil)))
|
||||
?v_src_strides)
|
||||
(ICons ?v_idx (ICons ?v_cache (INil)))))
|
||||
|
||||
; ── First matmul: Mul(Q, K_gqa) ──
|
||||
; Shape: (nheads, s, c, hdim) — 4D
|
||||
(= ?mul1 (Op (Mul
|
||||
(ECons ?nheads4 (ECons (MVar "s") (ECons (MVar "c") (ECons ?hdim4 (ENil)))))
|
||||
?mul1_a_strides
|
||||
?mul1_b_strides
|
||||
?mul1_out_strides)
|
||||
(ICons ?q (ICons ?k_gqa (INil)))))
|
||||
|
||||
; ── First matmul: Sum (reduction over hdim) → QK scores ──
|
||||
; Shape: (nheads, s, c) — reduces hdim
|
||||
(= ?qk (Op (Sum
|
||||
(ECons ?nheads5 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?hdim5
|
||||
?qk_in_strides
|
||||
(MIter)
|
||||
?qk_out_strides)
|
||||
(ICons ?mul1 (INil))))
|
||||
|
||||
; ── Mask Add: Add(scaled_QK, mask) ──
|
||||
; Shape: (nheads, s, c) — 3D
|
||||
; Mask is broadcast from (s, c) via zero-stride in first dim (nheads).
|
||||
(= ?masked (Op (Add
|
||||
(ECons ?nheads8 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?mask_add_a_strides
|
||||
(ECons (MNum 0) ?mask_rest_strides)
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; FlashInfer needs qo_indptr/kv_indptr to be recoverable from the mask
|
||||
; expression. Do not match examples that pass a precomputed mask Input.
|
||||
(= ?mask (Op (Add ?inner_mask_shape ?inner_mask_a_strides ?inner_mask_b_strides ?inner_mask_out_strides)
|
||||
(ICons ?mask_scaled_allowed (ICons ?mask_offset (INil)))))
|
||||
(= ?mask_scaled_allowed (Op (Mul ?allowed_shape ?allowed_strides ?scale_const_strides ?scaled_allowed_strides)
|
||||
(ICons ?mask_allowed (ICons ?mask_scale_const (INil)))))
|
||||
(= ?mask_scale_const (Op (Constant ?mask_scale_val) (INil)))
|
||||
(> ?mask_scale_val 9999999999.0)
|
||||
(< ?mask_scale_val 10000000001.0)
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?k_gqa (Op (Mul
|
||||
(ECons ?nheads6 (ECons ?hdim6 (ECons (MVar "c") (ENil))))
|
||||
?k_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?k_gqa_out_strides)
|
||||
(ICons ?k_gathered (ICons ?k_gqa_const (INil)))))
|
||||
|
||||
; ── K Gather: rows from K_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?k_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim3 (ENil)))
|
||||
?k_gather_strides
|
||||
(ECons ?num_slots_k (ECons ?kvdim4 (ENil)))
|
||||
?k_src_strides)
|
||||
(ICons ?k_idx (ICons ?k_cache (INil)))))
|
||||
|
||||
; ── Dtype consistency ──
|
||||
(= ?dt (dtype ?q))
|
||||
(= ?dt (dtype ?k_cache))
|
||||
(= ?dt (dtype ?v_cache))
|
||||
)
|
||||
(
|
||||
(let ?fi (Op (FlashInferAttention
|
||||
?nheads (MDiv ?kvdim ?hdim) ?hdim (MNum 1) (MVar "s"))
|
||||
(ICons ?q (ICons ?k_cache (ICons ?v_cache (ICons ?k_idx (ICons ?mask (INil))))))))
|
||||
(union ?output ?fi)
|
||||
(set (dtype ?fi) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "FlashInfer batch decode attention"
|
||||
)
|
||||
504
crates/luminal_cuda_lite/src/host/flashinfer/jit.rs
Normal file
504
crates/luminal_cuda_lite/src/host/flashinfer/jit.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! JIT compilation and dynamic loading of FlashInfer kernels.
|
||||
//!
|
||||
//! Everything runs at compile / profiling time — there is no `build.rs`.
|
||||
//! `wrapper.cu` and `wrapper.h` are embedded via `include_str!()` and
|
||||
//! extracted to the cache directory on first use. The FlashInfer + CUTLASS
|
||||
//! header trees are located by probing `LUMINAL_FLASHINFER_DIR`, a small set
|
||||
//! of default paths, and (as a last resort) by `git clone`-ing FlashInfer at
|
||||
//! a pinned commit into the cache. `nvcc` is then invoked with the model's
|
||||
//! actual `HEAD_DIM` and the resulting `.so` is `dlopen`'d.
|
||||
//!
|
||||
//! `ensure_compiled` is called from `FlashInferAttention::extract()`, i.e.
|
||||
//! during luminal's compile / GA-profiling phase, not from `execute()`. After
|
||||
//! the first call the `OnceLock` makes subsequent lookups free.
|
||||
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
hash::{Hash, Hasher},
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
// ── Function pointer types matching wrapper.h ──
|
||||
|
||||
pub type PlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
indptr_h: *mut i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type RunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
pub type ExtractFn = unsafe extern "C" fn(
|
||||
flat_idx: *const i32,
|
||||
out: *mut i32,
|
||||
c: i32,
|
||||
kv_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type DeriveIndptrFn =
|
||||
unsafe extern "C" fn(mask: *const f32, indptr: *mut i32, s: i32, c: i32, stream: *mut c_void);
|
||||
|
||||
pub type TransposeOutputFn = unsafe extern "C" fn(
|
||||
src: *const f32,
|
||||
dst: *mut f32,
|
||||
batch: i32,
|
||||
heads: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type PrefillPlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
qo_indptr_h: *mut i32,
|
||||
kv_indptr_h: *mut i32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type PrefillRunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
qo_indptr: *mut i32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
// ── Embedded CUDA sources ──
|
||||
|
||||
const WRAPPER_CU: &str = include_str!("wrapper.cu");
|
||||
const WRAPPER_H: &str = include_str!("wrapper.h");
|
||||
|
||||
// ── Loaded library handle ──
|
||||
|
||||
pub struct FlashInferLib {
|
||||
// Keep the handle alive so the dlopen'd .so remains mapped.
|
||||
_lib: libloading::Library,
|
||||
pub plan: PlanFn,
|
||||
pub run: RunFn,
|
||||
pub extract_slot_indices: ExtractFn,
|
||||
pub derive_indptr_from_mask: DeriveIndptrFn,
|
||||
pub transpose_output: TransposeOutputFn,
|
||||
pub prefill_plan: PrefillPlanFn,
|
||||
pub prefill_run: PrefillRunFn,
|
||||
}
|
||||
|
||||
// SAFETY: The library handle and function pointers are valid for the lifetime
|
||||
// of the process. All functions are called with proper CUDA stream serialization.
|
||||
unsafe impl Send for FlashInferLib {}
|
||||
unsafe impl Sync for FlashInferLib {}
|
||||
|
||||
static FLASHINFER_LIB: OnceLock<FlashInferLib> = OnceLock::new();
|
||||
|
||||
/// Ensure the FlashInfer library is compiled and loaded for the given HEAD_DIM.
|
||||
/// Returns a reference to the loaded library. Thread-safe via OnceLock.
|
||||
pub fn ensure_compiled(head_dim: usize) -> &'static FlashInferLib {
|
||||
FLASHINFER_LIB.get_or_init(|| {
|
||||
assert!(
|
||||
matches!(head_dim, 64 | 128 | 256),
|
||||
"FlashInfer: unsupported HEAD_DIM={} (must be 64, 128, or 256 for f32)",
|
||||
head_dim
|
||||
);
|
||||
let so_path = compile_or_cache(head_dim);
|
||||
unsafe {
|
||||
FlashInferLib::load(&so_path)
|
||||
.unwrap_or_else(|e| panic!("Failed to load FlashInfer library: {e}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl FlashInferLib {
|
||||
/// Load a compiled FlashInfer .so and resolve function pointers.
|
||||
///
|
||||
/// # Safety
|
||||
/// The .so must be a valid FlashInfer wrapper compiled from wrapper.cu.
|
||||
unsafe fn load(path: &Path) -> Result<Self, libloading::Error> {
|
||||
let lib = unsafe { libloading::Library::new(path)? };
|
||||
let plan: PlanFn = unsafe { *lib.get::<PlanFn>(b"flashinfer_batch_decode_plan\0")? };
|
||||
let run: RunFn = unsafe { *lib.get::<RunFn>(b"flashinfer_batch_decode_run\0")? };
|
||||
let extract_slot_indices: ExtractFn =
|
||||
unsafe { *lib.get::<ExtractFn>(b"flashinfer_extract_slot_indices\0")? };
|
||||
let derive_indptr_from_mask: DeriveIndptrFn =
|
||||
unsafe { *lib.get::<DeriveIndptrFn>(b"flashinfer_derive_indptr_from_mask\0")? };
|
||||
let transpose_output: TransposeOutputFn =
|
||||
unsafe { *lib.get::<TransposeOutputFn>(b"flashinfer_transpose_output\0")? };
|
||||
let prefill_plan: PrefillPlanFn =
|
||||
unsafe { *lib.get::<PrefillPlanFn>(b"flashinfer_batch_prefill_plan\0")? };
|
||||
let prefill_run: PrefillRunFn =
|
||||
unsafe { *lib.get::<PrefillRunFn>(b"flashinfer_batch_prefill_run\0")? };
|
||||
Ok(Self {
|
||||
_lib: lib,
|
||||
plan,
|
||||
run,
|
||||
extract_slot_indices,
|
||||
derive_indptr_from_mask,
|
||||
transpose_output,
|
||||
prefill_plan,
|
||||
prefill_run,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile wrapper.cu for the given HEAD_DIM, or return cached .so path.
|
||||
fn compile_or_cache(head_dim: usize) -> PathBuf {
|
||||
let cache_dir = cache_directory();
|
||||
std::fs::create_dir_all(&cache_dir).expect("Failed to create FlashInfer cache directory");
|
||||
|
||||
// Extract bundled wrapper sources to the cache so nvcc can compile them.
|
||||
let (wrapper_cu_path, wrapper_h_dir) = extract_wrapper_sources(&cache_dir);
|
||||
|
||||
let arch = detect_cuda_arch();
|
||||
// Bake a hash of the embedded wrapper into the .so name so old caches are
|
||||
// discarded automatically when wrapper.cu or wrapper.h change.
|
||||
let wrapper_hash = wrapper_source_hash();
|
||||
let so_name = format!(
|
||||
"libflashinfer_hd{}_{}_w{:016x}.so",
|
||||
head_dim, arch, wrapper_hash
|
||||
);
|
||||
let so_path = cache_dir.join(&so_name);
|
||||
|
||||
if so_path.exists() {
|
||||
eprintln!(
|
||||
"FlashInfer: using cached library for HEAD_DIM={} ({})",
|
||||
head_dim,
|
||||
so_path.display()
|
||||
);
|
||||
return so_path;
|
||||
}
|
||||
|
||||
let Some((flashinfer_include, cutlass_include)) = locate_flashinfer_includes() else {
|
||||
panic!(
|
||||
"FlashInfer: could not locate header tree. Set LUMINAL_FLASHINFER_DIR to the \
|
||||
FlashInfer source root (the directory containing `include/` and \
|
||||
`3rdparty/cutlass/include/`)."
|
||||
);
|
||||
};
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: JIT compiling for HEAD_DIM={}, arch={} ...",
|
||||
head_dim, arch
|
||||
);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let output = Command::new("nvcc")
|
||||
.args([
|
||||
"-shared",
|
||||
"-o",
|
||||
so_path.to_str().unwrap(),
|
||||
&format!("-DLUMINAL_HEAD_DIM={}", head_dim),
|
||||
wrapper_cu_path.to_str().unwrap(),
|
||||
"-I",
|
||||
flashinfer_include.to_str().unwrap(),
|
||||
"-I",
|
||||
cutlass_include.to_str().unwrap(),
|
||||
"-I",
|
||||
wrapper_h_dir.to_str().unwrap(),
|
||||
"-std=c++17",
|
||||
&format!("-arch={}", arch),
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-w",
|
||||
"-rdc=true",
|
||||
"--compiler-options",
|
||||
"-fPIC",
|
||||
])
|
||||
.output()
|
||||
.expect("Failed to run nvcc. Is the CUDA toolkit installed?");
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let _ = std::fs::remove_file(&so_path);
|
||||
panic!(
|
||||
"FlashInfer JIT compilation failed (HEAD_DIM={}, arch={}):\nstdout: {}\nstderr: {}",
|
||||
head_dim, arch, stdout, stderr
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
"FlashInfer: compiled in {:.1}s → {}",
|
||||
elapsed.as_secs_f64(),
|
||||
so_path.display()
|
||||
);
|
||||
|
||||
so_path
|
||||
}
|
||||
|
||||
/// Returns ~/.cache/luminal/flashinfer/
|
||||
fn cache_directory() -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("luminal")
|
||||
.join("flashinfer")
|
||||
}
|
||||
|
||||
/// Drop the embedded wrapper.cu/wrapper.h into the cache dir so nvcc has files
|
||||
/// on disk to compile. Returns (wrapper.cu path, directory containing wrapper.h).
|
||||
fn extract_wrapper_sources(cache_dir: &Path) -> (PathBuf, PathBuf) {
|
||||
let cu = cache_dir.join("wrapper.cu");
|
||||
let h = cache_dir.join("wrapper.h");
|
||||
write_if_changed(&cu, WRAPPER_CU.as_bytes());
|
||||
write_if_changed(&h, WRAPPER_H.as_bytes());
|
||||
(cu, cache_dir.to_path_buf())
|
||||
}
|
||||
|
||||
fn write_if_changed(path: &Path, contents: &[u8]) {
|
||||
if let Ok(existing) = std::fs::read(path)
|
||||
&& existing == contents
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::fs::write(path, contents).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"FlashInfer: failed to write wrapper source to {}: {e}",
|
||||
path.display()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn wrapper_source_hash() -> u64 {
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
WRAPPER_CU.hash(&mut hasher);
|
||||
WRAPPER_H.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// ── Pinned FlashInfer source ──
|
||||
//
|
||||
// Bumping this constant invalidates the cached source tree AND the cached .so
|
||||
// (the .so cache key incorporates the wrapper hash, which is rebuilt against
|
||||
// these headers, so different headers compile to a different .so file even at
|
||||
// the same head_dim). If you change `FLASHINFER_GIT_REV`, also re-check
|
||||
// `wrapper.cu` against the new FlashInfer API.
|
||||
|
||||
const FLASHINFER_GIT_URL: &str = "https://github.com/flashinfer-ai/flashinfer.git";
|
||||
const CUTLASS_GIT_URL: &str = "https://github.com/NVIDIA/cutlass.git";
|
||||
const FLASHINFER_GIT_REV: &str = "f1e6fdcb8f65104047697f022b5d055ef022d763";
|
||||
const CUTLASS_GIT_REV: &str = "f3fde58372d33e9a5650ba7b80fc48b3b49d40c8";
|
||||
|
||||
fn locate_flashinfer_includes() -> Option<(PathBuf, PathBuf)> {
|
||||
if let Ok(path) = std::env::var("LUMINAL_FLASHINFER_DIR")
|
||||
&& !path.is_empty()
|
||||
{
|
||||
let root = PathBuf::from(path);
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
eprintln!(
|
||||
"FlashInfer: LUMINAL_FLASHINFER_DIR={} did not contain include/ and \
|
||||
3rdparty/cutlass/include/ — falling back to default locations",
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
|
||||
let home = std::env::var("HOME").unwrap_or_default();
|
||||
let candidates = [
|
||||
PathBuf::from(&home).join("luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
PathBuf::from(&home).join("luminal_cuda/flashinfer"),
|
||||
PathBuf::from("/opt/luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
];
|
||||
for root in candidates {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: fetch the pinned commit into the cache directory.
|
||||
fetch_flashinfer_source().ok().map(|root| {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
(inc, cutlass)
|
||||
})
|
||||
}
|
||||
|
||||
/// Clone FlashInfer at `FLASHINFER_GIT_REV` + CUTLASS at `CUTLASS_GIT_REV`
|
||||
/// into `~/.cache/luminal/flashinfer-src/<short_rev>/` if absent, then return
|
||||
/// the FlashInfer root directory. ~50 MB one-time download; subsequent calls
|
||||
/// short-circuit on the directory check.
|
||||
fn fetch_flashinfer_source() -> Result<PathBuf, String> {
|
||||
let short = &FLASHINFER_GIT_REV[..12];
|
||||
let cache_root = cache_directory().join("flashinfer-src").join(short);
|
||||
let inc = cache_root.join("include");
|
||||
let cutlass_inc = cache_root.join("3rdparty/cutlass/include");
|
||||
|
||||
if inc.exists() && cutlass_inc.exists() {
|
||||
return Ok(cache_root);
|
||||
}
|
||||
|
||||
let parent = cache_root.parent().unwrap();
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("failed to create {}: {e}", parent.display()))?;
|
||||
|
||||
// Clone into a staging dir, then atomic rename. Protects against multiple
|
||||
// processes racing to fetch the same source.
|
||||
let staging = parent.join(format!(".staging-{}-{}", short, std::process::id()));
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: cloning {FLASHINFER_GIT_URL} @ {short} into {} (one-time fetch, ~50 MB) …",
|
||||
cache_root.display()
|
||||
);
|
||||
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
FLASHINFER_GIT_URL,
|
||||
staging.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&staging, &["checkout", FLASHINFER_GIT_REV])?;
|
||||
|
||||
// Init only the CUTLASS submodule (skip spdlog — we don't need it for kernels).
|
||||
let cutlass_path = staging.join("3rdparty/cutlass");
|
||||
let _ = std::fs::remove_dir_all(&cutlass_path);
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
CUTLASS_GIT_URL,
|
||||
cutlass_path.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&cutlass_path, &["checkout", CUTLASS_GIT_REV])?;
|
||||
|
||||
if !staging.join("include").exists() {
|
||||
return Err(format!(
|
||||
"FlashInfer clone succeeded but include/ missing at {}",
|
||||
staging.display()
|
||||
));
|
||||
}
|
||||
if !staging.join("3rdparty/cutlass/include").exists() {
|
||||
return Err(format!(
|
||||
"CUTLASS clone succeeded but include/ missing at {}",
|
||||
staging.join("3rdparty/cutlass").display()
|
||||
));
|
||||
}
|
||||
|
||||
// Atomic-ish rename. If another process beat us to it, just keep theirs.
|
||||
match std::fs::rename(&staging, &cache_root) {
|
||||
Ok(()) => {}
|
||||
Err(_) if cache_root.exists() => {
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
}
|
||||
Err(e) => return Err(format!("rename to {} failed: {e}", cache_root.display())),
|
||||
}
|
||||
|
||||
Ok(cache_root)
|
||||
}
|
||||
|
||||
fn run_git(args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}. Is git installed?"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` failed: {}",
|
||||
args.join(" "),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_git_in(cwd: &Path, args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` in {} failed: {}",
|
||||
args.join(" "),
|
||||
cwd.display(),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect CUDA arch via env override → nvidia-smi → default sm_80.
|
||||
fn detect_cuda_arch() -> String {
|
||||
if let Ok(arch) = std::env::var("FLASHINFER_CUDA_ARCH") {
|
||||
return arch;
|
||||
}
|
||||
|
||||
if let Ok(output) = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
|
||||
.output()
|
||||
&& output.status.success()
|
||||
{
|
||||
let cap = String::from_utf8_lossy(&output.stdout);
|
||||
let cap = cap.trim().lines().next().unwrap_or("8.0");
|
||||
let sm = cap.replace('.', "");
|
||||
if !sm.is_empty() {
|
||||
return format!("sm_{}", sm);
|
||||
}
|
||||
}
|
||||
|
||||
"sm_80".to_string()
|
||||
}
|
||||
424
crates/luminal_cuda_lite/src/host/flashinfer/mod.rs
Normal file
424
crates/luminal_cuda_lite/src/host/flashinfer/mod.rs
Normal file
@@ -0,0 +1,424 @@
|
||||
pub mod find_indptrs;
|
||||
pub mod jit;
|
||||
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaSlice, CudaStream, DevicePtr, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// FlashInfer attention op (batch decode, fp32).
|
||||
///
|
||||
/// Replaces the full paged-GQA attention pattern (gather → broadcast → Q*K^T →
|
||||
/// scale → mask → softmax → *V) with a single FlashInfer fused kernel.
|
||||
///
|
||||
/// Graph inputs (7): Q, K_pool, V_pool, flat_gather_idx, mask, qo_indptr, kv_indptr.
|
||||
/// The egglog rule captures the first 5; `extract()` appends qo/kv indptrs after
|
||||
/// walking the e-graph from the mask. `batch_size` is derived at runtime from the
|
||||
/// indptr length (= num_sequences + 1).
|
||||
#[derive(Debug)]
|
||||
pub struct FlashInferAttention {
|
||||
pub num_qo_heads: usize,
|
||||
pub num_kv_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub page_size: usize,
|
||||
pub batch_dim: Expression,
|
||||
|
||||
pub plan_info: Mutex<Vec<i64>>,
|
||||
}
|
||||
|
||||
// SAFETY: PAGE_LOCKED_WORKSPACE holds a raw pointer to page-locked CUDA memory
|
||||
// allocated once and serialized via the CUDA stream that owns it.
|
||||
unsafe impl Send for FlashInferAttention {}
|
||||
unsafe impl Sync for FlashInferAttention {}
|
||||
|
||||
const FLOAT_WORKSPACE_SIZE: usize = 128 * 1024 * 1024; // 128 MiB
|
||||
const INT_WORKSPACE_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
static PAGE_LOCKED_WORKSPACE: OnceLock<PageLockedPtr> = OnceLock::new();
|
||||
|
||||
struct PageLockedPtr(*mut u8);
|
||||
|
||||
// SAFETY: The pointer is page-locked CUDA memory allocated once via
|
||||
// posix_memalign + cudaHostRegister and only mutated during OnceLock
|
||||
// initialization.
|
||||
unsafe impl Send for PageLockedPtr {}
|
||||
unsafe impl Sync for PageLockedPtr {}
|
||||
|
||||
impl std::fmt::Debug for PageLockedPtr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PageLockedPtr({:p})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FlashInferAttention {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_qo_heads: 0,
|
||||
num_kv_heads: 0,
|
||||
head_dim: 0,
|
||||
page_size: 0,
|
||||
batch_dim: Expression::default(),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for FlashInferAttention {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FlashInferAttention",
|
||||
&[
|
||||
("num_qo_heads", EXPRESSION),
|
||||
("num_kv_heads", EXPRESSION),
|
||||
("head_dim", EXPRESSION),
|
||||
("page_size", EXPRESSION),
|
||||
("batch_dim", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
// Q, K_pool, V_pool, flat_gather_idx, mask (egglog IList).
|
||||
// extract() appends qo_indptr + kv_indptr → 7 actual inputs at runtime.
|
||||
5
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["flashinfer_attention.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let num_qo_heads = extract_expr(egraph, kind_children[0], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let num_kv_heads = extract_expr(egraph, kind_children[1], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let head_dim = extract_expr(egraph, kind_children[2], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let page_size = extract_expr(egraph, kind_children[3], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let batch_dim = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
|
||||
let extracted = Self {
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
batch_dim,
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Trigger JIT compilation (or .so cache hit) at extract time, not at
|
||||
// first execute. Pays the ~30s cold-cache nvcc cost during compile
|
||||
// rather than during the GA profiling loop, where it would dominate
|
||||
// the candidate's measured runtime and make the GA reject FlashInfer.
|
||||
let _ = jit::ensure_compiled(head_dim);
|
||||
|
||||
// Walk the mask e-graph chain to recover qo_indptr / kv_indptr Input nodes.
|
||||
// input_enodes: [Q, K_cache, V_cache, gather_idx, mask]
|
||||
let mask_node = input_enodes[4];
|
||||
let indptrs = find_indptrs::find_indptr_inputs(egraph, mask_node);
|
||||
|
||||
// Build final inputs: [Q, K_cache, V_cache, gather_idx, mask, qo_indptr, kv_indptr]
|
||||
let mut final_inputs = input_enodes;
|
||||
final_inputs.push(indptrs.qo_indptr);
|
||||
final_inputs.push(indptrs.kv_indptr);
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
(op, final_inputs)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for FlashInferAttention {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let lib = jit::ensure_compiled(self.head_dim);
|
||||
|
||||
let total_q_tokens = self
|
||||
.batch_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention batch_dim is unresolved"))?;
|
||||
let c = *dyn_map
|
||||
.get(&'c')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'c'"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'r'"))?;
|
||||
|
||||
if inputs.len() < 7 {
|
||||
anyhow::bail!(
|
||||
"FlashInferAttention expects 7 inputs (Q, K, V, flat_idx, mask, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("FlashInferAttention missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_buf = get_buf("Q", inputs[0])?;
|
||||
let k_buf = get_buf("K_cache", inputs[1])?;
|
||||
let v_buf = get_buf("V_cache", inputs[2])?;
|
||||
let flat_idx_buf = get_buf("flat_gather_idx", inputs[3])?;
|
||||
// inputs[4] = mask (unused by FlashInfer — indptrs replace it)
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[6])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
// Derive batch_size (num sequences) from r = indptr length.
|
||||
let batch_size = r.saturating_sub(1);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"FlashInferAttention",
|
||||
total_q_tokens,
|
||||
batch_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let kv_dim = self.num_kv_heads * self.head_dim;
|
||||
let cu_stream = stream.cu_stream() as *mut std::ffi::c_void;
|
||||
|
||||
// Extract slot indices (one per context page) from the flat gather index.
|
||||
let indices_buf = unsafe { stream.alloc::<u8>(c.max(1) * std::mem::size_of::<i32>())? };
|
||||
let (indices_ptr, _idx_guard) = indices_buf.device_ptr(stream);
|
||||
|
||||
if c > 0 {
|
||||
unsafe {
|
||||
(lib.extract_slot_indices)(
|
||||
flat_idx_buf.ptr() as *const i32,
|
||||
indices_ptr as *mut i32,
|
||||
c as i32,
|
||||
kv_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Read kv_indptr to host for the plan phase.
|
||||
let kv_indptr_bytes = r * 4;
|
||||
let mut kv_indptr_host_bytes = vec![0u8; kv_indptr_bytes];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(
|
||||
&mut kv_indptr_host_bytes,
|
||||
kv_indptr_buf.ptr(),
|
||||
stream.cu_stream(),
|
||||
)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let kv_indptr_host: Vec<i32> = unsafe {
|
||||
let mut v = std::mem::ManuallyDrop::new(kv_indptr_host_bytes);
|
||||
Vec::from_raw_parts(v.as_mut_ptr() as *mut i32, r, r)
|
||||
};
|
||||
|
||||
// kv_last_page_len = [1; batch_size] when page_size=1.
|
||||
let last_page_host: Vec<i32> = vec![1; batch_size];
|
||||
let last_page_dev: CudaSlice<u8> = if batch_size > 0 {
|
||||
stream.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
last_page_host.as_ptr() as *const u8,
|
||||
last_page_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
unsafe { stream.alloc::<u8>(1)? }
|
||||
};
|
||||
let (last_page_ptr, _lp_guard) = last_page_dev.device_ptr(stream);
|
||||
|
||||
// Global shared workspaces (allocated once across all op instances to
|
||||
// amortize the ~4ms first-allocation cost during GA profiling).
|
||||
static FLOAT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
static INT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
let float_ws = FLOAT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(FLOAT_WORKSPACE_SIZE).unwrap() });
|
||||
let int_ws = INT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(INT_WORKSPACE_SIZE).unwrap() });
|
||||
let page_locked_ws = PAGE_LOCKED_WORKSPACE.get_or_init(|| unsafe {
|
||||
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
let status = libc::posix_memalign(&mut ptr, 4096, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(status, 0, "Failed to allocate page-locked workspace");
|
||||
let cuda_status = cuda_pin_memory(ptr, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(cuda_status, 0, "Failed to pin memory");
|
||||
PageLockedPtr(ptr as *mut u8)
|
||||
});
|
||||
|
||||
let (float_ws_ptr, _fws_guard) = float_ws.device_ptr(stream);
|
||||
let (int_ws_ptr, _iws_guard) = int_ws.device_ptr(stream);
|
||||
|
||||
// FlashInfer decode writes (total_q_tokens, heads, dim);
|
||||
// luminal expects (heads, total_q_tokens, dim) — transpose at the end.
|
||||
let output_elems = total_q_tokens * self.num_qo_heads * self.head_dim;
|
||||
let temp_out_buf =
|
||||
unsafe { stream.alloc::<u8>(output_elems * std::mem::size_of::<f32>())? };
|
||||
let (temp_out_ptr, _tmp_guard) = temp_out_buf.device_ptr(stream);
|
||||
|
||||
// PrefillPlanInfo has 15 entries, DecodePlanInfo fewer — 16 is enough.
|
||||
let mut plan_info_buf = [0i64; 16];
|
||||
let mut plan_info_len: i32 = 0;
|
||||
|
||||
// ── BatchDecode path ──
|
||||
// Prefill kernels require fp16/bf16 tensor-core MMA; the C API returns -1
|
||||
// when called from the fp32 pipeline. We only use decode here.
|
||||
let plan_ret = unsafe {
|
||||
(lib.plan)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
INT_WORKSPACE_SIZE,
|
||||
page_locked_ws.0 as *mut std::ffi::c_void,
|
||||
kv_indptr_host.as_ptr() as *mut i32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
plan_info_buf.as_mut_ptr(),
|
||||
&mut plan_info_len,
|
||||
)
|
||||
};
|
||||
if plan_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode plan failed with error code {plan_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
let mut plan_info = self.plan_info.lock().unwrap();
|
||||
plan_info.clear();
|
||||
plan_info.extend_from_slice(&plan_info_buf[..plan_info_len as usize]);
|
||||
|
||||
let run_ret = unsafe {
|
||||
(lib.run)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
plan_info.as_mut_ptr(),
|
||||
plan_info.len() as i32,
|
||||
q_buf.ptr() as *mut f32,
|
||||
k_buf.ptr() as *mut f32,
|
||||
v_buf.ptr() as *mut f32,
|
||||
kv_indptr_buf.ptr() as *mut i32,
|
||||
indices_ptr as *mut i32,
|
||||
last_page_ptr as *mut i32,
|
||||
temp_out_ptr as *mut f32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
)
|
||||
};
|
||||
drop(plan_info);
|
||||
|
||||
if run_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode run failed with error code {run_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
// Transpose (total_q_tokens, heads, dim) → (heads, total_q_tokens, dim)
|
||||
unsafe {
|
||||
(lib.transpose_output)(
|
||||
temp_out_ptr as *const f32,
|
||||
out_buf.ptr() as *mut f32,
|
||||
total_q_tokens as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.batch_dim * self.num_qo_heads * self.head_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("FlashInferAttention")
|
||||
}
|
||||
}
|
||||
|
||||
/// Pin host memory for CUDA async memcpy.
|
||||
///
|
||||
/// `cudaHostRegister` lives in libcudart, which cudarc doesn't link to our
|
||||
/// binary. Resolve it via `dlopen`/`dlsym` so we don't need a build script or
|
||||
/// a `#[link]` directive — keeping the crate buildable without any nvcc-side
|
||||
/// dependencies.
|
||||
unsafe fn cuda_pin_memory(ptr: *mut std::ffi::c_void, size: usize) -> i32 {
|
||||
type HostRegisterFn = unsafe extern "C" fn(*mut std::ffi::c_void, usize, u32) -> i32;
|
||||
static FN: OnceLock<usize> = OnceLock::new();
|
||||
|
||||
let raw = *FN.get_or_init(|| unsafe {
|
||||
let lib = [
|
||||
"libcudart.so",
|
||||
"libcudart.so.13",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
]
|
||||
.iter()
|
||||
.find_map(|n| libloading::Library::new(*n).ok())
|
||||
.expect("FlashInfer: could not dlopen libcudart for cudaHostRegister");
|
||||
let sym: libloading::Symbol<HostRegisterFn> = lib
|
||||
.get(b"cudaHostRegister\0")
|
||||
.expect("FlashInfer: libcudart missing cudaHostRegister symbol");
|
||||
let ptr = *sym as *const () as usize;
|
||||
// Keep libcudart resident for the process lifetime so the function
|
||||
// pointer remains valid.
|
||||
std::mem::forget(lib);
|
||||
ptr
|
||||
});
|
||||
let f: HostRegisterFn = unsafe { std::mem::transmute(raw) };
|
||||
// cudaHostRegisterDefault = 0
|
||||
unsafe { f(ptr, size, 0) }
|
||||
}
|
||||
357
crates/luminal_cuda_lite/src/host/flashinfer/wrapper.cu
Normal file
357
crates/luminal_cuda_lite/src/host/flashinfer/wrapper.cu
Normal file
@@ -0,0 +1,357 @@
|
||||
// FlashInfer batch decode + prefill wrapper for luminal_cuda.
|
||||
// JIT-compiled at runtime with -DLUMINAL_HEAD_DIM=N.
|
||||
//
|
||||
// Decode: instantiated for f32 (scalar vectorized dot products, no tensor cores).
|
||||
// Prefill: instantiated for f16 (requires tensor core MMA + ldmatrix).
|
||||
// The C API accepts fp32 buffers; cast kernels convert fp32↔fp16 at the boundary.
|
||||
//
|
||||
// NHD layout. GQA group_size and page_size are runtime parameters.
|
||||
|
||||
#ifndef LUMINAL_HEAD_DIM
|
||||
#error "LUMINAL_HEAD_DIM must be defined (e.g. -DLUMINAL_HEAD_DIM=128)"
|
||||
#endif
|
||||
|
||||
// Include utils.cuh first to get the original DISPATCH_HEAD_DIM, then override it
|
||||
// to only instantiate our specific HEAD_DIM. This avoids a compile error in
|
||||
// cascade.cuh where HEAD_DIM=512 + f32 triggers vec_size=16, vec_bits=512
|
||||
// which exceeds cp_async's 256-bit limit.
|
||||
#include <flashinfer/utils.cuh>
|
||||
#undef DISPATCH_HEAD_DIM
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
{ \
|
||||
constexpr size_t HEAD_DIM = LUMINAL_HEAD_DIM; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#include <flashinfer/attention/scheduler.cuh>
|
||||
#include <flashinfer/attention/decode.cuh>
|
||||
#include <flashinfer/attention/default_decode_params.cuh>
|
||||
#include <flashinfer/attention/prefill.cuh>
|
||||
#include <flashinfer/attention/default_prefill_params.cuh>
|
||||
#include <flashinfer/attention/mask.cuh>
|
||||
#include <flashinfer/attention/variants.cuh>
|
||||
#include <flashinfer/page.cuh>
|
||||
#include <flashinfer/pos_enc.cuh>
|
||||
|
||||
#include "wrapper.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// ── Decode types (f32) ──
|
||||
using DTypeQ = float;
|
||||
using DTypeKV = float;
|
||||
using DTypeO = float;
|
||||
using IdType = int32_t;
|
||||
|
||||
// ── Prefill types (f16 compute, fp32 external interface) ──
|
||||
using PrefillDTypeQ = half;
|
||||
using PrefillDTypeKV = half;
|
||||
using PrefillDTypeO = half;
|
||||
|
||||
constexpr uint32_t HEAD_DIM = LUMINAL_HEAD_DIM;
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;
|
||||
|
||||
// Attention variants
|
||||
using Variant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
using CausalVariant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
// Decode params (f32)
|
||||
using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
|
||||
|
||||
// Prefill params (f16)
|
||||
using PrefillParams = BatchPrefillPagedParams<PrefillDTypeQ, PrefillDTypeKV, PrefillDTypeO, IdType>;
|
||||
|
||||
// Forward declarations
|
||||
namespace flashinfer {
|
||||
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
|
||||
typename Params>
|
||||
cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
|
||||
PosEncodingMode POS_ENCODING_MODE, bool USE_FP16_QK_REDUCTION,
|
||||
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
|
||||
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
}
|
||||
|
||||
// Explicit instantiation: decode kernel (f32)
|
||||
template cudaError_t flashinfer::BatchDecodeWithPagedKVCacheDispatched<
|
||||
HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
DecodeParams params, DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// Explicit instantiation: prefill kernels (f16, causal mask, CTA_TILE_Q=16/64/128)
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
16, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
64, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
128, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// ── fp32 ↔ fp16 cast kernels ──
|
||||
|
||||
__global__ void cast_f32_to_f16_kernel(const float* src, half* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __float2half(src[i]);
|
||||
}
|
||||
|
||||
__global__ void cast_f16_to_f32_kernel(const half* src, float* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __half2float(src[i]);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
uint32_t group_size = num_qo_heads / num_kv_heads;
|
||||
|
||||
// We need to dispatch on GROUP_SIZE to get the right work estimation function
|
||||
cudaError_t status = cudaSuccess;
|
||||
|
||||
// Use a lambda to dispatch on group size
|
||||
auto do_plan = [&]<uint32_t GROUP_SIZE>() -> cudaError_t {
|
||||
auto work_estimation_func =
|
||||
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
|
||||
GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>;
|
||||
return DecodePlan<HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
float_workspace, float_ws_size,
|
||||
int_workspace, page_locked_int_workspace,
|
||||
int_ws_size, plan_info, indptr_h,
|
||||
(uint32_t)batch_size, (uint32_t)num_qo_heads,
|
||||
(uint32_t)page_size, /*enable_cuda_graph=*/false,
|
||||
stream, work_estimation_func);
|
||||
};
|
||||
|
||||
switch (group_size) {
|
||||
case 1: status = do_plan.operator()<1>(); break;
|
||||
case 2: status = do_plan.operator()<2>(); break;
|
||||
case 4: status = do_plan.operator()<4>(); break;
|
||||
case 8: status = do_plan.operator()<8>(); break;
|
||||
default: return -1; // unsupported group size
|
||||
}
|
||||
|
||||
if (status != cudaSuccess) return (int)status;
|
||||
|
||||
auto vec = plan_info.ToVector();
|
||||
*plan_info_len_out = (int)vec.size();
|
||||
std::memcpy(plan_info_out, vec.data(), vec.size() * sizeof(int64_t));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q,
|
||||
float* k_cache,
|
||||
float* v_cache,
|
||||
int32_t* kv_indptr,
|
||||
int32_t* kv_indices,
|
||||
int32_t* kv_last_page_len,
|
||||
float* output,
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
plan_info.FromVector(std::vector<int64_t>(plan_info_vec, plan_info_vec + plan_info_len));
|
||||
|
||||
// Construct paged_kv_t with NHD layout
|
||||
paged_kv_t<DTypeKV, IdType> paged_kv(
|
||||
(uint32_t)num_kv_heads,
|
||||
(uint32_t)page_size,
|
||||
HEAD_DIM,
|
||||
(uint32_t)batch_size,
|
||||
QKVLayout::kNHD,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
kv_last_page_len);
|
||||
|
||||
DecodeParams params;
|
||||
params.q = q;
|
||||
params.q_rope_offset = nullptr;
|
||||
params.paged_kv = paged_kv;
|
||||
params.o = output;
|
||||
params.lse = nullptr;
|
||||
params.maybe_alibi_slopes = nullptr;
|
||||
params.padded_batch_size = plan_info.padded_batch_size;
|
||||
params.num_qo_heads = (uint32_t)num_qo_heads;
|
||||
// Q buffer is (batch, num_qo_heads * head_dim) flat — the graph's split_dims + transpose
|
||||
// are stride tricks, no data movement. So the actual memory layout is (batch, heads, dim).
|
||||
params.q_stride_n = num_qo_heads * HEAD_DIM;
|
||||
params.q_stride_h = HEAD_DIM;
|
||||
params.window_left = -1; // no sliding window
|
||||
params.logits_soft_cap = 0.0f;
|
||||
params.sm_scale = 1.0f / sqrtf((float)HEAD_DIM);
|
||||
params.rope_rcp_scale = 1.0f;
|
||||
params.rope_rcp_theta = 1.0f;
|
||||
|
||||
// Set plan info pointers
|
||||
params.request_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.request_indices_offset);
|
||||
params.kv_tile_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_tile_indices_offset);
|
||||
params.o_indptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.o_indptr_offset);
|
||||
params.kv_chunk_size_ptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_chunk_size_ptr_offset);
|
||||
params.block_valid_mask = nullptr;
|
||||
params.partition_kv = false;
|
||||
|
||||
DTypeO* tmp_v = nullptr;
|
||||
float* tmp_s = nullptr;
|
||||
|
||||
if (plan_info.split_kv) {
|
||||
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_workspace, plan_info.v_offset);
|
||||
tmp_s = GetPtrFromBaseOffset<float>(float_workspace, plan_info.s_offset);
|
||||
if (plan_info.enable_cuda_graph) {
|
||||
params.block_valid_mask =
|
||||
GetPtrFromBaseOffset<bool>(int_workspace, plan_info.block_valid_mask_offset);
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t status =
|
||||
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, Variant>(
|
||||
params, tmp_v, tmp_s, /*enable_pdl=*/false, stream);
|
||||
|
||||
return (int)status;
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
// BatchPrefill (fp16/bf16 only — tensor core MMA requires 16-bit inputs)
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
//
|
||||
// The prefill kernel templates are instantiated above for fp16. These C API
|
||||
// functions accept fp32 pointers (matching the current luminal pipeline) but
|
||||
// return -1 to indicate that fp32 prefill is not supported. When native fp16
|
||||
// support is added, these will accept fp16 pointers and call through to the
|
||||
// instantiated templates.
|
||||
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void*, size_t, void*, size_t, void*,
|
||||
int32_t*, int32_t*, int, int,
|
||||
int, int, int, int, cudaStream_t,
|
||||
int64_t*, int*)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
int flashinfer_batch_prefill_run(
|
||||
void*, size_t, void*,
|
||||
int64_t*, int,
|
||||
float*, float*, float*,
|
||||
int32_t*, int32_t*, int32_t*, int32_t*,
|
||||
float*, int, int, int, int, int, int, cudaStream_t)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
// ── Slot index extraction kernel (outside extern "C" for __global__) ──
|
||||
|
||||
__global__ void extract_slot_indices_kernel(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < c) out[i] = flat_idx[i * kv_dim] / kv_dim;
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream) {
|
||||
if (c == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (c + threads - 1) / threads;
|
||||
extract_slot_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
flat_idx, out, c, kv_dim);
|
||||
}
|
||||
|
||||
// ── Derive CSR indptr from attention mask ──
|
||||
// Mask is (s, c) f32. Entries > -1e9 are "valid" (0.0), rest are -inf.
|
||||
// Per-row count of valid entries = context length for that sequence.
|
||||
// Output: indptr[0..=s] with indptr[0]=0 and indptr[i+1] = indptr[i] + ctx_len[i].
|
||||
// Single thread is fine since s is tiny (batch_size during decode, typically 1-8).
|
||||
|
||||
__global__ void derive_indptr_kernel(
|
||||
const float* mask, int32_t* indptr, int s, int c) {
|
||||
if (threadIdx.x != 0 || blockIdx.x != 0) return;
|
||||
indptr[0] = 0;
|
||||
for (int i = 0; i < s; i++) {
|
||||
int count = 0;
|
||||
for (int j = 0; j < c; j++) {
|
||||
if (mask[i * c + j] > -1e9f) count++;
|
||||
}
|
||||
indptr[i + 1] = indptr[i] + count;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream) {
|
||||
if (s == 0) return;
|
||||
derive_indptr_kernel<<<1, 1, 0, stream>>>(mask, indptr, s, c);
|
||||
}
|
||||
|
||||
// ── Output transpose: (batch, heads, dim) → (heads, batch, dim) ──
|
||||
// FlashInfer writes output as (batch, heads, dim) but Luminal expects (heads, batch, dim).
|
||||
// For batch=1 these are identical; for batch>1 we need an explicit transpose.
|
||||
|
||||
__global__ void transpose_bhd_to_hbd_kernel(
|
||||
const float* src, float* dst, int batch, int heads, int dim) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = batch * heads * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Decompose linear index into (b, h, d) for src layout
|
||||
int d = idx % dim;
|
||||
int h = (idx / dim) % heads;
|
||||
int b = idx / (heads * dim);
|
||||
|
||||
// Write to (h, b, d) layout in dst
|
||||
dst[h * batch * dim + b * dim + d] = src[idx];
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream) {
|
||||
int total = batch * heads * dim;
|
||||
if (total == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (total + threads - 1) / threads;
|
||||
transpose_bhd_to_hbd_kernel<<<blocks, threads, 0, stream>>>(
|
||||
src, dst, batch, heads, dim);
|
||||
}
|
||||
93
crates/luminal_cuda_lite/src/host/flashinfer/wrapper.h
Normal file
93
crates/luminal_cuda_lite/src/host/flashinfer/wrapper.h
Normal file
@@ -0,0 +1,93 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Plan phase: CPU-side scheduling. Must call before each new batch config.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase: GPU kernel launch.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [batch_size, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* kv_indptr, // [batch_size + 1]
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [batch_size, num_qo_heads, head_dim]
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Extract slot indices from a flat gather index tensor.
|
||||
// flat_idx shape: (c, kv_dim) i32, out shape: (c,) i32.
|
||||
// out[i] = flat_idx[i * kv_dim] / kv_dim
|
||||
void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Derive CSR indptr from attention mask.
|
||||
// mask shape: (s, c) f32. Entries > -1e9 are valid.
|
||||
// indptr shape: (s + 1,) i32. indptr[0] = 0, indptr[i+1] = cumsum of valid counts.
|
||||
void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Transpose output from (batch, heads, dim) to (heads, batch, dim).
|
||||
void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// ── BatchPrefill with Paged KV Cache ──
|
||||
|
||||
// Plan phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* qo_indptr_h, int32_t* kv_indptr_h,
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [total_num_rows, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* qo_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [total_num_rows, num_qo_heads, head_dim]
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
182
crates/luminal_cuda_lite/src/host/mod.rs
Normal file
182
crates/luminal_cuda_lite/src/host/mod.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTypeTuple = (
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
&'static str,
|
||||
luminal::dtype::DType,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_type_tuple(op: &dyn HostOp) -> Option<CublasLtTypeTuple> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::type_tuple)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtScaleValues = (f64, f64);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_scale_values(op: &dyn HostOp) -> Option<CublasLtScaleValues> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::scale_values)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_epilogue(op: &dyn HostOp) -> Option<&'static str> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::epilogue)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtMatrixOrders = (&'static str, &'static str, &'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_matrix_orders(op: &dyn HostOp) -> Option<CublasLtMatrixOrders> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::matrix_orders)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTransposeOps = (&'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_transpose_ops(op: &dyn HostOp) -> Option<CublasLtTransposeOps> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::transpose_ops)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTensorScaleInputs = (bool, bool);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_tensor_scale_inputs(op: &dyn HostOp) -> Option<CublasLtTensorScaleInputs> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::tensor_scale_inputs)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
/// the reusable arena, or an external pointer. Host ops only need the pointer
|
||||
/// and the logical byte length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct DeviceBuffer {
|
||||
ptr: u64,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl DeviceBuffer {
|
||||
pub fn new(ptr: u64, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
pub fn ptr(self) -> u64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn len(self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
|
||||
let mut host = vec![0u8; self.len];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(host)
|
||||
}
|
||||
}
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
/// Execute the operation with access to buffers via a map.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `stream` - The CUDA stream to execute on
|
||||
/// * `self_node` - The NodeIndex of this op in the llir_graph (used as output buffer)
|
||||
/// * `inputs` - NodeIndices of input nodes (in edge order from the graph)
|
||||
/// * `buffers` - Map from NodeIndex to device buffer for all allocated nodes
|
||||
/// * `dyn_map` - Dynamic dimension values
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Returns the output buffer size in elements.
|
||||
/// Return 0 if this op doesn't have a single output buffer (e.g., CudaGraphOp).
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns additional nodes (beyond graph edges) that this op needs buffers for.
|
||||
///
|
||||
/// For most ops, this returns empty (buffers determined by graph edges).
|
||||
/// For CudaGraphOp, this returns all internal kernel nodes.
|
||||
fn extra_buffer_nodes(&self) -> Vec<NodeIndex> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns relative lifetimes for extra buffer nodes within this host op.
|
||||
///
|
||||
/// The tuple is `(node, first_step, last_step)`, where steps are local to
|
||||
/// this host op's execution. Returning `None` tells the runtime to treat
|
||||
/// every extra buffer as live for the whole host op.
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
/// For CudaGraphOp, this returns sizes for all internal kernel output buffers.
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
FxHashMap::default()
|
||||
}
|
||||
|
||||
/// Returns the name of this host op for stats reporting, or None if not reportable.
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
}
|
||||
281
crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg
Normal file
281
crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg
Normal file
@@ -0,0 +1,281 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a gated MoE.
|
||||
;
|
||||
; One fused op supports two activation modes:
|
||||
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared expert index/gather markers
|
||||
; 2) Shared gate-up matmul marker
|
||||
; 3) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 5) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEExpertIndexState
|
||||
(MkGLUMoEExpertIndexState Expression Expression IR)
|
||||
)
|
||||
(GLUMoEExpertGatherState
|
||||
(MkGLUMoEExpertGatherState Expression Expression IR IR)
|
||||
)
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
(GLUMoESwiGLUState
|
||||
(MkGLUMoESwiGLUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoEGemmaGELUState
|
||||
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoESwiGLUDownState
|
||||
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
|
||||
)
|
||||
(GLUMoEGemmaDownState
|
||||
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_expert_index (IR) GLUMoEExpertIndexState :merge new)
|
||||
(function glumoe_expert_gather (IR) GLUMoEExpertGatherState :merge new)
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
|
||||
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?iota_base (Op (Iota ?io ?iota_base_range) (INil)))
|
||||
(= ?mul_base (Op (Mul ?mul_base_shape ?mul_base_a_stride ?mul_base_b_stride ?mul_base_out_stride) (ICons ?topk_idx (ICons ?iota_base (INil)))))
|
||||
(= ?iota_within (Op (Iota (MIter) ?iota_within_range) (INil)))
|
||||
(= ?add_idx (Op (Add ?add_shape ?add_a_stride ?add_b_stride ?add_out_stride) (ICons ?mul_base (ICons ?iota_within (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_index ?add_idx)
|
||||
(MkGLUMoEExpertIndexState ?io ?iota_within_range ?topk_idx))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert index marker"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?index_state (glumoe_expert_index ?idx))
|
||||
(= ?index_state (MkGLUMoEExpertIndexState ?io ?within_range ?topk_idx))
|
||||
(= ?gathered (Op (Gather ?gather_idx_shape ?gather_idx_stride ?gather_data_shape ?gather_data_stride) (ICons ?idx (ICons ?weights (INil)))))
|
||||
(= ?f32 (Op (Cast ?f32_size (F32)) (ICons ?gathered (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_gather ?f32)
|
||||
(MkGLUMoEExpertGatherState ?io ?within_range ?topk_idx ?weights))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert gather marker"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?gather_state (glumoe_expert_gather ?gu_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?gu_io ?gu_iota_within_range ?topk_idx ?gate_up_w))
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 0 (SwiGLU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 2 (SwiGLU with row-normalized top-k weights) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?normed_topk (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 2))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (normalized swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_gemma_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
|
||||
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, IR},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
@@ -12,6 +12,7 @@ use luminal::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device,
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
@@ -30,17 +31,17 @@ use crate::{
|
||||
driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
nvrtc::{CompileOptions, compile_ptx_with_opts},
|
||||
},
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
|
||||
/// activation + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU/SwiGLUNormalized: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
pub(crate) mode: GLUMoEMode,
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
@@ -69,9 +74,37 @@ pub struct GLUMoE {
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
SwiGLUNormalized,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
fn from_mode_id(mode_id: usize) -> Self {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
2 => Self::SwiGLUNormalized,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU | Self::SwiGLUNormalized => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GLUMoEMode::SwiGLU,
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
@@ -88,6 +121,7 @@ impl Default for GLUMoE {
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("mode", &self.mode)
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
@@ -100,6 +134,7 @@ impl std::fmt::Debug for GLUMoE {
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mode: self.mode,
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
@@ -114,9 +149,15 @@ impl Clone for GLUMoE {
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
@@ -134,33 +175,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
extern "C" __global__ void glu_activation_bf16(
|
||||
unsigned long long gate_up_ptr,
|
||||
unsigned long long out_ptr,
|
||||
int intermediate,
|
||||
int mode
|
||||
) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
float activated;
|
||||
if (mode == 0) {
|
||||
activated = gate / (1.0f + expf(-gate));
|
||||
} else {
|
||||
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
|
||||
activated = gate / (1.0f + expf(-scaled));
|
||||
}
|
||||
out[i] = __float2bfloat16(activated * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_ptx_with_opts(
|
||||
src,
|
||||
CompileOptions {
|
||||
include_paths: vec![
|
||||
"/usr/local/cuda/include".to_string(),
|
||||
"/usr/include".to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
let activation = module.load_function("glu_activation_bf16").unwrap();
|
||||
(module, f32_to_bf16, activation)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -168,14 +210,9 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
|
||||
impl EgglogOp for GLUMoE {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
OP_KIND,
|
||||
"GLUMoE",
|
||||
&[
|
||||
("x", IR),
|
||||
("topk_idx", IR),
|
||||
("topk_vals", IR),
|
||||
("gate_up_w", IR),
|
||||
("down_w", IR),
|
||||
("gu_io", EXPRESSION),
|
||||
("dn_io", EXPRESSION),
|
||||
("gu_matmul_k", EXPRESSION),
|
||||
@@ -183,30 +220,55 @@ impl EgglogOp for GLUMoE {
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
("mode", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
(
|
||||
(set (dtype ?e) (F32))
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
),
|
||||
Rule::raw(include_str!["glumoe_rewrite.egg"]),
|
||||
]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
6
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let gu_io = extract_expr(egraph, children[5], expr_cache).unwrap();
|
||||
let dn_io = extract_expr(egraph, children[6], expr_cache).unwrap();
|
||||
let gu_matmul_k = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let dn_matmul_k = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let output_k = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, children[10], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, children[11], expr_cache).unwrap();
|
||||
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let mode_id = mode_expr
|
||||
.to_usize()
|
||||
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
|
||||
let mode = GLUMoEMode::from_mode_id(mode_id);
|
||||
|
||||
let extracted = GLUMoE {
|
||||
mode,
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
@@ -219,17 +281,8 @@ impl EgglogOp for GLUMoE {
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
(
|
||||
op,
|
||||
vec![
|
||||
children[0],
|
||||
children[1],
|
||||
children[2],
|
||||
children[3],
|
||||
children[4],
|
||||
],
|
||||
)
|
||||
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -243,26 +296,140 @@ impl HostOp for GLUMoE {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
if inputs.len() < 6 {
|
||||
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
|
||||
}
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
// Resolve dimensions
|
||||
let hidden = self
|
||||
.gu_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
|
||||
let intermediate = self
|
||||
.dn_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
|
||||
let top_k = self
|
||||
.output_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
|
||||
let gu_io = self
|
||||
.gu_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
|
||||
let dn_io = self
|
||||
.dn_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
|
||||
|
||||
if hidden == 0 || intermediate == 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
|
||||
);
|
||||
}
|
||||
if top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if gu_io % hidden != 0 {
|
||||
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
|
||||
}
|
||||
if dn_io % intermediate != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
|
||||
);
|
||||
}
|
||||
|
||||
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
|
||||
let down_hidden = dn_io / intermediate;
|
||||
if gate_up_dim != intermediate * 2 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
|
||||
gate_up_dim,
|
||||
intermediate * 2
|
||||
);
|
||||
}
|
||||
if down_hidden != hidden {
|
||||
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
|
||||
}
|
||||
|
||||
let output_bytes = self
|
||||
.output_bytes()
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
|
||||
if output_bytes % (hidden * 4) != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
|
||||
hidden * 4
|
||||
);
|
||||
}
|
||||
let seq = output_bytes / (hidden * 4);
|
||||
if seq == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
// Get input/output buffers
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
|
||||
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
|
||||
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
|
||||
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let min_topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
if output_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
|
||||
output_buf.len()
|
||||
);
|
||||
}
|
||||
|
||||
let gu_stride_bytes = gate_up_dim * hidden * 2;
|
||||
let down_stride_bytes = hidden * intermediate * 2;
|
||||
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
|
||||
gate_up_buf.len()
|
||||
);
|
||||
}
|
||||
let num_experts = gate_up_buf.len() / gu_stride_bytes;
|
||||
if num_experts == 0 {
|
||||
anyhow::bail!("GLUMoE has no expert weights");
|
||||
}
|
||||
if down_buf.len() < num_experts * down_stride_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down weight buffer too small: have {} bytes, need {}",
|
||||
down_buf.len(),
|
||||
num_experts * down_stride_bytes
|
||||
);
|
||||
}
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
@@ -270,25 +437,131 @@ impl HostOp for GLUMoE {
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
|
||||
if !topk_idx_i32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index element count {} is not divisible by seq {seq}",
|
||||
topk_idx_i32.len()
|
||||
);
|
||||
}
|
||||
if !topk_vals_f32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value element count {} is not divisible by seq {seq}",
|
||||
topk_vals_f32.len()
|
||||
);
|
||||
}
|
||||
let topk_idx_row_stride = topk_idx_i32.len() / seq;
|
||||
let topk_vals_row_stride = topk_vals_f32.len() / seq;
|
||||
if topk_idx_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index row stride {topk_idx_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
if topk_vals_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value row stride {topk_vals_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
|
||||
let topk_idx_at = |token: usize, expert: usize| -> i32 {
|
||||
topk_idx_i32[token * topk_idx_row_stride + expert]
|
||||
};
|
||||
let topk_val_at = |token: usize, expert: usize| -> f32 {
|
||||
topk_vals_f32[token * topk_vals_row_stride + expert]
|
||||
};
|
||||
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i);
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at token {t} top-k position {i} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - SwiGLUNormalized: normalize topk values row-wise
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => {
|
||||
if topk_vals_row_stride == top_k {
|
||||
topk_vals_f32
|
||||
} else {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i);
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
}
|
||||
GLUMoEMode::SwiGLUNormalized => {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i) * inv_norm;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
if per_expert_scale_host.len() < per_expert_scale_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
|
||||
per_expert_scale_host.len()
|
||||
);
|
||||
}
|
||||
let per_expert_scale_f32: &[f32] =
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
expert_idx,
|
||||
per_expert_scale_f32.len()
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[t * top_k + i] =
|
||||
topk_val_at(t, i) * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
let gate_up_out_buf = unsafe { stream.alloc::<u8>(gate_up_dim * 2)? }; // BF16 per-token
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = slice_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = slice_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
@@ -307,35 +580,21 @@ impl HostOp for GLUMoE {
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
let expert_idx = expert_idx as usize;
|
||||
for (i, &weight) in weights.iter().enumerate() {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
|
||||
// a. Gate+Up matmul (BF16 in, BF16 out)
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
@@ -354,17 +613,19 @@ impl HostOp for GLUMoE {
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
// b. Mode-specific gated activation (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
let activation_mode = self.mode.activation_kernel_mode();
|
||||
let activation_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.launch_builder(activation_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.arg(&activation_mode)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
grid_dim: (activation_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
@@ -377,7 +638,7 @@ impl HostOp for GLUMoE {
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
@@ -420,7 +681,11 @@ impl HostOp for GLUMoE {
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
|
||||
buf.ptr()
|
||||
}
|
||||
|
||||
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
738
crates/luminal_cuda_lite/src/kernel/conv2d.rs
Normal file
738
crates/luminal_cuda_lite/src/kernel/conv2d.rs
Normal file
@@ -0,0 +1,738 @@
|
||||
//! CUDA conv2d-with-bias backend rewrite.
|
||||
//!
|
||||
//! `KernelConv2D` is selected by egglog from pure HLIR conv graphs and lowers
|
||||
//! to a one-thread-per-output CUDA kernel. It avoids materializing unfold/im2col
|
||||
//! intermediates while keeping model code free of custom ops.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::FxHashMap;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::FxHashSet,
|
||||
shape::{Expression, flatten_strides},
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::{KernelOp, hlir::generate_dyn_dims_defines};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConv2D {
|
||||
out_shape: Vec<Expression>,
|
||||
input_shape: Vec<Expression>,
|
||||
input_stride: Vec<Expression>,
|
||||
weight_co_stride: Expression,
|
||||
weight_inner_stride: Expression,
|
||||
bias_c_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
kernel_h: Expression,
|
||||
kernel_w: Expression,
|
||||
stride_h: Expression,
|
||||
stride_w: Expression,
|
||||
dilation_h: Expression,
|
||||
dilation_w: Expression,
|
||||
pad_h: Expression,
|
||||
pad_w: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelConv2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelConv2D",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("input_shape", ELIST),
|
||||
("input_stride", ELIST),
|
||||
("weight_co_stride", EXPRESSION),
|
||||
("weight_inner_stride", EXPRESSION),
|
||||
("bias_c_stride", EXPRESSION),
|
||||
("out_stride", ELIST),
|
||||
("kernel_h", EXPRESSION),
|
||||
("kernel_w", EXPRESSION),
|
||||
("stride_h", EXPRESSION),
|
||||
("stride_w", EXPRESSION),
|
||||
("dilation_h", EXPRESSION),
|
||||
("dilation_w", EXPRESSION),
|
||||
("pad_h", EXPRESSION),
|
||||
("pad_w", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// 1x1 convs in Flux2's VAE are represented without `unfold`:
|
||||
//
|
||||
// input.permute([H,W,C]).merge(H,W)
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute back to [C_out,H,W]
|
||||
// -> + channel bias
|
||||
//
|
||||
// The lowered form is still the same Mul -> KernelSum -> Add
|
||||
// matmul skeleton, but the lhs FusionStart reads directly from the
|
||||
// original input instead of a KernelGather window tensor.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered bias matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the same conv after generic CUDA lowering has normalized
|
||||
// the elementwise pieces into fusion regions:
|
||||
//
|
||||
// KernelGather(input windows)
|
||||
// -> CudaBinaryElementwise("Mul", weight)
|
||||
// -> KernelSum(reduce K)
|
||||
// -> CudaBinaryElementwise("Add", bias)
|
||||
//
|
||||
// This is the form that survives long enough for CUDA search in
|
||||
// real models. The KernelConv2D op consumes the pre-gather input
|
||||
// and avoids materializing both the im2col window tensor and the
|
||||
// elementwise product tensor.
|
||||
//
|
||||
// TODO(egglog-shapes): the current e-graph does not reliably prove
|
||||
// the derived arithmetic equalities for this chain after CUDA
|
||||
// normalization:
|
||||
// * `M == H_out * W_out`
|
||||
// * `K == C_in * KH * KW`
|
||||
// * separately-derived but structurally identical stride
|
||||
// expressions, e.g. the Mul output stride and KernelSum input
|
||||
// stride, belong to the same e-class.
|
||||
// Keep the rewrite anchored on the stable conv layout facts the
|
||||
// graph does carry today: six-axis unfold window shape, flattened
|
||||
// `[M, C_out, K]` product, reduction over `K`, the three-axis
|
||||
// `[C_out, H_out, W_out]` output view, and channel-only bias
|
||||
// broadcast. Once expression/list canonicalization can prove those
|
||||
// equalities, tighten this rule and its regression tests.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the im2col-style HLIR conv used by Flux2:
|
||||
//
|
||||
// input.unfold([1, kh, kw], [1, 1, 1], [1, 1, 1])
|
||||
// -> squeeze/permute/merge view
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute view
|
||||
// -> + bias.expand_dim(1, h_out).expand_dim(2, w_out)
|
||||
//
|
||||
// The kernel consumes the pre-unfold input directly. That input may
|
||||
// already be a padded HLIR tensor, so the rewrite is still correct
|
||||
// for Flux2's padded convs while removing the large patch matrix.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
; This rewrite is for stride=1, dilation=1 over the
|
||||
; tensor passed to unfold. Padded HLIR inputs are already
|
||||
; represented as their own tensor, so padding is 0 here.
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?add (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (Add ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?fe (Op (FusionEnd ?shape ?os ?dt) ?inputs))
|
||||
(= ?fe (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?conv_dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (FusionEnd ?shape ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a luminal::egglog_utils::NodeId],
|
||||
input_enodes: Vec<&'a luminal::egglog_utils::NodeId>,
|
||||
list_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::egglog_utils::NodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
weight_co_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
weight_inner_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
|
||||
bias_c_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
kernel_h: extract_expr(egraph, kind_children[7], expr_cache).unwrap(),
|
||||
kernel_w: extract_expr(egraph, kind_children[8], expr_cache).unwrap(),
|
||||
stride_h: extract_expr(egraph, kind_children[9], expr_cache).unwrap(),
|
||||
stride_w: extract_expr(egraph, kind_children[10], expr_cache).unwrap(),
|
||||
dilation_h: extract_expr(egraph, kind_children[11], expr_cache).unwrap(),
|
||||
dilation_w: extract_expr(egraph, kind_children[12], expr_cache).unwrap(),
|
||||
pad_h: extract_expr(egraph, kind_children[13], expr_cache).unwrap(),
|
||||
pad_w: extract_expr(egraph, kind_children[14], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[15]),
|
||||
}) as Box<dyn KernelOp>),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelConv2D {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert_eq!(self.dtype, DType::F32, "KernelConv2D currently emits F32");
|
||||
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let c_out = self.out_shape[0].to_kernel();
|
||||
let h_out = self.out_shape[1].to_kernel();
|
||||
let w_out = self.out_shape[2].to_kernel();
|
||||
let c_in = self.input_shape[0].to_kernel();
|
||||
let h_in = self.input_shape[1].to_kernel();
|
||||
let w_in = self.input_shape[2].to_kernel();
|
||||
let weight_co_stride = self
|
||||
.weight_co_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let weight_inner_stride = self
|
||||
.weight_inner_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let bias_c_stride = self
|
||||
.bias_c_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let kh = self.kernel_h.to_kernel();
|
||||
let kw = self.kernel_w.to_kernel();
|
||||
let stride_h = self.stride_h.to_kernel();
|
||||
let stride_w = self.stride_w.to_kernel();
|
||||
let dilation_h = self.dilation_h.to_kernel();
|
||||
let dilation_w = self.dilation_w.to_kernel();
|
||||
let pad_h = self.pad_h.to_kernel();
|
||||
let pad_w = self.pad_w.to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let input_idx = flatten_strides(&self.input_shape, &self.input_stride)
|
||||
.to_kernel()
|
||||
.replace("const_z", "input_linear");
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_conv2d_bias(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias{dyn_dims_param}
|
||||
) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long long total = {total};
|
||||
if (const_z >= total) return;
|
||||
|
||||
const long long COUT = {c_out};
|
||||
const long long HOUT = {h_out};
|
||||
const long long WOUT = {w_out};
|
||||
const long long CIN = {c_in};
|
||||
const long long HIN = {h_in};
|
||||
const long long WIN = {w_in};
|
||||
const long long KH = {kh};
|
||||
const long long KW = {kw};
|
||||
const long long SH = {stride_h};
|
||||
const long long SW = {stride_w};
|
||||
const long long DH = {dilation_h};
|
||||
const long long DW = {dilation_w};
|
||||
const long long PH = {pad_h};
|
||||
const long long PW = {pad_w};
|
||||
const long long W_CO_STRIDE = {weight_co_stride};
|
||||
const long long W_INNER_STRIDE = {weight_inner_stride};
|
||||
const long long BIAS_C_STRIDE = {bias_c_stride};
|
||||
|
||||
long long co = const_z / (HOUT * WOUT);
|
||||
long long rem = const_z - co * HOUT * WOUT;
|
||||
long long oh = rem / WOUT;
|
||||
long long ow = rem - oh * WOUT;
|
||||
|
||||
float acc = bias[co * BIAS_C_STRIDE];
|
||||
for (long long ci = 0; ci < CIN; ++ci) {{
|
||||
for (long long r = 0; r < KH; ++r) {{
|
||||
long long ih = oh * SH + r * DH - PH;
|
||||
if (ih < 0 || ih >= HIN) continue;
|
||||
for (long long s = 0; s < KW; ++s) {{
|
||||
long long iw = ow * SW + s * DW - PW;
|
||||
if (iw < 0 || iw >= WIN) continue;
|
||||
long long input_linear = (ci * HIN + ih) * WIN + iw;
|
||||
long long input_idx = {input_idx};
|
||||
long long inner = (ci * KH + r) * KW + s;
|
||||
long long weight_idx = co * W_CO_STRIDE + inner * W_INNER_STRIDE;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
out[{out_idx}] = acc;
|
||||
}}
|
||||
}}",
|
||||
total = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_conv2d_bias").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs.ceil_div(256), 1.into(), 1.into()),
|
||||
(n_outputs.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2 * 4 + self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericConv2D"
|
||||
}
|
||||
}
|
||||
@@ -425,7 +425,7 @@ mod tests {
|
||||
fn test_raw_function_extraction() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
@@ -448,7 +448,7 @@ mod tests {
|
||||
use cudarc::driver::{CudaSlice, DevicePtr};
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
@@ -492,13 +492,14 @@ mod tests {
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = ((a + b) * a + b).output();
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -523,13 +524,14 @@ mod tests {
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = (a + b + a + b).output();
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -559,14 +561,15 @@ mod tests {
|
||||
let b = cx.tensor('s');
|
||||
let c = (a + b).output();
|
||||
let d = (c * a).output();
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
@@ -601,13 +604,14 @@ mod tests {
|
||||
let a = cx.tensor(size);
|
||||
let b = cx.tensor(size);
|
||||
let c = (a + b).output();
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
@@ -631,13 +635,14 @@ mod tests {
|
||||
result *= b;
|
||||
}
|
||||
let output = result.output();
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
@@ -648,4 +653,53 @@ mod tests {
|
||||
}
|
||||
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test that CUDA graphs produce correct results when dynamic dimensions
|
||||
/// change incrementally across many executions (simulating a decode loop
|
||||
/// where position offset increments each step).
|
||||
#[test]
|
||||
fn test_cuda_graph_incremental_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor('s');
|
||||
let b = cx.tensor('s');
|
||||
let c = ((a + b) * a).output();
|
||||
|
||||
let initial_size = 128;
|
||||
cx.set_dim('s', initial_size);
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(initial_size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
|
||||
// Incrementally change the dynamic dimension 10 times,
|
||||
// simulating decode steps where position offset grows.
|
||||
for step in 1..=10usize {
|
||||
let size = initial_size + step;
|
||||
cx.set_dim('s', size);
|
||||
let da = random_f32_vec(size, 100 + step as u64, -0.5, 0.5);
|
||||
let db = random_f32_vec(size, 200 + step as u64, -0.5, 0.5);
|
||||
rt.set_data(a, da.clone());
|
||||
rt.set_data(b, db.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = da.iter().zip(&db).map(|(a, b)| (a + b) * a).collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
393
crates/luminal_cuda_lite/src/kernel/fusion/elementwise.rs
Normal file
393
crates/luminal_cuda_lite/src/kernel/fusion/elementwise.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
// =========================================================================
|
||||
// Generic CUDA elementwise ops used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// CUDA elementwise execution is represented as a FusionEnd-rooted region even
|
||||
// for a single op. These ops are therefore region-internal only; standalone
|
||||
// compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (CudaUnaryElementwise, CudaBinaryElementwise);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
fn extract_string_label(egraph: &SerializedEGraph, node: &ENodeId) -> String {
|
||||
egraph.enodes[node].0.trim_matches('"').to_string()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaUnaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaUnaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaUnaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
for (hlir, opcode) in [
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?out_s) (ICons ?x (INil))))
|
||||
(= ?dt (dtype ?u))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?out_s ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?sqrt (Op (Sqrt ?shape ?x_stride ?sqrt_stride) (ICons ?x (INil))))
|
||||
(= ?recip (Op (Recip ?shape ?sqrt_stride ?out_stride) (ICons ?sqrt (INil))))
|
||||
(= ?dt (dtype ?recip))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Rsqrt\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?recip ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-rsqrt-from-sqrt-recip\")",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Exp\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?exp2 ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(datatype*
|
||||
(CudaSigmoidScaledState
|
||||
(MkCudaSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function cuda_sigmoid_scaled (IR) CudaSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (cuda_sigmoid_scaled ?scaled)
|
||||
(MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-region-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (cuda_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Sigmoid\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?sig_out ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaUnaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaUnaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaUnaryElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaBinaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaBinaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaBinaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Add ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?bin))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Add\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Add\")",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Mul ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?a))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Mul\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Mul\")",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[4], list_cache, expr_cache).unwrap();
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype: extract_dtype(egraph, kind_children[5]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaBinaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaBinaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes() * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaBinaryElementwise"
|
||||
}
|
||||
}
|
||||
414
crates/luminal_cuda_lite/src/kernel/fusion/markers.rs
Normal file
414
crates/luminal_cuda_lite/src/kernel/fusion/markers.rs
Normal file
@@ -0,0 +1,414 @@
|
||||
// =========================================================================
|
||||
// Fusion boundary markers — FusionStart and FusionEnd.
|
||||
//
|
||||
// Tag-like LLIR ops that bracket a region of elementwise ops destined to
|
||||
// be emitted as a single CUDA kernel:
|
||||
// - N FusionStart nodes per region (one per FS leaf — distinct external
|
||||
// reads),
|
||||
// - exactly 1 FusionEnd per region.
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Both markers' `compile()` is
|
||||
// `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (FusionStart, FusionEnd);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// FusionStart
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionStart {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionStart",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No idempotence rule. `FusionStart(FusionStart(x)) ≡ FusionStart(x)`
|
||||
// would unify nested markers and create eclass cycles via the
|
||||
// pair-fuse rules; without it, occasional re-firings produce extra
|
||||
// semantically-correct identity layers, bounded by the run schedule.
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionStart {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionStart must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// FusionEnd
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionEnd {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionEnd",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Generic region growth works directly from HLIR elementwise ops into
|
||||
// `Cuda*Elementwise` region nodes. The concrete HLIR op still appears in
|
||||
// the egraph, so fusion remains a normal nondestructive alternative, but
|
||||
// the region-internal representation is arity based instead of one
|
||||
// dedicated fused sort per operation.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
];
|
||||
let binaries: &[(&str, &str)] = &[("Add", "Add"), ("Mul", "Mul")];
|
||||
|
||||
// Grow FE → unary consumer: U(FE(inner)) → FE(CudaUnary(inner)).
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Grow FE → binary consumer, left and right orientations.
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Absorb an elementwise producer through a FusionStart boundary. This
|
||||
// makes a region that initially treats `producer(...)` as an external
|
||||
// input able to pull that producer inside later.
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?x (INil))))
|
||||
(= ?fs_u (Op (FusionStart ?shape ?s ?dt) (ICons ?u (INil))))
|
||||
) (
|
||||
(let ?fs_x (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?fs_x (INil))))
|
||||
(union ?fs_u ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-U-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?bad_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?bad_fs (INil))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-unary-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?fs_bin (Op (FusionStart ?shape ?out_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(union ?fs_bin ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-B-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?bad_fs (ICons ?fs_b (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?bad_fs (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(CudaBinary(ia, ib)).
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// No dissolve rule (`FS(FE(x)) → x`): unioning FS's eclass with FE's
|
||||
// inner eclass creates self-referential eclasses after grow rules
|
||||
// extend the downstream region, and extraction then panics with
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-strides\")",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionEnd {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionEnd must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionEnd"
|
||||
}
|
||||
}
|
||||
22
crates/luminal_cuda_lite/src/kernel/fusion/mod.rs
Normal file
22
crates/luminal_cuda_lite/src/kernel/fusion/mod.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Binary-inclusive elementwise kernel fusion.
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `elementwise` — generic region-internal CUDA elementwise op variants.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / generic elementwise / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod elementwise;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use elementwise::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior generic elementwise variants). Combined into a flat
|
||||
/// tuple for the `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, elementwise::Ops);
|
||||
640
crates/luminal_cuda_lite/src/kernel/fusion/region_codegen.rs
Normal file
640
crates/luminal_cuda_lite/src/kernel/fusion/region_codegen.rs
Normal file
@@ -0,0 +1,640 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// Older fusion lowering left elementwise / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — unfused non-region kernels, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior elementwise DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all elementwise bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior Cuda*Elementwise / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use as_any::Downcast;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::elementwise::{CudaBinaryElementwise, CudaUnaryElementwise},
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Compile units — what `kernel_to_host` iterates over instead of nodes.
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior Cuda*Elementwise nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub elementwise_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
/// strides; the host launch passes the same device pointer twice.
|
||||
pub fs_nodes: Vec<NodeIndex>,
|
||||
/// External producer NodeIndices, one per `fs_nodes` entry in the same
|
||||
/// order. Becomes the `inputs` field of the FE's `CompiledKernel`, and
|
||||
/// the kernel function's `in0`, `in1`, ... parameters in that order.
|
||||
pub external_inputs: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum CompileUnit {
|
||||
Single(NodeIndex),
|
||||
Region(RegionUnit),
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region detection.
|
||||
// =========================================================================
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// Cuda*Elementwise and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / Cuda*Elementwise nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
/// subgraph, because `partition_marked_convex` may put a shared FS leaf
|
||||
/// (one whose e-graph congruence-deduplicated it across multiple
|
||||
/// regions) into a different subgraph than the FE that absorbs it.
|
||||
/// Without this global view, `build_compile_units` running on the FS's
|
||||
/// subgraph would not see any FE walking back to the FS and would emit the
|
||||
/// FS as `CompileUnit::Single`; marker standalone compilation is not supported.
|
||||
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
for fe in llir_graph.node_indices() {
|
||||
if name_of(fe) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = vec![fe];
|
||||
visited.insert(fe);
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
absorbed.insert(pred);
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
absorbed
|
||||
}
|
||||
|
||||
pub(crate) fn build_compile_units(
|
||||
topo_order: &[NodeIndex],
|
||||
llir_graph: &LLIRGraph,
|
||||
globally_absorbed: &FxHashSet<NodeIndex>,
|
||||
) -> Vec<CompileUnit> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
// First pass: every FusionEnd in the subgraph anchors a region; gather
|
||||
// the region's interior + FS leaves by walking incoming edges
|
||||
// backward, stopping at FusionStart (a leaf — its predecessor is the
|
||||
// external producer, outside the region).
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut regions: FxHashMap<NodeIndex, RegionUnit> = FxHashMap::default();
|
||||
|
||||
for &node in topo_order {
|
||||
if name_of(node) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut interior: Vec<NodeIndex> = Vec::new();
|
||||
let mut fs_nodes: Vec<NodeIndex> = Vec::new();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = Vec::new();
|
||||
stack.push(node);
|
||||
visited.insert(node);
|
||||
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
fs_nodes.push(pred);
|
||||
// Don't recurse past FS — its predecessor is
|
||||
// external (outside the region).
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// A nested FE inside a region. Under the current
|
||||
// rule design these are cascade artifacts — treat
|
||||
// them as transparent (walk through) rather than
|
||||
// as a separate region. The outer region absorbs
|
||||
// them. They do not become CompileUnit::Region
|
||||
// anchors because their eclass is already the
|
||||
// outer region's.
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-elementwise predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
// malformed and we likely should not have a
|
||||
// region at all; caller will see incomplete
|
||||
// interior.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Topological order on the interior + FS nodes (so the kernel
|
||||
// emits `let v = ...;` lines after their inputs are bound). We
|
||||
// use the parent graph's toposort filtered to in-region nodes.
|
||||
let mut region_set: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
region_set.extend(interior.iter().copied());
|
||||
region_set.extend(fs_nodes.iter().copied());
|
||||
let topo = toposort(llir_graph, None).expect("LLIR cycle in region detection");
|
||||
let interior_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && interior.contains(n))
|
||||
.collect();
|
||||
let fs_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && fs_nodes.contains(n))
|
||||
.collect();
|
||||
|
||||
// External producer for each FS leaf, in the same order.
|
||||
let external_inputs: Vec<NodeIndex> = fs_topo
|
||||
.iter()
|
||||
.map(|&fs| {
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap_or_else(|| {
|
||||
// Dump the malformed structure: which FE
|
||||
// triggered the walk, every node in fs_topo and
|
||||
// interior_topo, and each FS's incoming /
|
||||
// outgoing degree. Helps localize whether the
|
||||
// missing edge came from extraction or a
|
||||
// downstream LLIR transform.
|
||||
if std::env::var("LUMINAL_DEBUG_FUSION_PANIC").is_ok() {
|
||||
eprintln!(
|
||||
"FusionStart panic: fe={} (kernel={:?})",
|
||||
node.index(),
|
||||
llir_graph.node_weight(node).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
}),
|
||||
);
|
||||
eprintln!(" fs_topo ({}):", fs_topo.len());
|
||||
for &f in &fs_topo {
|
||||
let in_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Incoming)
|
||||
.count();
|
||||
let out_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Outgoing)
|
||||
.count();
|
||||
let kn = llir_graph
|
||||
.node_weight(f)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(
|
||||
" fs={} kind={} in_deg={} out_deg={}",
|
||||
f.index(),
|
||||
kn,
|
||||
in_deg,
|
||||
out_deg,
|
||||
);
|
||||
}
|
||||
eprintln!(" interior_topo ({}):", interior_topo.len());
|
||||
for &i in &interior_topo {
|
||||
let kn = llir_graph
|
||||
.node_weight(i)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(" interior={} kind={}", i.index(), kn);
|
||||
}
|
||||
}
|
||||
panic!("FusionStart with no predecessor")
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
absorbed.extend(interior_topo.iter().copied());
|
||||
absorbed.extend(fs_topo.iter().copied());
|
||||
|
||||
regions.insert(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
elementwise_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Second pass: emit compile units in original topo order, replacing
|
||||
// FE nodes with their RegionUnit and skipping anything absorbed —
|
||||
// either by a region in *this* subgraph (`absorbed`) or by any
|
||||
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
|
||||
// latter prevents shared FS markers whose consumers live in other
|
||||
// convex subgraphs from being emitted as standalone compile units:
|
||||
// those FSes are absorbed by some other region, and the consuming
|
||||
// region reads from FS's external producer.
|
||||
let mut units: Vec<CompileUnit> = Vec::new();
|
||||
for &node in topo_order {
|
||||
if let Some(region) = regions.remove(&node) {
|
||||
units.push(CompileUnit::Region(region));
|
||||
} else if absorbed.contains(&node) || globally_absorbed.contains(&node) {
|
||||
continue;
|
||||
} else {
|
||||
units.push(CompileUnit::Single(node));
|
||||
}
|
||||
}
|
||||
units
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-elementwise body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
|
||||
llir_graph
|
||||
.node_weight(node)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>())
|
||||
.is_some_and(|op| {
|
||||
(***op).downcast_ref::<CudaUnaryElementwise>().is_some()
|
||||
|| (***op).downcast_ref::<CudaBinaryElementwise>().is_some()
|
||||
})
|
||||
}
|
||||
|
||||
fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("static_cast<float>({local})")
|
||||
} else {
|
||||
local.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("{cuda_ty}({expr})")
|
||||
} else {
|
||||
expr.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
|
||||
let a = || elementwise_value(locals[0], dtype);
|
||||
let b = || elementwise_value(locals[1], dtype);
|
||||
match op {
|
||||
"Sin" => format!("sinf({})", a()),
|
||||
"Sqrt" => format!("sqrtf({})", a()),
|
||||
"Rsqrt" => format!("rsqrtf({})", a()),
|
||||
"Exp" => format!("expf({})", a()),
|
||||
"Exp2" => format!("exp2f({})", a()),
|
||||
"Log2" => format!("log2f({})", a()),
|
||||
"Recip" => format!("1.0f / {}", a()),
|
||||
"Sigmoid" => format!("1.0f / (1.0f + expf(-{}))", a()),
|
||||
"Add" => format!("{} + {}", a(), b()),
|
||||
"Mul" => format!("{} * {}", a(), b()),
|
||||
other => panic!("region_codegen: unknown elementwise op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region compilation — emit one CUDA kernel for the whole region.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct CompiledRegion {
|
||||
pub function: CudaFunction,
|
||||
pub module: Arc<CudaModule>,
|
||||
pub kernel_str: String,
|
||||
pub grid: (Expression, Expression, Expression),
|
||||
pub block: (Expression, Expression, Expression),
|
||||
pub shared_mem: Expression,
|
||||
pub constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn compile_region(
|
||||
region: &RegionUnit,
|
||||
llir_graph: &LLIRGraph,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompiledRegion {
|
||||
// Resolve FE: shape, strides (for the write), dtype.
|
||||
let fe_op = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.expect("FE node must be a KernelOp");
|
||||
let fe_struct: &FusionEnd = (***fe_op)
|
||||
.downcast_ref::<FusionEnd>()
|
||||
.expect("region root must be FusionEnd");
|
||||
let out_shape: &[Expression] = &fe_struct.shape;
|
||||
let out_strides: &[Expression] = &fe_struct.strides;
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides and elementwise shapes.
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
for &fs_idx in ®ion.fs_nodes {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
for &elem_idx in ®ion.elementwise_topo {
|
||||
let elem_op = llir_graph[elem_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
if let Some(elem) = (***elem_op).downcast_ref::<CudaUnaryElementwise>() {
|
||||
all_vars.extend(elem.shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.in_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
} else if let Some(elem) = (***elem_op).downcast_ref::<CudaBinaryElementwise>() {
|
||||
all_vars.extend(elem.out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.a_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.b_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&all_vars);
|
||||
let dyn_dims_param = if all_vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
|
||||
// Build kernel signature: out, then one input per FS leaf in
|
||||
// `region.fs_nodes` order. The `external_inputs` list (parallel to
|
||||
// `fs_nodes`) is what the host wires into the launch params.
|
||||
let mut signature_params: Vec<String> = vec![format!("{cuda_ty} *out")];
|
||||
for i in 0..region.fs_nodes.len() {
|
||||
signature_params.push(format!("const {cuda_ty} *in{i}"));
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk elementwise nodes in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), elementwise nodes get fs_nodes.len()..(+ elementwise_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.elementwise_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
|
||||
let mut body = String::new();
|
||||
body.push_str(&format!(
|
||||
" long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n"
|
||||
));
|
||||
|
||||
// FS leaves: each reads from its corresponding `in_i` parameter using
|
||||
// its own strides.
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
let read_idx = flatten_strides(out_shape, &fs_struct.strides).to_kernel();
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = in{i}[{read_idx}];\n",
|
||||
name = local_name(fs_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// Elementwise ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.elementwise_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let (elem_name, elem_dtype) =
|
||||
if let Some(elem) = (***op_ref).downcast_ref::<CudaUnaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else if let Some(elem) = (***op_ref).downcast_ref::<CudaBinaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else {
|
||||
panic!(
|
||||
"region_codegen: expected Cuda*Elementwise op, got {}",
|
||||
op_ref.kernel_name()
|
||||
);
|
||||
};
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|(_, src)| local_name(src))
|
||||
.collect();
|
||||
// Sort by edge id like the rest of the codegen does for stable
|
||||
// input ordering.
|
||||
let mut edges: Vec<(_, NodeIndex)> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect();
|
||||
edges.sort_by_key(|(eid, _)| *eid);
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = elementwise_body(elem_name, &inputs_ref, elem_dtype);
|
||||
let expr = elementwise_init_expr(&expr, elem_dtype, cuda_ty);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the elementwise node feeding FE (its single incoming edge in
|
||||
// the region — an elementwise node or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionEnd with no predecessor");
|
||||
let fe_input_local = local_name(fe_input);
|
||||
let write_idx = flatten_strides(out_shape, out_strides).to_kernel();
|
||||
body.push_str(&format!(" out[{write_idx}] = {fe_input_local};\n"));
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}\n\
|
||||
{dyn_defines}\n\
|
||||
extern \"C\" {{\n\
|
||||
\x20 __global__ void fused_region_k({signature}{dyn_dims_param}) {{\n\
|
||||
{body}\
|
||||
\x20 }}\n\
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, function) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("region kernel PTX compile failed");
|
||||
let module = stream
|
||||
.context()
|
||||
.load_module(ptx)
|
||||
.expect("module load failed");
|
||||
let function = module
|
||||
.load_function("fused_region_k")
|
||||
.expect("region kernel function not found");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), function.clone()));
|
||||
(module, function)
|
||||
};
|
||||
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
|
||||
CompiledRegion {
|
||||
function,
|
||||
module,
|
||||
kernel_str: kernel,
|
||||
grid: (out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
block: (out_size.min(256), 1.into(), 1.into()),
|
||||
shared_mem: 0.into(),
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::fusion::elementwise::CudaBinaryElementwise;
|
||||
use luminal::op::LLIROp;
|
||||
use luminal::prelude::petgraph::algo::toposort;
|
||||
|
||||
/// Helper: wrap a `KernelOp` in an `LLIROp` of the kernel dialect.
|
||||
fn llir_of(op: impl KernelOp + 'static) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(op) as Box<dyn KernelOp>)
|
||||
}
|
||||
|
||||
/// Reproducer for the `FusionStart with no predecessor` panic at
|
||||
/// `region_codegen.rs:232`. The egglog rolling pass + iterated mode
|
||||
/// (`LUMINAL_LOOP_ROLL_ITERATE=1`) has been observed to produce LLIR
|
||||
/// graphs where a `FusionStart` marker is reached as a region leaf
|
||||
/// during the FE→FS walk but has no incoming edge — meaning the
|
||||
/// region has nothing to read from. `build_compile_units` then
|
||||
/// panics when constructing `external_inputs` because every FS leaf
|
||||
/// is required to have exactly one external producer.
|
||||
///
|
||||
/// Until that path is fixed, this test pins the failure mode so a
|
||||
/// regression doesn't silently change the panic message or location.
|
||||
/// `should_panic` rather than `ignore` so it stays runnable in CI
|
||||
/// and surfaces if the panic ever moves.
|
||||
#[test]
|
||||
#[should_panic(expected = "FusionStart with no predecessor")]
|
||||
fn fusion_start_with_no_predecessor_panics() {
|
||||
// Minimal reproducer:
|
||||
//
|
||||
// (no input) ──▶ FusionStart ──▶ CudaBinaryElementwise ──▶ FusionEnd
|
||||
//
|
||||
// CudaBinaryElementwise is a binary op (n_inputs = 2) so a real region would
|
||||
// have two FS leaves. For this panic-shape test only the *first*
|
||||
// FS leaf needs a missing predecessor — `build_compile_units`
|
||||
// panics in `expect("FusionStart with no predecessor")` as soon
|
||||
// as any FS in `fs_topo` lacks one. We add only one FS edge so
|
||||
// CudaBinaryElementwise has a dangling second input slot, but that's fine:
|
||||
// we're testing the specific panic path inside `build_compile_units`,
|
||||
// not full kernel codegen.
|
||||
let mut llir: LLIRGraph = LLIRGraph::default();
|
||||
|
||||
let fs_node = llir.add_node(llir_of(FusionStart::default()));
|
||||
let fadd_node = llir.add_node(llir_of(CudaBinaryElementwise::default()));
|
||||
let fe_node = llir.add_node(llir_of(FusionEnd::default()));
|
||||
|
||||
// FusionStart → CudaBinaryElementwise → FusionEnd.
|
||||
llir.add_edge(fs_node, fadd_node, ());
|
||||
llir.add_edge(fadd_node, fe_node, ());
|
||||
|
||||
let topo = toposort(&llir, None).expect("LLIR cycle in test setup");
|
||||
let absorbed = globally_absorbed_markers(&llir);
|
||||
|
||||
// This is the call that panics with `FusionStart with no
|
||||
// predecessor` because `fs_node`'s incoming-edges iterator is
|
||||
// empty.
|
||||
let _ = build_compile_units(&topo, &llir, &absorbed);
|
||||
}
|
||||
}
|
||||
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{
|
||||
KernelOp,
|
||||
hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
shape::flatten_strides,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct GenericMatmul {
|
||||
out_shape: Vec<Expression>,
|
||||
mul_shape: Vec<Expression>,
|
||||
k: Expression,
|
||||
lhs_strides: Vec<Expression>,
|
||||
rhs_strides: Vec<Expression>,
|
||||
sum_input_strides: Vec<Expression>,
|
||||
sum_iter_stride: Expression,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for GenericMatmul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"GenericMatmul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("mul_shape", ELIST),
|
||||
("k", EXPRESSION),
|
||||
("lhs_strides", ELIST),
|
||||
("rhs_strides", ELIST),
|
||||
("sum_input_strides", ELIST),
|
||||
("sum_iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?sum))
|
||||
)
|
||||
(
|
||||
(let ?generic (Op (GenericMatmul
|
||||
?out_shape
|
||||
?mul_shape
|
||||
?k
|
||||
?lhs_strides
|
||||
?rhs_strides
|
||||
?sum_input_strides
|
||||
?sum_iter_stride
|
||||
?out_strides
|
||||
?dt)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(union ?sum ?generic)
|
||||
(set (dtype ?generic) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"generic-matmul-cuda-mul-sum\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
(
|
||||
(delete (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?kernel_sum (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs))
|
||||
(= ?kernel_sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
((delete (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
mul_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
lhs_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
rhs_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
sum_input_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[5],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
sum_iter_stride: extract_expr(egraph, kind_children[6], expr_cache).unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[7], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[8]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for GenericMatmul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.all_dyn_vars();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_outputs = self.output_size();
|
||||
let sum_base_idx = flatten_strides(&self.out_shape, &self.sum_input_strides).to_kernel();
|
||||
let iter_offset = self.sum_iter_stride.to_kernel().replace("const_z", "i");
|
||||
let lhs_idx = flatten_strides(&self.mul_shape, &self.lhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let rhs_idx = flatten_strides(&self.mul_shape, &self.rhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
|
||||
let k = self.k.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_matmul({dtype} *out, const {dtype} *lhs, const {dtype} *rhs{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
if (const_z >= {n_outputs}) return;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long base_idx = {sum_base_idx};
|
||||
long long iters = {k};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
long long mul_idx = base_idx + {iter_offset};
|
||||
partial += static_cast<float>(lhs[{lhs_idx}]) * static_cast<float>(rhs[{rhs_idx}]);
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
float block_sum = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s >>= 1) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = ({dtype})block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}",
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.mul_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k.dyn_vars())
|
||||
.chain(self.lhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.rhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_input_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_iter_stride.dyn_vars())
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.output_size() * self.k * self.dtype.bits() * 2).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericMatmul"
|
||||
}
|
||||
}
|
||||
2186
crates/luminal_cuda_lite/src/kernel/hlir.rs
Normal file
2186
crates/luminal_cuda_lite/src/kernel/hlir.rs
Normal file
File diff suppressed because it is too large
Load Diff
427
crates/luminal_cuda_lite/src/kernel/matmul2d.rs
Normal file
427
crates/luminal_cuda_lite/src/kernel/matmul2d.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
//! Direct 2D matmul kernel — bypasses egglog rewrites, used as a custom op
|
||||
//! for matmul shapes where the cublaslt egg rules don't reliably fire.
|
||||
//!
|
||||
//! The cublaslt 2D rules in `host/cublaslt/cublaslt_*Cm_rewrite.egg` /
|
||||
//! `cublaslt_Rm*_rewrite.egg` are *supposed* to match any 2D matmul whose
|
||||
//! Mul + SumReduce broadcast lowering has the expected stride patterns,
|
||||
//! and the conditional matmul cleanup is *supposed* to delete the
|
||||
//! elementwise Mul + KernelSumReduce fallback whenever a cublaslt alternative
|
||||
//! exists. In practice both fail to fire reliably for the VAE's mid-block
|
||||
//! `AttnBlock` matmuls — at 1024² that lets the search occasionally pick
|
||||
//! the broadcast-Mul path for `q @ kᵀ`, generating a `(HW, HW, C) =
|
||||
//! (16384, 16384, 512)` ≈ 524 GiB single intermediate that OOMs the GPU.
|
||||
//!
|
||||
//! Same approach as `kernel::conv2d`: define a `KernelOp`, wrap it in a
|
||||
//! `CustomOp`, expose a tiny `pub fn` so callers don't see the
|
||||
//! `cx.custom_op` plumbing. This is opaque to egglog by design — we
|
||||
//! aren't trying to fuse with surrounding ops, just guarantee a sane
|
||||
//! lowering for the matmuls we know are problematic.
|
||||
//!
|
||||
//! The CUDA implementation is a textbook 2D-blocked SGEMM:
|
||||
//! * 16×16 output tile per block (256 threads)
|
||||
//! * Tiled load of A and B into shared memory in K-size chunks
|
||||
//! * Each thread accumulates one output element across all K-tiles
|
||||
//! * Optional bias broadcast along the M axis at write-out
|
||||
//! * `transpose_b` toggles between row-major B `(K, N)` and row-major
|
||||
//! B `(N, K)` (i.e. the `A @ Bᵀ` pattern that linear/projection
|
||||
//! layers use).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct 2D matmul `(M, K) × {(K, N) | (N, K)} → (M, N)` with optional
|
||||
/// per-output-column bias and an optional batch axis. A and output are
|
||||
/// always F32. B can be F32 or BF16; BF16 is converted to F32 on each
|
||||
/// load, which avoids materializing the cast as a separate intermediate
|
||||
/// tensor (important for the text encoder / transformer where the F32-
|
||||
/// cast weights would not fit in GPU memory). All shape parameters are
|
||||
/// static (baked into the CUDA source via #defines).
|
||||
///
|
||||
/// When `batch > 1` the kernel does `batch` independent 2D matmuls in
|
||||
/// parallel: A is `(batch, M, K)`, B is `(batch, *, *)` with the same
|
||||
/// per-batch shape, output is `(batch, M, N)`. All three are assumed
|
||||
/// contiguous row-major across batches (i.e. `a_batch_stride = M*K`,
|
||||
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
|
||||
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
|
||||
/// `(N,)` and broadcast across batches.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DKernel {
|
||||
pub m: usize,
|
||||
pub n: usize,
|
||||
pub k: usize,
|
||||
pub batch: usize,
|
||||
/// If `true`, B is interpreted as `(N, K)` row-major and accessed as
|
||||
/// `B[n][k]` (i.e. `A @ Bᵀ`). If `false`, B is `(K, N)` row-major and
|
||||
/// accessed as `B[k][n]` (i.e. `A @ B`).
|
||||
pub transpose_b: bool,
|
||||
pub has_bias: bool,
|
||||
/// Storage dtype of B. Currently F32 or BF16 are supported.
|
||||
pub weight_dtype: DType,
|
||||
}
|
||||
|
||||
const TILE: usize = 16;
|
||||
|
||||
impl KernelOp for Matmul2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let bias_param = if self.has_bias {
|
||||
", const float* __restrict__ bias"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let bias_add = if self.has_bias {
|
||||
" acc += bias[n];\n"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
|
||||
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
|
||||
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// Plus the per-batch offset (`b_batch_off`).
|
||||
let b_index_expr = if self.transpose_b {
|
||||
"b_batch_off + (b_n_base + tx) * K + (k0 + ty)"
|
||||
} else {
|
||||
"b_batch_off + (k0 + ty) * N + (b_n_base + tx)"
|
||||
};
|
||||
// Convert B's element to float on load. For BF16 we declare B as
|
||||
// `__nv_bfloat16*` and use `__bfloat162float`; for F32 it's a no-op.
|
||||
let (b_param_type, b_load_expr, bf16_include) = match self.weight_dtype {
|
||||
DType::F32 => (
|
||||
"const float* __restrict__ B",
|
||||
format!("B[{b_index_expr}]"),
|
||||
"",
|
||||
),
|
||||
DType::Bf16 => (
|
||||
"const __nv_bfloat16* __restrict__ B",
|
||||
format!("__bfloat162float(B[{b_index_expr}])"),
|
||||
"#include <cuda_bf16.h>\n",
|
||||
),
|
||||
other => panic!("Matmul2DKernel: unsupported weight_dtype {other:?}"),
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ A,
|
||||
{b_param_type}{bias_param}
|
||||
) {{
|
||||
const int M = {m};
|
||||
const int N = {n};
|
||||
const int K = {k};
|
||||
const int TILE = {tile};
|
||||
|
||||
__shared__ float As[{tile}][{tile}];
|
||||
__shared__ float Bs[{tile}][{tile}];
|
||||
|
||||
int bx = blockIdx.x; // tile column (n)
|
||||
int by = blockIdx.y; // tile row (m)
|
||||
int batch = blockIdx.z; // batch index (0..BATCH-1)
|
||||
int tx = threadIdx.x; // 0..TILE-1, output col within tile
|
||||
int ty = threadIdx.y; // 0..TILE-1, output row within tile
|
||||
|
||||
int m_global = by * TILE + ty;
|
||||
int n_global = bx * TILE + tx;
|
||||
|
||||
int a_m_base = by * TILE;
|
||||
int b_n_base = bx * TILE;
|
||||
|
||||
// Per-batch base pointer offsets (contiguous row-major across batches).
|
||||
int a_batch_off = batch * (M * K);
|
||||
int b_batch_off = batch * (K * N);
|
||||
int c_batch_off = batch * (M * N);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
int n_tiles = (K + TILE - 1) / TILE;
|
||||
for (int t = 0; t < n_tiles; ++t) {{
|
||||
int k0 = t * TILE;
|
||||
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
|
||||
int a_m = a_m_base + ty;
|
||||
int a_k = k0 + tx;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
|
||||
|
||||
// Load B tile depending on transpose_b
|
||||
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
|
||||
int b_k_or_k = k0 + ty; // similarly
|
||||
// We compute Bs[ty][tx] such that the inner loop reads Bs[k_local][n_local] = B[k][n].
|
||||
// For transpose_b=true (B is (N,K)): B[k][n] in math = B_storage[n][k] = B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// For transpose_b=false (B is (K,N)): B[k][n] in math = B_storage[k][n] = B[(k0+ty)*N + (b_n_base+tx)]
|
||||
bool b_in_bounds = ({transpose_b} ? (b_n_or_k < N && b_k_or_k < K)
|
||||
: (b_k_or_k < K && b_n_or_k < N));
|
||||
Bs[ty][tx] = b_in_bounds ? ({b_load_expr}) : 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < {tile}; ++kk) {{
|
||||
acc += As[ty][kk] * Bs[kk][tx];
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
|
||||
if (m_global < M && n_global < N) {{
|
||||
int n = n_global;
|
||||
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
m = self.m,
|
||||
n = self.n,
|
||||
k = self.k,
|
||||
tile = TILE,
|
||||
transpose_b = self.transpose_b,
|
||||
b_load_expr = b_load_expr,
|
||||
b_param_type = b_param_type,
|
||||
bias_param = bias_param,
|
||||
bias_add = bias_add,
|
||||
bf16_include = bf16_include,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("matmul_2d_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let grid_x = self.n.div_ceil(TILE);
|
||||
let grid_y = self.m.div_ceil(TILE);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(grid_y),
|
||||
Expression::from(self.batch),
|
||||
),
|
||||
(
|
||||
Expression::from(TILE),
|
||||
Expression::from(TILE),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.m * self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// K elements from A (F32) + K elements from B (F32 or BF16) + maybe bias (F32).
|
||||
let b_bytes = match self.weight_dtype {
|
||||
DType::F32 => 4,
|
||||
DType::Bf16 => 2,
|
||||
_ => 4,
|
||||
};
|
||||
let bias_bytes = if self.has_bias { 4 } else { 0 };
|
||||
Expression::from(
|
||||
self.batch * self.m * self.n * (self.k * 4 + self.k * b_bytes + bias_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let per_out = self.k * 2 + if self.has_bias { 1 } else { 0 };
|
||||
Expression::from(self.batch * self.m * self.n * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Matmul2D"
|
||||
}
|
||||
}
|
||||
|
||||
/// CustomOp wrapper for [`Matmul2DKernel`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DCustom(pub Matmul2DKernel);
|
||||
|
||||
impl CustomOp for Matmul2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
|
||||
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
|
||||
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
|
||||
/// pattern produced by linear / projection layers (`x @ w.t()`).
|
||||
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
|
||||
/// `(N,)`, row-major F32 throughout.
|
||||
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
|
||||
}
|
||||
|
||||
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
|
||||
///
|
||||
/// Lowers as plain HLIR — `Cast(A, BF16) @ permute(B_bf16) → Cast(F32)`.
|
||||
/// The activation cast and output cast are tiny (M*K and M*N elements;
|
||||
/// the K=hidden weight stays BF16). The inner BF16 matmul matches the
|
||||
/// existing cublaslt rewrite rules and runs as
|
||||
/// `CUBLAS_COMPUTE_32F_FAST_16BF` — Hopper's native 2× BF16 path.
|
||||
pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "linear_no_bias_bf16_w expects F32 A");
|
||||
assert_eq!(
|
||||
b_bf16.dtype,
|
||||
DType::Bf16,
|
||||
"linear_no_bias_bf16_w expects BF16 B"
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b_bf16.dims();
|
||||
assert_eq!(a_dims.len(), 2);
|
||||
assert_eq!(b_dims.len(), 2);
|
||||
let a_bf16 = a.cast(DType::Bf16);
|
||||
let b_kn = b_bf16.permute((1, 0));
|
||||
a_bf16.matmul(b_kn).cast(DType::F32)
|
||||
}
|
||||
|
||||
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
|
||||
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
|
||||
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
fn matmul_inner(
|
||||
a: GraphTensor,
|
||||
b: GraphTensor,
|
||||
transpose_b: bool,
|
||||
bias: Option<GraphTensor>,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
|
||||
let weight_dtype = b.dtype;
|
||||
assert!(
|
||||
matches!(weight_dtype, DType::F32 | DType::Bf16),
|
||||
"matmul B must be F32 or BF16, got {weight_dtype:?}",
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b.dims();
|
||||
assert_eq!(
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
"matmul A/B rank mismatch: {} vs {}",
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
);
|
||||
assert!(
|
||||
a_dims.len() == 2 || a_dims.len() == 3,
|
||||
"matmul expects rank 2 or 3, got rank {}",
|
||||
a_dims.len(),
|
||||
);
|
||||
|
||||
let (batch, a_off) = if a_dims.len() == 3 {
|
||||
let ba = a_dims[0].to_usize().expect("batch dim must be static");
|
||||
let bb = b_dims[0].to_usize().expect("batch dim must be static");
|
||||
assert_eq!(
|
||||
ba, bb,
|
||||
"matmul batch dim mismatch: A batch={ba}, B batch={bb}"
|
||||
);
|
||||
(ba, 1)
|
||||
} else {
|
||||
(1, 0)
|
||||
};
|
||||
|
||||
let m = a_dims[a_off].to_usize().expect("M must be a static dim");
|
||||
let k_a = a_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (A) must be a static dim");
|
||||
let (n, k_b) = if transpose_b {
|
||||
// B per-batch is (N, K)
|
||||
let n = b_dims[a_off].to_usize().expect("N must be a static dim");
|
||||
let k = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
(n, k)
|
||||
} else {
|
||||
// B per-batch is (K, N)
|
||||
let k = b_dims[a_off]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
let n = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("N must be a static dim");
|
||||
(n, k)
|
||||
};
|
||||
assert_eq!(k_a, k_b, "matmul K mismatch: A K={k_a}, B K={k_b}");
|
||||
let k = k_a;
|
||||
|
||||
let has_bias = bias.is_some();
|
||||
if let Some(bias) = bias {
|
||||
let bdims = bias.dims();
|
||||
assert_eq!(bdims.len(), 1, "matmul bias must be 1D");
|
||||
assert_eq!(
|
||||
bdims[0].to_usize().expect("bias dim must be static"),
|
||||
n,
|
||||
"matmul bias size must equal N"
|
||||
);
|
||||
assert_eq!(bias.dtype, DType::F32, "matmul bias must be F32");
|
||||
}
|
||||
|
||||
let kern = Matmul2DKernel {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch,
|
||||
transpose_b,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
};
|
||||
let cx = unsafe { &mut *a.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
vec![a, b, bias]
|
||||
} else {
|
||||
vec![a, b]
|
||||
};
|
||||
if batch == 1 {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
|
||||
} else {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
|
||||
}
|
||||
}
|
||||
@@ -9,13 +9,31 @@ use luminal_tracing::schema::{
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod generic_matmul;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use generic_matmul::GenericMatmul;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
pub type Ops = (
|
||||
hlir::Ops,
|
||||
other_ops::Ops,
|
||||
conv2d::KernelConv2D,
|
||||
GenericMatmul,
|
||||
fusion::Ops,
|
||||
);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
@@ -173,9 +191,23 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
|
||||
/// Returns the output buffer size in elements.
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
|
||||
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
|
||||
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.output_size().dyn_vars().into_iter().collect()
|
||||
}
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns the DType of this kernel's output buffer.
|
||||
/// Used by has_nan_outputs to interpret buffer bytes correctly.
|
||||
/// Default: F32 (most kernels output float).
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this kernel will load from global memory.
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
@@ -244,18 +276,21 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
|
||||
) {
|
||||
}
|
||||
|
||||
/// Called before each CUDA graph launch. Runs stream-level work outside the graph.
|
||||
/// Used by ops like KernelScatter that need a copy kernel before the main graph kernel.
|
||||
/// Default: no-op.
|
||||
fn pre_launch(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_output_ptr: u64,
|
||||
_input_ptrs: &[u64],
|
||||
_dyn_dims_ptr: u64,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
|
||||
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
/// If this kernel's output is derived from one of its inputs (copy-then-modify
|
||||
/// or in-place write), return that input index. Used by `resolve_data_node` to
|
||||
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
|
||||
/// roundtrip pattern.
|
||||
///
|
||||
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
|
||||
/// (like Scatter which copies dest→output then scatters into it).
|
||||
fn output_data_input(&self) -> Option<usize> {
|
||||
self.output_aliases_input()
|
||||
}
|
||||
|
||||
/// Returns indices of internal buffers containing timing data, if any.
|
||||
880
crates/luminal_cuda_lite/src/kernel/other_ops.rs
Normal file
880
crates/luminal_cuda_lite/src/kernel/other_ops.rs
Normal file
@@ -0,0 +1,880 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines, kernel_rewrite},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (KernelMeanReduce, KernelScatterNoCopy, KernelSoftmax);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
pub struct KernelMeanReduce {
|
||||
out_shape: Vec<Expression>,
|
||||
iters: Expression,
|
||||
in_stride: Vec<Expression>,
|
||||
iter_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
impl EgglogOp for KernelMeanReduce {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelMean",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("iters", EXPRESSION),
|
||||
("strides", ELIST),
|
||||
("iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Disabled: the e-graph union introduced by this rule can cause the search
|
||||
// to select genomes with accumulated FP precision issues over many layers.
|
||||
// The unfused Sum + Mul(Recip(Cast(Iota))) path produces equivalent results.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
{
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let iters = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let iter_stride = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, kind_children[4], list_cache, expr_cache).unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[5]);
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
iters,
|
||||
in_stride,
|
||||
iter_stride,
|
||||
out_stride,
|
||||
dtype,
|
||||
}) as Box<dyn KernelOp>)
|
||||
},
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelMeanReduce {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.iters.dyn_vars())
|
||||
.chain(self.iter_stride.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block: usize = 256; // 8 warps per block
|
||||
let n_warps = threads_per_block / 32;
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void reduce_mean_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = blockIdx.x;
|
||||
long long n_elements = {n_outputs};
|
||||
if (const_z >= n_elements) return;
|
||||
|
||||
long long in_start = {in_index};
|
||||
long long iters = {iters};
|
||||
long long iter_stride = {iter_stride};
|
||||
|
||||
float thread_sum = 0.0f;
|
||||
for (long long i = threadIdx.x; i < iters; i += {threads_per_block})
|
||||
thread_sum += (float)in[in_start + i * iter_stride];
|
||||
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
|
||||
|
||||
__shared__ float warp_sums[{n_warps}];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp = threadIdx.x >> 5;
|
||||
if (lane == 0) warp_sums[warp] = thread_sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {{
|
||||
float sum = 0.0f;
|
||||
for (int w = 0; w < {n_warps}; w++) sum += warp_sums[w];
|
||||
out[{out_index}] = ({dtype})(sum / (float)iters);
|
||||
}}
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
in_index = flatten_strides(&self.out_shape, &self.in_stride).to_kernel(),
|
||||
out_index = flatten_strides(&self.out_shape, &self.out_stride).to_kernel(),
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
iters = self.iters.to_kernel(),
|
||||
iter_stride = self
|
||||
.iter_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel(),
|
||||
threads_per_block = threads_per_block,
|
||||
n_warps = n_warps,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_mean_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(threads_per_block.into(), 1.into(), 1.into()), // block
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.out_shape.iter().copied().product::<Expression>() * self.iters * self.dtype.bits())
|
||||
.ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
n_outputs * self.iters + n_outputs
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MeanReduce"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelScatterNoCopy: In-place scatter that writes directly to dest buffer
|
||||
// without copying. The output buffer aliases the dest buffer.
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KernelScatterNoCopy {
|
||||
dest_shape: Vec<Expression>,
|
||||
dest_strides: Vec<Expression>,
|
||||
index_shape: Vec<Expression>,
|
||||
index_strides: Vec<Expression>,
|
||||
src_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Default for KernelScatterNoCopy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dest_shape: Vec::new(),
|
||||
dest_strides: Vec::new(),
|
||||
index_shape: Vec::new(),
|
||||
index_strides: Vec::new(),
|
||||
src_strides: Vec::new(),
|
||||
out_strides: Vec::new(),
|
||||
dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelScatterNoCopy {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelScatterNoCopy",
|
||||
&[
|
||||
("dest_shape", ELIST),
|
||||
("dest_strides", ELIST),
|
||||
("index_shape", ELIST),
|
||||
("index_strides", ELIST),
|
||||
("src_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn ir_defs(&self) -> Vec<String> {
|
||||
vec!["(ConsumedBuffer IR)".to_string()]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match KernelScatter and rewrite to KernelScatterNoCopy with ConsumedBuffer on dest.
|
||||
// ConsumedBuffer wraps dest to signal in-place modification.
|
||||
// This is only valid when the destination buffer can also represent
|
||||
// the scatter output layout. If dest is a strided/broadcast view,
|
||||
// regular Scatter must first materialize a contiguous output copy.
|
||||
//
|
||||
// Two-phase resolution:
|
||||
// 1. During (run): cleanup rules delete ConsumedBuffer if dest is shared (another op uses it)
|
||||
// 2. During (saturate base_cleanup): surviving ConsumedBuffers are valid — union with
|
||||
// source and delete. This merges the ConsumedBuffer eclass into the source eclass,
|
||||
// making KernelScatterNoCopy's input resolve directly to the source buffer.
|
||||
//
|
||||
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
|
||||
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
|
||||
let mut rules = vec![
|
||||
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail)))
|
||||
((consumed_buffer_ilist_contains ?list ?head))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-head\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail))
|
||||
(consumed_buffer_ilist_contains ?tail ?item))
|
||||
((consumed_buffer_ilist_contains ?list ?item))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-tail\"
|
||||
)",
|
||||
),
|
||||
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?dst ?os)
|
||||
(= ?dty (dtype ?src))
|
||||
)
|
||||
(
|
||||
(let ?consumed (ConsumedBuffer ?dest))
|
||||
(let ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?consumed (ICons ?indexes (ICons ?src (INil))))))
|
||||
(union ?scatter ?nocopy)
|
||||
(set (dtype ?nocopy) ?dty)
|
||||
)
|
||||
:ruleset buffer_reuse
|
||||
:name \"scatter to scatter-no-copy\"
|
||||
)",
|
||||
),
|
||||
// Dtype propagation for ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?dt (dtype ?a)))
|
||||
((set (dtype ?cb) ?dt))
|
||||
:ruleset dtype_prop
|
||||
:name \"consumed-buffer-dtype\"
|
||||
)",
|
||||
),
|
||||
];
|
||||
// Cleanup: delete ConsumedBuffer when inner buffer is used by a DIFFERENT Op.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?op1 (Op ?k1 ?ilist1))
|
||||
(consumed_buffer_ilist_contains ?ilist1 ?cb)
|
||||
(= ?op2 (Op ?k2 ?ilist2))
|
||||
(!= ?op1 ?op2)
|
||||
(consumed_buffer_ilist_contains ?ilist2 ?a))
|
||||
((delete (ConsumedBuffer ?a)))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-cleanup-shared-op-use\"
|
||||
)",
|
||||
));
|
||||
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
|
||||
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?dest))
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
|
||||
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
|
||||
:ruleset post_cleanup
|
||||
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
|
||||
)",
|
||||
));
|
||||
// Surviving ConsumedBuffers are valid — union with source and delete.
|
||||
// Runs in base_cleanup (after all (run) iterations).
|
||||
// TODO: figure out how to validate this is a valid ConsumedBuffer independantly so we can run it in the cleanup ruleset, rather than base_cleanup
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a)))
|
||||
((union ?cb ?a)
|
||||
(delete (ConsumedBuffer ?a)))
|
||||
:ruleset base_cleanup
|
||||
:name \"consumed-buffer-resolve\"
|
||||
)",
|
||||
));
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
dest_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dest_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
index_shape: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
index_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
src_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[5], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[6]),
|
||||
})),
|
||||
input_enodes, // dest, indexes, src
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelScatterNoCopy {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let all_vars: FxHashSet<char> = self
|
||||
.dest_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.dest_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.index_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.index_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.src_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&all_vars);
|
||||
let dyn_dims_param = if all_vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_src_elements = self
|
||||
.index_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let n_dest_elements = self
|
||||
.dest_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let scatter_idx_idx = flatten_strides(&self.index_shape, &self.index_strides).to_kernel();
|
||||
let scatter_src_idx = flatten_strides(&self.index_shape, &self.src_strides).to_kernel();
|
||||
let scatter_kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void scatter_nocopy({dtype} *dest, const int *indexes, const {dtype} *src{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_src_elements}) return;
|
||||
int idx = indexes[{scatter_idx_idx}];
|
||||
if (idx >= 0 && idx < {n_dest_elements}) {{
|
||||
dest[idx] = src[{scatter_src_idx}];
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&scatter_kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx =
|
||||
compile_module_image_for_current_device(stream.context(), &scatter_kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("scatter_nocopy").unwrap();
|
||||
compile_cache.insert(scatter_kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let n_src: Expression = self.index_shape.iter().copied().product();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
scatter_kernel,
|
||||
(n_src.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.dest_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.dest_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.dest_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.index_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.index_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.src_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
let elem_size: Expression = match self.dtype {
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
| DType::I8
|
||||
| DType::U8
|
||||
| DType::F8UE8M0
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2 => 1,
|
||||
other => panic!("Unsupported dtype for scatter output_bytes: {other:?}"),
|
||||
}
|
||||
.into();
|
||||
self.output_size() * elem_size
|
||||
}
|
||||
|
||||
fn build_params(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
_internal_bufs: &[CudaSlice<u8>],
|
||||
dyn_dims_ptr: u64,
|
||||
) -> Vec<u64> {
|
||||
// scatter_nocopy kernel: (dest, indexes, src [, dyn_dims])
|
||||
// Write directly to dest buffer (input_ptrs[0]), NOT to output_ptr
|
||||
let mut params = vec![input_ptrs[0], input_ptrs[1], input_ptrs[2]];
|
||||
if dyn_dims_ptr != 0 {
|
||||
params.push(dyn_dims_ptr);
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let data_elem_size: Expression = match self.dtype {
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
| DType::I8
|
||||
| DType::U8
|
||||
| DType::F8UE8M0
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2 => 1,
|
||||
other => panic!("Unsupported dtype for scatter bytes_loaded: {other:?}"),
|
||||
}
|
||||
.into();
|
||||
let n_src: Expression = self.index_shape.iter().copied().product();
|
||||
// Only load indices + src (no dest copy!)
|
||||
n_src * 4 + n_src * data_elem_size
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
let data_elem_size: Expression = match self.dtype {
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
| DType::I8
|
||||
| DType::U8
|
||||
| DType::F8UE8M0
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2 => 1,
|
||||
other => panic!("Unsupported dtype for scatter bytes_stored: {other:?}"),
|
||||
}
|
||||
.into();
|
||||
let n_src: Expression = self.index_shape.iter().copied().product();
|
||||
// Only store the scattered elements
|
||||
n_src * data_elem_size
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0) // output aliases dest (input 0)
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"ScatterNoCopy"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelSoftmax: Fused softmax over last dimension
|
||||
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))
|
||||
// Replaces 5+ kernel launches with a single fused kernel
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSoftmax {
|
||||
out_shape: Vec<Expression>, // shape of output (same as input)
|
||||
in_stride: Vec<Expression>, // input strides
|
||||
out_stride: Vec<Expression>, // output strides
|
||||
reduce_dim: Expression, // size of the softmax dimension (last dim)
|
||||
reduce_stride: Expression, // stride along softmax dimension in input
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSoftmax {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelSoftmax",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("in_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("reduce_dim", EXPRESSION),
|
||||
("reduce_stride", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
kernel_rewrite::<luminal::hlir::Softmax, Self>(),
|
||||
// Also add a direct rewrite that assumes F32 dtype, in case dtype
|
||||
// propagation hasn't reached the Softmax node yet.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?sm (Op (Softmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride) ?inputs))
|
||||
)
|
||||
(
|
||||
(let ?ksm (Op (KernelSoftmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride (F32)) ?inputs))
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
in_stride,
|
||||
out_stride,
|
||||
reduce_dim,
|
||||
reduce_stride,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelSoftmax {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.reduce_dim.dyn_vars())
|
||||
.chain(self.reduce_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
// n_rows = product of all dims except the last (reduce dim)
|
||||
let n_rows: Expression = self.out_shape[..self.out_shape.len() - 1]
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
// Each block handles one row. 256 threads cooperatively compute softmax.
|
||||
let in_idx = flatten_strides(
|
||||
&self.out_shape[..self.out_shape.len() - 1],
|
||||
&self.in_stride[..self.in_stride.len() - 1],
|
||||
)
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(
|
||||
&self.out_shape[..self.out_shape.len() - 1],
|
||||
&self.out_stride[..self.out_stride.len() - 1],
|
||||
)
|
||||
.to_kernel();
|
||||
let reduce_dim_expr = self.reduce_dim.to_kernel();
|
||||
let in_reduce_stride = self
|
||||
.reduce_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let out_reduce_stride = self.out_stride[self.out_stride.len() - 1]
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
#define NEG_INF_F __int_as_float(0xff800000)
|
||||
{dyn_defines}
|
||||
#define LOG2E 1.4426950408889634f
|
||||
|
||||
extern \"C\" {{
|
||||
// Online normalizer calculation for softmax (Milakov & Gimelshein 2018).
|
||||
|
||||
// Merge two partial (max, sum) pairs using the online softmax rule.
|
||||
__device__ __forceinline__ void merge_md(float *m, float *d, float m2, float d2) {{
|
||||
float new_m = fmaxf(*m, m2);
|
||||
*d = *d * exp2f((*m - new_m) * LOG2E) + d2 * exp2f((m2 - new_m) * LOG2E);
|
||||
*m = new_m;
|
||||
}}
|
||||
|
||||
__global__ void fused_softmax(float *out, const float *inp{dyn_dims_param}) {{
|
||||
__shared__ float sh_m[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
__shared__ float sh_d[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long in_base = {in_idx};
|
||||
long long out_base = {out_idx};
|
||||
long long N = {reduce_dim_expr};
|
||||
long long in_stride = {in_reduce_stride};
|
||||
long long out_stride = {out_reduce_stride};
|
||||
|
||||
// Pass 1: one read of inp produces (global_max, global_sum).
|
||||
float m = NEG_INF_F, d = 0.0f;
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
merge_md(&m, &d, inp[in_base + i * in_stride], 1.0f);
|
||||
}}
|
||||
// Warp reduce: collapse 32 threads within each warp down to lane 0.
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
if (lane_id == 0) {{ sh_m[warp_id] = m; sh_d[warp_id] = d; }}
|
||||
__syncthreads();
|
||||
// Block reduce: warp 0 collapses the 8 warp results down to one.
|
||||
if (warp_id == 0) {{
|
||||
m = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_m[tid] : NEG_INF_F;
|
||||
d = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_d[tid] : 0.0f;
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
sh_m[0] = m;
|
||||
sh_d[0] = d;
|
||||
}}
|
||||
__syncthreads();
|
||||
float global_max = sh_m[0];
|
||||
float inv_sum = 1.0f / sh_d[0];
|
||||
|
||||
// Pass 2: write final softmax values.
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
out[out_base + i * out_stride] = exp2f((inp[in_base + i * in_stride] - global_max) * LOG2E) * inv_sum;
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_softmax").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_rows, 1.into(), 1.into()), // grid: one block per row
|
||||
(256.into(), 1.into(), 1.into()), // block: 256 threads
|
||||
32.into(), // shared mem
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// 3 passes over input (max, exp+sum, normalize reads from output)
|
||||
self.output_size() * 4 * 3
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// 2 writes: exp values, then normalized values
|
||||
self.output_size() * 4 * 2
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Per element: sub, exp2, add (sum), div = ~4 ops
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Softmax"
|
||||
}
|
||||
}
|
||||
189
crates/luminal_cuda_lite/src/kernel/rope.rs
Normal file
189
crates/luminal_cuda_lite/src/kernel/rope.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Fused RoPE (rotary position embedding) — interleaved-pair convention.
|
||||
//!
|
||||
//! Replaces flux2's 6-op RoPE chain (split / slice / squeeze / neg / concat /
|
||||
//! merge_dims / 4× cast / mul / add) with a single kernel launch per call.
|
||||
//! ~120 RoPE calls per forward pass at full DiT depth.
|
||||
//!
|
||||
//! Convention: `repeat_interleave_real=True` (Flux 2 / diffusers), so adjacent
|
||||
//! dim pairs rotate together. For an input `[a0, b0, a1, b1, ...]` and per-
|
||||
//! position `(cos, sin)`, the output is
|
||||
//! `out[2j] = x[2j] * cos[2j] - x[2j+1] * sin[2j]`
|
||||
//! `out[2j+1] = x[2j+1] * cos[2j+1] + x[2j] * sin[2j+1]`
|
||||
//!
|
||||
//! Layout: x `(S, H, D)`, cos/sin `(S, D)` (broadcast across H).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPEKernel {
|
||||
pub s: usize,
|
||||
pub h: usize,
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
const TPB: usize = 64;
|
||||
|
||||
impl KernelOp for RoPEKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let s = self.s;
|
||||
let h = self.h;
|
||||
let d = self.d;
|
||||
assert!(d.is_multiple_of(2), "RoPE head_dim must be even");
|
||||
let kernel = format!(
|
||||
r#"
|
||||
extern "C" __global__ void rope_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ cos_,
|
||||
const float* __restrict__ sin_
|
||||
) {{
|
||||
const int S = {s};
|
||||
const int H = {h};
|
||||
const int D = {d};
|
||||
int sh = blockIdx.x; // 0..S*H
|
||||
int s_idx = sh / H;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const float* xr = x + sh * D;
|
||||
const float* cosr = cos_ + s_idx * D;
|
||||
const float* sinr = sin_ + s_idx * D;
|
||||
float* yr = out + sh * D;
|
||||
|
||||
for (int i = tid; i < D; i += {TPB}) {{
|
||||
float xi = xr[i];
|
||||
float xpair;
|
||||
if ((i & 1) == 0) {{
|
||||
// even: paired with i+1, rotated value is -x[i+1]
|
||||
xpair = -xr[i + 1];
|
||||
}} else {{
|
||||
// odd: paired with i-1, rotated value is +x[i-1]
|
||||
xpair = xr[i - 1];
|
||||
}}
|
||||
yr[i] = xi * cosr[i] + xpair * sinr[i];
|
||||
}}
|
||||
}}
|
||||
"#
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("rope_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
"rope_kernel".to_string(),
|
||||
(
|
||||
Expression::from(s * h),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(TPB),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.s * self.h * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// x: full (S,H,D); cos/sin: (S,D) read H times each but cached.
|
||||
Expression::from(self.s * self.h * self.d * 4 + self.s * self.d * 4 * 2)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 4 per output element (mul, neg/load, mul, add).
|
||||
Expression::from(self.s * self.h * self.d * 4)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"RoPE"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPECustom(pub RoPEKernel);
|
||||
|
||||
impl CustomOp for RoPECustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE: `x` shape `(S, H, D)` F32, `cos`/`sin` shape `(S, D)` F32.
|
||||
/// Returns `(S, H, D)` F32.
|
||||
pub fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(x.dtype, DType::F32, "RoPE x must be F32");
|
||||
let cos = if cos.dtype == DType::F32 {
|
||||
cos
|
||||
} else {
|
||||
cos.cast(DType::F32)
|
||||
};
|
||||
let sin = if sin.dtype == DType::F32 {
|
||||
sin
|
||||
} else {
|
||||
sin.cast(DType::F32)
|
||||
};
|
||||
let x_dims = x.dims();
|
||||
assert_eq!(x_dims.len(), 3, "RoPE x must be 3-D (S, H, D)");
|
||||
let s = x_dims[0].to_usize().expect("RoPE: S must be static");
|
||||
let h = x_dims[1].to_usize().expect("RoPE: H must be static");
|
||||
let d = x_dims[2].to_usize().expect("RoPE: D must be static");
|
||||
let cos_dims = cos.dims();
|
||||
let sin_dims = sin.dims();
|
||||
assert_eq!(cos_dims.len(), 2, "RoPE cos must be 2-D (S, D)");
|
||||
assert_eq!(sin_dims.len(), 2, "RoPE sin must be 2-D (S, D)");
|
||||
assert_eq!(cos_dims[0].to_usize().unwrap(), s, "RoPE cos S mismatch");
|
||||
assert_eq!(cos_dims[1].to_usize().unwrap(), d, "RoPE cos D mismatch");
|
||||
assert_eq!(sin_dims[0].to_usize().unwrap(), s, "RoPE sin S mismatch");
|
||||
assert_eq!(sin_dims[1].to_usize().unwrap(), d, "RoPE sin D mismatch");
|
||||
|
||||
let kern = RoPEKernel { s, h, d };
|
||||
let cx = unsafe { &mut *x.graph_ref };
|
||||
cx.custom_op(RoPECustom(kern), vec![x, cos, sin], (s, h, d), DType::F32)
|
||||
}
|
||||
@@ -11,8 +11,9 @@ use cudarc::driver::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::IR},
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
@@ -22,10 +23,12 @@ use luminal::{
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::HostOp,
|
||||
host::{DeviceBuffer, HostOp},
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
fusion::region_codegen::{self, CompileUnit},
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
};
|
||||
@@ -45,8 +48,12 @@ struct CompiledKernel {
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Human-readable labels for input nodes, for launch diagnostics.
|
||||
input_labels: Vec<String>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
|
||||
has_dyn_dims_param: bool,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
@@ -66,7 +73,9 @@ impl CompiledKernel {
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
input_labels: Vec<String>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
has_dyn_dims_param: bool,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
@@ -77,7 +86,9 @@ impl CompiledKernel {
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
@@ -181,6 +192,32 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
/// LLIR node IDs of every kernel in this CudaGraphOp, in the order
|
||||
/// they execute inside the compiled CUDA graph. This is the
|
||||
/// toposort `kernel_to_host` used at compile time, preserved here
|
||||
/// so the runtime can compute live ranges that match real
|
||||
/// execution order: each kernel in `state.kernels` was added to
|
||||
/// the CUDA graph with `prev_graph_node` as its sole dependency,
|
||||
/// which serializes them.
|
||||
pub fn kernel_topo_order(&self) -> Vec<NodeIndex> {
|
||||
self.state.borrow().kernels.iter().map(|k| k.node).collect()
|
||||
}
|
||||
|
||||
/// Direct LLIR-node inputs of one kernel inside this CudaGraphOp.
|
||||
/// Used by the runtime's live-range pass to refine intra-graph
|
||||
/// consumer positions: a kernel's input can stop being live as
|
||||
/// soon as that specific kernel finishes, not when the whole
|
||||
/// CudaGraphOp finishes.
|
||||
pub fn kernel_inputs(&self, kernel_node: NodeIndex) -> Vec<NodeIndex> {
|
||||
self.state
|
||||
.borrow()
|
||||
.kernels
|
||||
.iter()
|
||||
.find(|k| k.node == kernel_node)
|
||||
.map(|k| k.inputs.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -195,7 +232,7 @@ impl std::fmt::Debug for CudaGraphOp {
|
||||
|
||||
impl EgglogOp for CudaGraphOp {
|
||||
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
|
||||
luminal::egglog_utils::api::sort(IR, "CudaGraphOp", &[])
|
||||
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
@@ -205,7 +242,8 @@ impl EgglogOp for CudaGraphOp {
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
_children: &[&'a luminal::prelude::ENodeId],
|
||||
_kind_children: &[&'a luminal::prelude::ENodeId],
|
||||
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
|
||||
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
|
||||
@@ -223,7 +261,7 @@ impl HostOp for CudaGraphOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
@@ -255,6 +293,40 @@ impl HostOp for CudaGraphOp {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
let state = self.state.borrow();
|
||||
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
|
||||
let max_step = state.kernels.len().saturating_sub(1);
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
lifetimes
|
||||
.entry(node)
|
||||
.and_modify(|(first, last)| {
|
||||
*first = (*first).min(step);
|
||||
*last = (*last).max(step);
|
||||
})
|
||||
.or_insert((step, step));
|
||||
};
|
||||
|
||||
for (step, kernel) in state.kernels.iter().enumerate() {
|
||||
for &input in &kernel.inputs {
|
||||
touch(input, step);
|
||||
}
|
||||
touch(kernel.node, step);
|
||||
}
|
||||
|
||||
for node in self.extra_buffer_nodes() {
|
||||
lifetimes.entry(node).or_insert((0, max_step));
|
||||
}
|
||||
|
||||
Some(
|
||||
lifetimes
|
||||
.into_iter()
|
||||
.map(|(node, (start, end))| (node, start, end))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
@@ -265,11 +337,63 @@ impl HostOp for CudaGraphOp {
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
|
||||
match kernel_name {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "Embed" | "Gather" | "GenericMatmul" | "LessThan" | "Mod" | "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_requires_output_buffer(
|
||||
kernel: &CompiledKernel,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> bool {
|
||||
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
|
||||
&& kernel.kernel_op.output_aliases_input().is_none()
|
||||
}
|
||||
|
||||
fn validate_kernel_pointers(
|
||||
kernel: &CompiledKernel,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
|
||||
if *input_ptr == 0 {
|
||||
let input_label = kernel
|
||||
.input_labels
|
||||
.get(idx)
|
||||
.map(String::as_str)
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!(
|
||||
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
input_node,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
@@ -299,7 +423,11 @@ impl CudaGraphOp {
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
// Internal buffer pointers changed, need to rebuild CUDA graph
|
||||
}
|
||||
// Only force full rebuild when internal buffer sizes change.
|
||||
// Dim-only changes (e.g. position offset `p` incrementing each decode step) are
|
||||
// handled by updating the dyn_dims device buffer + kernel node params in-place.
|
||||
if needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
@@ -336,7 +464,16 @@ impl CudaGraphOp {
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
current_buffer_ptrs.insert(node, buf.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Apply output-aliases-input
|
||||
for kernel in state.kernels.iter() {
|
||||
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
|
||||
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
|
||||
{
|
||||
current_buffer_ptrs.insert(kernel.node, input_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,13 +512,26 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
kernel_dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
@@ -408,6 +558,19 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
@@ -425,43 +588,9 @@ impl CudaGraphOp {
|
||||
state.last_buffer_ptrs = current_buffer_ptrs;
|
||||
}
|
||||
|
||||
// Call pre_launch for each kernel (e.g., KernelScatter copies dest→output before graph)
|
||||
{
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
for kernel in state.kernels.iter() {
|
||||
let output_ptr = state
|
||||
.last_buffer_ptrs
|
||||
.get(&kernel.node)
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| state.last_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
kernel.kernel_op.pre_launch(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
dyn_dims_ptr,
|
||||
dyn_map,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Sync before launch
|
||||
stream.synchronize()?;
|
||||
|
||||
// Launch the graph
|
||||
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
|
||||
|
||||
// Sync after launch
|
||||
stream.synchronize()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -470,7 +599,7 @@ impl CudaGraphOp {
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
@@ -492,7 +621,7 @@ impl CudaGraphOp {
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
buffer_ptrs.insert(node, buf.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,6 +668,19 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
@@ -547,18 +689,41 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
kernel_dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
|
||||
eprintln!(
|
||||
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
kernel.inputs.len(),
|
||||
kernel.has_dyn_dims_param,
|
||||
params.values.len(),
|
||||
);
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
@@ -618,7 +783,7 @@ impl Drop for CudaGraphOp {
|
||||
fn drop(&mut self) {
|
||||
let mut state = self.state.borrow_mut();
|
||||
|
||||
// Destroy timing events - extract ctx first to avoid borrow issues
|
||||
// Destroy timing events first
|
||||
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
|
||||
if let Some(ctx) = ctx {
|
||||
for event in state.timing_events.drain(..) {
|
||||
@@ -626,22 +791,22 @@ impl Drop for CudaGraphOp {
|
||||
}
|
||||
}
|
||||
|
||||
// Forget dyn_dims buffer (managed by runtime)
|
||||
if let Some(buf) = state.dyn_dims_buffer.take() {
|
||||
std::mem::forget(buf);
|
||||
}
|
||||
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
|
||||
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
|
||||
// so it must be destroyed first to avoid dangling pointer issues.
|
||||
drop(state.cuda_graph_exec.take());
|
||||
drop(state.cuda_graph.take());
|
||||
|
||||
// Handle kernel resources
|
||||
// Now safe to free dynamically allocated GPU buffers
|
||||
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
|
||||
|
||||
// Constants point to __constant__ memory in the CUDA module,
|
||||
// not dynamically allocated — must not be freed.
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
// Forget constants (they point to __constant__ memory)
|
||||
let constants = std::mem::take(&mut kernel.constants);
|
||||
for (_k, v) in constants {
|
||||
std::mem::forget(v);
|
||||
}
|
||||
// Forget internal buffers (managed by runtime)
|
||||
for buf in kernel.internal_bufs.drain(..) {
|
||||
std::mem::forget(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -661,7 +826,6 @@ pub fn kernel_to_host(
|
||||
llir_graph: &mut LLIRGraph,
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
megakernel_to_blocks: &FxHashMap<NodeIndex, Vec<NodeIndex>>,
|
||||
) {
|
||||
let _span = span!(Level::TRACE, "kernel_to_host").entered();
|
||||
|
||||
@@ -675,6 +839,41 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / Cuda*Elementwise nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
|
||||
graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
|
||||
name_of(graph, node) == Some("FusionStart")
|
||||
|| graph[node].to_op::<LoopStart>().is_some()
|
||||
|| graph[node].to_op::<LoopEnd>().is_some()
|
||||
|| graph[node].to_op::<LoopInput>().is_some()
|
||||
|| graph[node].to_op::<LoopInputStatic>().is_some()
|
||||
|| graph[node].to_op::<LoopOutput>().is_some()
|
||||
|| graph[node].to_op::<LoopOutputSelect>().is_some()
|
||||
};
|
||||
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
|
||||
let mut visited = FxHashSet::default();
|
||||
while visited.insert(node) && is_transparent_input(graph, node) {
|
||||
let Some(pred) = graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.next()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
node = pred;
|
||||
}
|
||||
node
|
||||
};
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
@@ -689,68 +888,183 @@ pub fn kernel_to_host(
|
||||
.filter(|n| subgraph.contains(n))
|
||||
.collect();
|
||||
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
let mut external_inputs = FxHashSet::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// If this is a megakernel, include all its block op nodes for buffer access
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node_idx) {
|
||||
inputs.extend(block_nodes.iter().copied());
|
||||
}
|
||||
|
||||
// Collect dyn dims used by this kernel
|
||||
all_dyn_dims.extend(grid.0.dyn_vars());
|
||||
all_dyn_dims.extend(grid.1.dyn_vars());
|
||||
all_dyn_dims.extend(grid.2.dyn_vars());
|
||||
all_dyn_dims.extend(block.0.dyn_vars());
|
||||
all_dyn_dims.extend(block.1.dyn_vars());
|
||||
all_dyn_dims.extend(block.2.dyn_vars());
|
||||
all_dyn_dims.extend(shared_mem.dyn_vars());
|
||||
all_dyn_dims.extend(kernel_op_ref.output_size().dyn_vars());
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
|
||||
}
|
||||
|
||||
// Sort dyn dims alphabetically for consistent buffer layout
|
||||
let mut dyn_dims_order: Vec<char> = all_dyn_dims.into_iter().collect();
|
||||
dyn_dims_order.sort();
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
// CUDA kernel for the whole DAG); everything else stays as
|
||||
// CompileUnit::Single (the existing per-op compile path).
|
||||
let compile_units =
|
||||
region_codegen::build_compile_units(&topo_order, llir_graph, &globally_absorbed);
|
||||
|
||||
// Compile all units with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(compile_units.len());
|
||||
for unit in &compile_units {
|
||||
match unit {
|
||||
CompileUnit::Single(kernel_node_idx) => {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect_vec();
|
||||
if let Some(expected_inputs) =
|
||||
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
|
||||
{
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
expected_inputs,
|
||||
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel_op_ref.kernel_name(),
|
||||
kernel_node_idx,
|
||||
);
|
||||
}
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op.clone(),
|
||||
has_dyn_dims_param,
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
CompileUnit::Region(region) => {
|
||||
// Generate one fused CUDA kernel for the whole region.
|
||||
let compiled = region_codegen::compile_region(
|
||||
region,
|
||||
llir_graph,
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior elementwise nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region
|
||||
.external_inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect();
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(region.fe_node);
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
region.fe_node,
|
||||
compiled.function,
|
||||
compiled.grid,
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
let final_global = get_global_dyn_dims();
|
||||
// Clear global ordering now that all kernels are compiled
|
||||
clear_global_dyn_dims();
|
||||
|
||||
// Use the final global ordering if it was extended during compilation
|
||||
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
|
||||
final_order
|
||||
} else {
|
||||
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
|
||||
dims.sort();
|
||||
dims
|
||||
};
|
||||
|
||||
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
|
||||
|
||||
@@ -773,26 +1087,19 @@ pub fn kernel_to_host(
|
||||
for kernel_node in &subgraph {
|
||||
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
|
||||
}
|
||||
// Also track block op nodes inside megakernels
|
||||
for kernel_node in &subgraph {
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
|
||||
for block_node in block_nodes {
|
||||
kernel_to_cuda_graph.insert(*block_node, cuda_graph_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
// Find external inputs: nodes outside subgraph that have edges into
|
||||
// subgraph. Also include normalized FusionStart predecessors, because
|
||||
// the compiled kernels read from the concrete producer buffer rather
|
||||
// than the marker node.
|
||||
external_inputs.extend(subgraph.iter().flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
}));
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
@@ -808,23 +1115,15 @@ pub fn kernel_to_host(
|
||||
|
||||
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
|
||||
// This ensures proper execution ordering when a kernel in one CudaGraphOp
|
||||
// produces output consumed by a kernel (or BlockOp inside a megakernel) in another CudaGraphOp.
|
||||
// produces output consumed by a kernel in another CudaGraphOp.
|
||||
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
|
||||
|
||||
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
|
||||
// Find all nodes that this subgraph produces output for (including BlockOp nodes in megakernels)
|
||||
let mut all_producer_nodes: FxHashSet<NodeIndex> = subgraph.clone();
|
||||
for kernel_node in subgraph {
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
|
||||
all_producer_nodes.extend(block_nodes.iter().copied());
|
||||
}
|
||||
}
|
||||
|
||||
// Find external consumers that are kernels belonging to other CudaGraphOps
|
||||
for producer_node in &all_producer_nodes {
|
||||
for producer_node in subgraph {
|
||||
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
|
||||
let consumer = edge.target();
|
||||
if all_producer_nodes.contains(&consumer) {
|
||||
if subgraph.contains(&consumer) {
|
||||
continue; // Same subgraph
|
||||
}
|
||||
// Check if consumer is a kernel in another CudaGraphOp
|
||||
@@ -844,22 +1143,41 @@ pub fn kernel_to_host(
|
||||
}
|
||||
}
|
||||
|
||||
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
|
||||
// Add each cross-CudaGraphOp dep edge iff it would carry new ordering
|
||||
// information without closing a cycle. The previous topo-position gate
|
||||
// ("skip when src_pos >= dst_pos") was too coarse: it dropped edges
|
||||
// whose src happened to land later in the toposort than their dst even
|
||||
// when no path dst→src actually existed, leaving consumers free to run
|
||||
// before the producer wrote their input buffer (wrong outputs); and it
|
||||
// also added edges that were already implied by an existing src→dst
|
||||
// path (extra serialization, no new info).
|
||||
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
|
||||
let topo = toposort(&*llir_graph, None).unwrap();
|
||||
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, n) in topo.iter().enumerate() {
|
||||
topo_pos.insert(*n, i);
|
||||
}
|
||||
use petgraph::algo::has_path_connecting;
|
||||
for (src, dst) in edges_to_add {
|
||||
// Only add forward edges (src before dst in topo order) to avoid creating cycles
|
||||
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
|
||||
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
|
||||
if src_pos >= dst_pos {
|
||||
continue; // Skip back-edges
|
||||
if has_path_connecting(&*llir_graph, src, dst, None) {
|
||||
continue; // already ordered src→dst by some path; edge redundant
|
||||
}
|
||||
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
if has_path_connecting(&*llir_graph, dst, src, None) {
|
||||
continue; // adding src→dst would close a cycle
|
||||
}
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// Cuda*Elementwise) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
|
||||
// chewing through dead nodes every decode token.
|
||||
//
|
||||
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
|
||||
// walks' starting points), so we keep them — they're the kernel
|
||||
// anchor for the region's compiled kernel.
|
||||
for node in globally_absorbed {
|
||||
// Defensive: only remove if the node still exists.
|
||||
if llir_graph.node_weight(node).is_some() {
|
||||
llir_graph.remove_node(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
333
crates/luminal_cuda_lite/src/lib.rs
Normal file
333
crates/luminal_cuda_lite/src/lib.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::{
|
||||
driver::{CudaContext, DriverError, sys as driver_sys},
|
||||
nvrtc::{
|
||||
Ptx,
|
||||
result::{self as nvrtc_result, NvrtcError},
|
||||
sys as nvrtc_sys,
|
||||
},
|
||||
};
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::F64 => "double",
|
||||
DType::F32 => "float",
|
||||
DType::F16 => "half",
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I64 => "long long",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
DType::U8 => "unsigned char",
|
||||
DType::Bool => "unsigned char",
|
||||
DType::F8E4M3 => "__nv_fp8_e4m3",
|
||||
DType::F8E5M2 => "__nv_fp8_e5m2",
|
||||
DType::F8UE8M0 => "__nv_fp8_e8m0",
|
||||
DType::F6E2M3 => "__nv_fp6_e2m3",
|
||||
DType::F6E3M2 => "__nv_fp6_e3m2",
|
||||
DType::F4E2M1 => "__nv_fp4_e2m1",
|
||||
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
|
||||
}
|
||||
}
|
||||
|
||||
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum CudaModuleImageCompileFailure {
|
||||
ComputeCapability(DriverError),
|
||||
Nvrtc {
|
||||
stage: &'static str,
|
||||
error: NvrtcError,
|
||||
},
|
||||
NoModuleImageProduced,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CudaModuleImageCompileError {
|
||||
pub target_arch: Option<String>,
|
||||
pub driver_version: Option<i32>,
|
||||
pub runtime_version: Option<i32>,
|
||||
pub nvrtc_options: Vec<String>,
|
||||
pub nvrtc_log: Option<String>,
|
||||
pub failure: CudaModuleImageCompileFailure,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CudaModuleImageCompileError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to compile CUDA module image")?;
|
||||
if let Some(target_arch) = &self.target_arch {
|
||||
write!(f, " for {target_arch}")?;
|
||||
}
|
||||
match &self.failure {
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error) => {
|
||||
write!(f, ": failed to query compute capability: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
|
||||
write!(f, ": NVRTC {stage} failed: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced => {
|
||||
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
|
||||
}
|
||||
}
|
||||
if let Some(version) = self.driver_version {
|
||||
write!(f, " | driver {}", format_cuda_version(version))?;
|
||||
}
|
||||
if let Some(version) = self.runtime_version {
|
||||
write!(f, " | runtime {}", format_cuda_version(version))?;
|
||||
}
|
||||
if !self.nvrtc_options.is_empty() {
|
||||
write!(f, " | options {:?}", self.nvrtc_options)?;
|
||||
}
|
||||
if let Some(log) = &self.nvrtc_log {
|
||||
write!(f, " | log: {log}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CudaModuleImageCompileError {}
|
||||
|
||||
fn format_cuda_version(version: i32) -> String {
|
||||
format!("{}.{}", version / 1000, (version % 1000) / 10)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_include_paths() -> Vec<String> {
|
||||
let mut include_paths = Vec::new();
|
||||
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
|
||||
if let Ok(root) = std::env::var(env_var) {
|
||||
let path = format!("{root}/include");
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
for path in CUDA_NVRTC_INCLUDE_PATHS {
|
||||
let path = path.to_string();
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
include_paths
|
||||
}
|
||||
|
||||
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
let mut driver_version = 0;
|
||||
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
|
||||
.result()
|
||||
.ok()
|
||||
.map(|_| driver_version);
|
||||
|
||||
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
|
||||
// resolves newer libcudart symbols that may not exist in the installed runtime.
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
pub(crate) fn try_create_cublaslt(
|
||||
stream: Arc<CudaStream>,
|
||||
) -> std::result::Result<Arc<CudaBlasLT>, String> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
|
||||
Ok(Ok(handle)) => Ok(Arc::new(handle)),
|
||||
Ok(Err(err)) => Err(err.to_string()),
|
||||
Err(payload) => {
|
||||
let message = if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
message.to_string()
|
||||
} else {
|
||||
"cuBLASLt initialization panicked".to_string()
|
||||
};
|
||||
Err(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
.map(|path| format!("--include-path={path}"))
|
||||
.collect::<Vec<_>>();
|
||||
options.push(format!("--gpu-architecture={target_arch}"));
|
||||
options
|
||||
}
|
||||
|
||||
fn build_module_image_compile_error(
|
||||
target_arch: Option<String>,
|
||||
driver_version: Option<i32>,
|
||||
runtime_version: Option<i32>,
|
||||
nvrtc_options: &[String],
|
||||
nvrtc_log: Option<String>,
|
||||
failure: CudaModuleImageCompileFailure,
|
||||
) -> CudaModuleImageCompileError {
|
||||
CudaModuleImageCompileError {
|
||||
target_arch,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
nvrtc_options: nvrtc_options.to_vec(),
|
||||
nvrtc_log,
|
||||
failure,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
|
||||
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
|
||||
.to_string_lossy()
|
||||
.trim_end_matches('\0')
|
||||
.trim()
|
||||
.to_string();
|
||||
if log.is_empty() { None } else { Some(log) }
|
||||
}
|
||||
|
||||
#[allow(clippy::slow_vector_initialization)]
|
||||
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
let mut cubin_size = 0usize;
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
|
||||
if cubin_size == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0u8);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
|
||||
Ok(cubin)
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
ctx: &Arc<CudaContext>,
|
||||
src: S,
|
||||
) -> Result<Ptx, CudaModuleImageCompileError> {
|
||||
let (driver_version, runtime_version) = cuda_driver_diagnostics();
|
||||
let (major, minor) = ctx.compute_capability().map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
None,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&[],
|
||||
None,
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error),
|
||||
)
|
||||
})?;
|
||||
let target_arch = format!("sm_{major}{minor}");
|
||||
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
|
||||
|
||||
let source = CString::new(src.as_ref().as_bytes())
|
||||
.expect("CUDA source code cannot contain null terminators");
|
||||
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
Some(target_arch.clone()),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
None,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "create_program",
|
||||
error,
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "compile_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let cubin = match get_cubin(program) {
|
||||
Ok(cubin) => cubin,
|
||||
Err(error) => {
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "get_cubin",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "destroy_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
if cubin.is_empty() {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ptx::from_binary(cubin))
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 273,
|
||||
"NVIDIA H100 PCIe" => 2_000,
|
||||
"NVIDIA H100 SXM" => 3_350,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in TFLOPs
|
||||
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 125, // forced to use tf32 flops
|
||||
"NVIDIA H100 PCIe" => 756,
|
||||
"NVIDIA H100 SXM" => 989,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
1888
crates/luminal_cuda_lite/src/memory_analysis.rs
Normal file
1888
crates/luminal_cuda_lite/src/memory_analysis.rs
Normal file
File diff suppressed because it is too large
Load Diff
2494
crates/luminal_cuda_lite/src/runtime.rs
Normal file
2494
crates/luminal_cuda_lite/src/runtime.rs
Normal file
File diff suppressed because it is too large
Load Diff
349
crates/luminal_cuda_lite/src/tests/bucket_tests.rs
Normal file
349
crates/luminal_cuda_lite/src/tests/bucket_tests.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::*;
|
||||
use luminal::prelude::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
|
||||
/// Returns (cx, input_node, output_node).
|
||||
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', 4));
|
||||
let b = (a + a).output();
|
||||
(cx, a.id, b.id)
|
||||
}
|
||||
|
||||
/// Helper: build a matmul graph with dynamic dim 's'.
|
||||
/// Computes (s, K) @ (K, N) -> (s, N)
|
||||
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let c = a.matmul(b).output();
|
||||
(cx, a.id, b.id, c.id)
|
||||
}
|
||||
|
||||
fn bucket_options(buckets: &[DimBucket]) -> CompileOptions {
|
||||
CompileOptions::default().dim_buckets('s', buckets)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_dispatch_simple() {
|
||||
// Tests that bucketed compilation produces correct results for different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Set dummy input for search
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..4], &expected, 1e-5, 1e-5);
|
||||
|
||||
// Test bucket 2: s=3
|
||||
cx.set_dim('s', 3);
|
||||
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..12], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_matmul_dynamic() {
|
||||
// Tests matmul with bucketed dynamic dim
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 8;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 8),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s1 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=1 (1xK @ KxN -> 1xN)
|
||||
let mut expected_s1 = vec![0.0f32; n];
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s1[j] += a_data[i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
|
||||
|
||||
// Execute at s=4
|
||||
cx.set_dim('s', 4);
|
||||
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_4.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s4 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=4 (4xK @ KxN -> 4xN)
|
||||
let mut expected_s4 = vec![0.0f32; 4 * n];
|
||||
for row in 0..4 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_results_match_unbucketed() {
|
||||
// Tests that bucketed results match non-bucketed results for the same graph
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let seed = 42u64;
|
||||
|
||||
// Non-bucketed run
|
||||
let (mut cx1, a1, b1) = build_dynamic_add_graph();
|
||||
cx1.set_dim('s', 3);
|
||||
cx1.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt1 = CudaRuntime::initialize(stream.clone());
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
|
||||
// Bucketed run with bucket that covers s=3
|
||||
let (mut cx2, a2, b2) = build_dynamic_add_graph();
|
||||
cx2.set_dim('s', 3);
|
||||
cx2.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 4)]));
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
|
||||
// Results should match — same graph, same search seed, same dyn_map
|
||||
assert_eq!(result_unbucketed.len(), result_bucketed.len());
|
||||
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "No bucket matches")]
|
||||
fn test_bucket_out_of_range_panics() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
// Can't trigger panic without GPU, skip gracefully
|
||||
panic!("No bucket matches dyn_map");
|
||||
};
|
||||
|
||||
let (mut cx, a, _b) = build_dynamic_add_graph();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
rt.set_data(a, vec![1.0f32; 40]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_no_buckets_backward_compat() {
|
||||
// No buckets set → should behave identically to old path
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
cx.set_dim('s', 2);
|
||||
|
||||
// No bucket options
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..8], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_representative_override() {
|
||||
// Tests that custom representative works
|
||||
let bucket = DimBucket::new(2, 32).representative(16);
|
||||
assert_eq!(bucket.representative_value(), 16);
|
||||
|
||||
let bucket_default = DimBucket::new(2, 32);
|
||||
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
|
||||
|
||||
let exact = DimBucket::new(1, 1);
|
||||
assert_eq!(exact.representative_value(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_switch_preserves_weights() {
|
||||
// Tests that switching between buckets still sees the correct weight data
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 4;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1a = rt.get_f32(c);
|
||||
|
||||
// Switch to bucket 2 (s=3)
|
||||
cx.set_dim('s', 3);
|
||||
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_3.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_3 = rt.get_f32(c);
|
||||
|
||||
// Switch back to bucket 1 (s=1) — weights should still work
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1b = rt.get_f32(c);
|
||||
|
||||
// First and last s=1 results should match exactly
|
||||
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
|
||||
|
||||
// Verify s=3 result correctness
|
||||
let mut expected_3 = vec![0.0f32; 3 * n];
|
||||
for row in 0..3 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_multiple_executions_same_bucket() {
|
||||
// Tests multiple executions within the same bucket with different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 8)]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
cx.set_dim('s', s);
|
||||
let n = s * 4;
|
||||
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..n], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Overlapping buckets")]
|
||||
fn test_bucket_overlapping_ranges_panics() {
|
||||
let _ = bucket_options(&[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dim_bucket_contains() {
|
||||
let b = DimBucket::new(2, 10);
|
||||
assert!(!b.contains(1));
|
||||
assert!(b.contains(2));
|
||||
assert!(b.contains(5));
|
||||
assert!(b.contains(10));
|
||||
assert!(!b.contains(11));
|
||||
|
||||
// Exact bucket
|
||||
let exact = DimBucket::new(3, 3);
|
||||
assert!(!exact.contains(2));
|
||||
assert!(exact.contains(3));
|
||||
assert!(!exact.contains(4));
|
||||
}
|
||||
1516
crates/luminal_cuda_lite/src/tests/consumed_buffer_tests.rs
Normal file
1516
crates/luminal_cuda_lite/src/tests/consumed_buffer_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
fn conv2d_bias_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
let output_spatial_dims = unfolded.dims()[1..3].to_vec();
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out = out
|
||||
.split_dims(0, output_spatial_dims[1])
|
||||
.permute(&[2, 0, 1]);
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(1, out_dims[1]).expand_dim(2, out_dims[2])
|
||||
}
|
||||
|
||||
fn build_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_hlir(x, weight, bias, 3, 2).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_bias_padded_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
let zero = Expression::from(0);
|
||||
let pad = Expression::from(padding);
|
||||
let padded = x.pad(vec![(zero, zero), (pad, pad), (pad, pad)], 0.0);
|
||||
conv2d_bias_hlir(padded, weight, bias, kernel, kernel)
|
||||
}
|
||||
|
||||
fn build_padded_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_padded_hlir(x, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn nearest_upsample_2x_hlir(x: GraphTensor) -> GraphTensor {
|
||||
let stage1 = x.expand_dim(2, 2usize).merge_dims(1, 2);
|
||||
stage1.expand_dim(3, 2usize).merge_dims(2, 3)
|
||||
}
|
||||
|
||||
fn build_upsample_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 3usize, 4usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let up = nearest_upsample_2x_hlir(x);
|
||||
let out = conv2d_bias_padded_hlir(up, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv1x1_bias_hlir(x: GraphTensor, weight: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
let h = dims[1];
|
||||
let w = dims[2];
|
||||
let xt = x.permute(&[1, 2, 0]).merge_dims(0, 1);
|
||||
let out = xt.matmul(weight.t());
|
||||
let out = out.split_dims(0, w).permute(&[2, 0, 1]);
|
||||
out + bias.expand_dim(1, h).expand_dim(2, w)
|
||||
}
|
||||
|
||||
fn build_conv1x1_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv1x1_bias_hlir(x, weight, bias).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_matmul_without_conv_output_shape(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(0, out_dims[0])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_unfold_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate"
|
||||
);
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "Add").is_empty(),
|
||||
"generic conv2d cleanup should prune the final bias Add fallback"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_conv1x1_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate for 1x1 conv"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_requires_conv_output_shape() {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
conv2d_matmul_without_conv_output_shape(x, weight, bias, 3, 2).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"matmul+bias without [C_out,H_out,W_out] conv output shape should not match KernelConv2D"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 5 * 6).map(|i| i as f32 * 0.03 - 0.4).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 2)
|
||||
.map(|i| (i as f32 % 11.0) * 0.04 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.25_f32, -0.15, 0.05];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 5,
|
||||
w: 6,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 2,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_conv1x1_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.07 - 1.0).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2).map(|i| (i as f32 % 5.0) * 0.11 - 0.2).collect();
|
||||
let biases = vec![0.2_f32, -0.1, 0.4];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 1,
|
||||
kw: 1,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_padded_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_padded_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.05 - 0.5).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 13.0) * 0.03 - 0.17)
|
||||
.collect();
|
||||
let biases = vec![0.15_f32, -0.25, 0.35];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_upsample_view_input() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_upsample_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 * 0.09 - 0.8).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 17.0) * 0.025 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.05_f32, -0.1, 0.2];
|
||||
let upsampled = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
let expected = reference_conv2d(
|
||||
&upsampled,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 6,
|
||||
w: 8,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
struct ConvCase {
|
||||
c_in: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
c_out: usize,
|
||||
kh: usize,
|
||||
kw: usize,
|
||||
padding_h: usize,
|
||||
padding_w: usize,
|
||||
}
|
||||
|
||||
fn reference_nearest_upsample_2x(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0_f32; c * h * 2 * w * 2];
|
||||
for ci in 0..c {
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let value = input[ci * h * w + y * w + x];
|
||||
for dy in 0..2 {
|
||||
for dx in 0..2 {
|
||||
let oy = y * 2 + dy;
|
||||
let ox = x * 2 + dx;
|
||||
out[ci * h * 2 * w * 2 + oy * w * 2 + ox] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn reference_conv2d(input: &[f32], weight: &[f32], bias: &[f32], case: ConvCase) -> Vec<f32> {
|
||||
let ConvCase {
|
||||
c_in,
|
||||
h,
|
||||
w,
|
||||
c_out,
|
||||
kh,
|
||||
kw,
|
||||
padding_h,
|
||||
padding_w,
|
||||
} = case;
|
||||
let h_out = h + 2 * padding_h - kh + 1;
|
||||
let w_out = w + 2 * padding_w - kw + 1;
|
||||
let mut out = vec![0.0; c_out * h_out * w_out];
|
||||
for co in 0..c_out {
|
||||
for oh in 0..h_out {
|
||||
for ow in 0..w_out {
|
||||
let mut acc = bias[co];
|
||||
for ci in 0..c_in {
|
||||
for r in 0..kh {
|
||||
for s in 0..kw {
|
||||
let Some(ih) = (oh + r).checked_sub(padding_h) else {
|
||||
continue;
|
||||
};
|
||||
let Some(iw) = (ow + s).checked_sub(padding_w) else {
|
||||
continue;
|
||||
};
|
||||
if ih >= h || iw >= w {
|
||||
continue;
|
||||
}
|
||||
let input_idx = ci * h * w + ih * w + iw;
|
||||
let weight_idx = co * c_in * kh * kw + (ci * kh + r) * kw + s;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
out[co * h_out * w_out + oh * w_out + ow] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, "KernelConv2D");
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0xC0_2D00 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user