mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
6 Commits
rust-examp
...
codex-lumi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79d00a4827 | ||
|
|
acad3a625a | ||
|
|
07ad11d101 | ||
|
|
98f4f2102b | ||
|
|
896c4b7c7e | ||
|
|
0134aa425a |
4
.github/workflows/modal-examples.yml
vendored
4
.github/workflows/modal-examples.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 70
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
|
||||
example: [llama, gemma, qwen, qwen3_moe]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
|
||||
2
.github/workflows/test-core.yml
vendored
2
.github/workflows/test-core.yml
vendored
@@ -21,4 +21,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
|
||||
2
.github/workflows/test-cuda.yml
vendored
2
.github/workflows/test-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/test-metal.yml
vendored
2
.github/workflows/test-metal.yml
vendored
@@ -16,4 +16,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
4
.github/workflows/test-python-cuda.yml
vendored
4
.github/workflows/test-python-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 7200 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
2
.github/workflows/test-python-native.yml
vendored
2
.github/workflows/test-python-native.yml
vendored
@@ -23,6 +23,6 @@ jobs:
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --profile release
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
12
AGENTS.md
12
AGENTS.md
@@ -8,14 +8,4 @@ All other functionality is split into crates in the `crates/` directory. For ins
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
|
||||
## Debugging and Correctness
|
||||
- Treat model examples as specifications of the intended architecture. Do not change model code, prompt templates, weights, or example logic to hide compiler/runtime/search bugs unless the model code is demonstrably semantically wrong.
|
||||
- When outputs are incorrect, first root-cause the failing compiler/runtime path. Prefer isolating the bad LLIR/HLIR graph, rewrite, op lowering, shape/stride assumption, layout contract, or runtime implementation that caused the mismatch.
|
||||
- Avoid narrow special-case fixes. A fix should state and enforce the general invariant it relies on, or explicitly document why the affected operation is only valid for a restricted layout/shape and ensure rewrites enforce that restriction.
|
||||
- For e-graph/search issues, assume all selectable LLIR graphs are intended to be semantically equivalent. If two selectable graphs disagree, debug the equivalence violation rather than selecting around the bad graph.
|
||||
- Add regression tests at the level where the bug occurred. Prefer tests that compare against a semantic reference such as `NativeRuntime` or a small independent reference, and use fixed seeds for any randomized search/fuzz test so failures are reproducible.
|
||||
|
||||
## Compiler Rewrite Boundary
|
||||
- All graph pattern matching and op selection must be expressed in egglog rewrites. Do not add Rust-side LLIR graph post-passes that search for op patterns, fuse kernels, select backend ops, or otherwise rewrite extracted graphs after egglog. If a backend needs a fused/specialized op, add the match and rewrite in egglog and let extraction produce that op directly.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
@@ -25,7 +25,6 @@ generational-box = "0.5.6"
|
||||
serde_json = "1.0.140"
|
||||
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-reports = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
|
||||
tracing = "0.1.43"
|
||||
paste = "1.0.15"
|
||||
|
||||
50
README.md
50
README.md
@@ -55,27 +55,23 @@ Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100.
|
||||
|
||||
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
### PyTorch-native
|
||||
|
||||
Luminal directly integrates with PyTorch as a compiler backend. Simply do `torch.compile(model, backend=luminal_cuda)` to compile your PyTorch models. We also have an excellent tensor API in Rust.
|
||||
|
||||
### RISC-style architecture
|
||||
|
||||
Everything in Luminal boils down to 15 primitive ops:
|
||||
Everything in Luminal boils down to 14 primitive ops:
|
||||
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Scatter, Cast`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
|
||||
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model in the world.
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model.
|
||||
|
||||
### Search
|
||||
|
||||
The best heuristic is no heuristic. Luminal tries to search every possible decision to give the compiler the flexibility to discover complex optimizations. This allows us to automatically discover Flash Attention and other similarly complex optimizations without relying on hand-written operations or heuristics. It also allows us to stay extremely small and simple long into the future and beat the performance of far larger frameworks.
|
||||
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
|
||||
|
||||
### Native
|
||||
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the accelerator APIs (CUDA, Metal, etc.). No indirections or abstractions, compatability layers, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
@@ -89,45 +85,39 @@ Most deep learning libraries are eager-first, meaning each op call directly oper
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### What about XLA?
|
||||
|
||||
XLA, torch.compile, TVM, and other traditional compiler stacks suffer from complexity explosion. They are made up of a very large set of destructive (one-direction) rewrite rules that lower and optimize a graph from a high-level representation to low-level machine code. But since these rules are destructive, they are required to only fire when it's certian that there's a performance benefit. This leads to the rules becoming very complex, special-cased, and numerous. Once additional hardware backends, model architectures, and new dtypes get thrown in, they suffer from the weight of their complexity and often produce very suboptimal code, requiring DSLs like Pallas or Triton to regain performance.
|
||||
|
||||
### Compile everything
|
||||
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as a static computation graphs, compiled, and executed later.
|
||||
|
||||
### First-class dynamism
|
||||
|
||||
A fully-static world would be nice, but we live in a world of nessecary dynamism. So we model dynamic shapes natively, as symbolic dimensions. Luminal supports arbitrary symbolic dimensions, including complex expressions, to give us shapes like `(s, 4096)`, `(b, h, w + 3)`, etc. This rich representation gives the compiler full visibility into shapes and lets it still do aggressive specialization.
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, Luminal has global knowledge. This means we can push most ML complexity to the compiler. For instance, devices, datatypes, and even autograd is modeled ahead of time and optimized by the compiler!
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
|
||||
|
||||
Now we can do:
|
||||
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Low-precision dtypes (mxfp4, nvfp4, fp8, etc.)
|
||||
- Complex mutli-device parallelism topologies, searched ahead-of-time
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
|
||||
## Where are we?
|
||||
|
||||
- Native PyTorch support
|
||||
- Many kernel libraries supported in the search space (FlashInfer, cuBLASLt, etc.)
|
||||
- Many models implemented in our Rust tensor API in `examples/`.
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
|
||||
- More fine-grained dialects supporting thread- and warp-level intrinsics like TMA and tcgen.05
|
||||
- ROCm backend
|
||||
- More public infernce accelerator backends (coming very soon...)
|
||||
- Public benchmarking suite
|
||||
- Automatically searched model parallelism (TP, PP, EPS, EPR, SP, etc.)
|
||||
- Expand the search space to utilize Tensor Cores more flexibly
|
||||
- Bring cuda to parity with Metal
|
||||
- Add Blackwell intrinsics, such as TMEM and TMA
|
||||
- Build a ROCm backend
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM inference _and_ training
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import re
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
EXPECTED_CONCEPTS = {
|
||||
"llama": [
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "adapt"],
|
||||
["data", "patterns", "features"],
|
||||
],
|
||||
"gemma": [
|
||||
["neural network", "neural networks"],
|
||||
["nodes", "neurons"],
|
||||
["layers"],
|
||||
["weights"],
|
||||
["training", "learn", "learns"],
|
||||
],
|
||||
"qwen": [
|
||||
["neural network", "neural networks"],
|
||||
["computational model", "computational system"],
|
||||
["brain"],
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "training"],
|
||||
],
|
||||
"qwen3_moe": [
|
||||
["capital"],
|
||||
["france"],
|
||||
["paris"],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
normalized_output = normalize_output(output)
|
||||
|
||||
expected_concepts = EXPECTED_CONCEPTS.get(example)
|
||||
if expected_concepts is not None:
|
||||
missing = [
|
||||
concept_group
|
||||
for concept_group in expected_concepts
|
||||
if not any(normalize_output(term) in normalized_output for term in concept_group)
|
||||
]
|
||||
if missing:
|
||||
expected = "\n - ".join(" / ".join(group) for group in expected_concepts)
|
||||
missing_terms = "\n - ".join(" / ".join(group) for group in missing)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}.\n"
|
||||
f"Expected concept groups:\n - {expected}\n"
|
||||
f"Missing concept groups:\n - {missing_terms}"
|
||||
)
|
||||
|
||||
expected = ", ".join(" / ".join(group) for group in expected_concepts)
|
||||
print(f"\nOutput check passed for {example!r}: found concepts {expected}")
|
||||
return
|
||||
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "qwen", "--features", "metal"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("qwen Metal example did not complete generation")
|
||||
validate_output("qwen", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -28,7 +28,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -47,7 +47,6 @@ def run_cargo_test():
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
@@ -20,37 +18,6 @@ hf_cache = modal.Volume.from_name(
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"qwen": ["--features", "cuda"],
|
||||
}
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
@@ -72,7 +39,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
@@ -80,20 +47,17 @@ cuda_image = (
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
sys.path.insert(0, f"{WORKDIR}/ci")
|
||||
from example_output import validate_output
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", *EXAMPLE_CARGO_ARGS.get(example, [])],
|
||||
subprocess.run(
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
validate_output(example, output)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
@@ -24,12 +23,10 @@ memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
lru = "0.16.2"
|
||||
libc = "0.2"
|
||||
libloading = "0.8"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
use std::{collections::BTreeMap, sync::Arc, time::Instant};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::prelude::egglog::{ast::Span, prelude::RustSpan};
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
base::{base_cleanup_egglog, base_expression_egglog},
|
||||
hlir_to_egglog,
|
||||
},
|
||||
hlir::HLIROps,
|
||||
op::{EgglogOp, IntoEgglogOp, Runtime},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
const DEFAULT_PASSES: usize = 256;
|
||||
const EGGLOG_RULESETS: &[&str] = &[
|
||||
"matmul_flatten",
|
||||
"kernel_lower",
|
||||
"direct_kernel",
|
||||
"kernel_specialize",
|
||||
"buffer_reuse",
|
||||
"matmul_backend",
|
||||
"glumoe",
|
||||
"fusion_pair",
|
||||
"fusion_grow",
|
||||
"fusion_merge",
|
||||
];
|
||||
const MOE_SEQ: usize = 2;
|
||||
const MOE_HIDDEN: usize = 16;
|
||||
const MOE_NUM_EXPERTS: usize = 8;
|
||||
const MOE_TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const GEMMA_RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Backend {
|
||||
Native,
|
||||
Cuda,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Mode {
|
||||
Current,
|
||||
Steps,
|
||||
FullDefault,
|
||||
FullCycle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Case {
|
||||
Mul,
|
||||
UnaryChain(usize),
|
||||
Gelu,
|
||||
Softmax,
|
||||
LayerNorm,
|
||||
Matmul,
|
||||
Attention,
|
||||
QwenMoe,
|
||||
GemmaMoe,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
backend: Backend,
|
||||
mode: Mode,
|
||||
case: Case,
|
||||
passes: usize,
|
||||
cleanup: bool,
|
||||
skip_roll: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args {
|
||||
backend: Backend::Cuda,
|
||||
mode: Mode::Current,
|
||||
case: Case::Gelu,
|
||||
passes: DEFAULT_PASSES,
|
||||
cleanup: true,
|
||||
skip_roll: false,
|
||||
};
|
||||
|
||||
let mut iter = std::env::args().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--backend" => {
|
||||
args.backend = match iter.next().as_deref() {
|
||||
Some("native") => Backend::Native,
|
||||
Some("cuda") => Backend::Cuda,
|
||||
other => panic!("invalid --backend {other:?}; use native|cuda"),
|
||||
};
|
||||
}
|
||||
"--mode" => {
|
||||
args.mode = match iter.next().as_deref() {
|
||||
Some("current") => Mode::Current,
|
||||
Some("steps") => Mode::Steps,
|
||||
Some("full-default") => Mode::FullDefault,
|
||||
Some("full-cycle") => Mode::FullCycle,
|
||||
other => panic!(
|
||||
"invalid --mode {other:?}; use current|steps|full-default|full-cycle"
|
||||
),
|
||||
};
|
||||
}
|
||||
"--case" => {
|
||||
args.case = parse_case(&iter.next().expect("missing --case value"));
|
||||
}
|
||||
"--passes" => {
|
||||
args.passes = iter
|
||||
.next()
|
||||
.expect("missing --passes value")
|
||||
.parse()
|
||||
.expect("invalid --passes value");
|
||||
}
|
||||
"--no-cleanup" => args.cleanup = false,
|
||||
"--skip-roll" => args.skip_roll = true,
|
||||
"--help" | "-h" => {
|
||||
println!(
|
||||
"Usage: egglog_saturation [OPTIONS]\n\
|
||||
\n\
|
||||
Options:\n\
|
||||
--backend native|cuda default: cuda\n\
|
||||
--mode current|steps|full-default|full-cycle\n\
|
||||
--case mul|unary-chain:N|gelu|softmax|layer-norm|matmul|attention|qwen-moe|gemma-moe\n\
|
||||
--passes N default: 256\n\
|
||||
--no-cleanup omit backend/HLIR cleanup rules\n\
|
||||
--skip-roll skip auto loop rolling prepass"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => panic!("unknown argument {other}; use --help"),
|
||||
}
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
fn parse_case(s: &str) -> Case {
|
||||
if let Some(n) = s.strip_prefix("unary-chain:") {
|
||||
return Case::UnaryChain(n.parse().expect("invalid unary-chain length"));
|
||||
}
|
||||
match s {
|
||||
"mul" => Case::Mul,
|
||||
"gelu" => Case::Gelu,
|
||||
"softmax" => Case::Softmax,
|
||||
"layer-norm" | "layer_norm" => Case::LayerNorm,
|
||||
"matmul" => Case::Matmul,
|
||||
"attention" => Case::Attention,
|
||||
"qwen-moe" | "qwen_moe" => Case::QwenMoe,
|
||||
"gemma-moe" | "gemma_moe" => Case::GemmaMoe,
|
||||
other => panic!("unknown case {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_case(case: Case) -> Graph {
|
||||
let mut cx = Graph::new();
|
||||
let out = match case {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor((64, 64));
|
||||
x * x
|
||||
}
|
||||
Case::UnaryChain(n) => {
|
||||
let mut x = cx.tensor((64, 64));
|
||||
for i in 0..n {
|
||||
x = match i % 6 {
|
||||
0 => x.sin(),
|
||||
1 => x.sqrt(),
|
||||
2 => x.reciprocal(),
|
||||
3 => x.exp2(),
|
||||
4 => x.log2(),
|
||||
_ => x * 1.125,
|
||||
};
|
||||
}
|
||||
x
|
||||
}
|
||||
Case::Gelu => cx.tensor((64, 64)).gelu(),
|
||||
Case::Softmax => cx.tensor((128, 128)).softmax(1),
|
||||
Case::LayerNorm => cx.tensor((128, 128)).layer_norm(1, 1e-5),
|
||||
Case::Matmul => {
|
||||
let a = cx.tensor((32, 64));
|
||||
let b = cx.tensor((64, 32));
|
||||
a.matmul(b)
|
||||
}
|
||||
Case::Attention => {
|
||||
let q = cx.tensor((64, 32));
|
||||
let k = cx.tensor((64, 32));
|
||||
let v = cx.tensor((64, 32));
|
||||
let scores = q.matmul(k.permute((1, 0))) * (1.0 / 32.0_f32.sqrt());
|
||||
scores.softmax(1).matmul(v)
|
||||
}
|
||||
Case::QwenMoe => build_qwen_moe(&mut cx),
|
||||
Case::GemmaMoe => build_gemma_moe(&mut cx),
|
||||
};
|
||||
let _ = out.output();
|
||||
cx
|
||||
}
|
||||
|
||||
fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let x = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let router_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let expert_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router_scale = cx.tensor(MOE_HIDDEN);
|
||||
let router_proj = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let per_expert_scale = cx.tensor(MOE_NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, GEMMA_RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (MOE_HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, MOE_TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
weights.gather(exp_base + exp_within)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
let mut ir_variants = Vec::new();
|
||||
let mut opkind_variants = Vec::new();
|
||||
for op in ops {
|
||||
let sort = op.sort();
|
||||
let variant = format!(
|
||||
"({} {})",
|
||||
sort.name,
|
||||
sort.fields.iter().map(|field| &field.sort).join(" ")
|
||||
);
|
||||
match sort.class.as_str() {
|
||||
"IR" => ir_variants.push(variant),
|
||||
"OpKind" => opkind_variants.push(variant),
|
||||
other => panic!("unknown sort class {other} for {}", sort.name),
|
||||
}
|
||||
}
|
||||
let extra_ir = ops.iter().flat_map(|op| op.ir_defs()).unique().join("\n");
|
||||
format!(
|
||||
"
|
||||
(datatype*
|
||||
(IR
|
||||
(OutputJoin IR IR)
|
||||
(Op OpKind IList)
|
||||
{extra_ir}
|
||||
{}
|
||||
)
|
||||
(OpKind
|
||||
{}
|
||||
)
|
||||
(IList
|
||||
(ICons IR IList)
|
||||
(INil)
|
||||
)
|
||||
)
|
||||
(function dtype (IR) DType :merge new)
|
||||
",
|
||||
ir_variants.join("\n"),
|
||||
opkind_variants.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
ops.iter()
|
||||
.filter(|op| op.cleanup())
|
||||
.map(|op| {
|
||||
let sort = op.sort();
|
||||
let fields = (0..sort.fields.len())
|
||||
.map(|i| (b'a' + i as u8) as char)
|
||||
.join(" ");
|
||||
if sort.class == "OpKind" {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
((delete (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m ({} {fields})))
|
||||
((delete ({} {fields})))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
}
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn setup_program(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let rewrites = ops
|
||||
.iter()
|
||||
.flat_map(|op| op.rewrites())
|
||||
.map(|rule| rule.to_egglog_string())
|
||||
.join("\n");
|
||||
[
|
||||
EGGLOG_RULESETS
|
||||
.iter()
|
||||
.map(|ruleset| format!("(ruleset {ruleset})"))
|
||||
.join("\n"),
|
||||
base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
base_cleanup_egglog(),
|
||||
rewrites,
|
||||
program.to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn producer_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run matmul_flatten)
|
||||
(run kernel_lower)
|
||||
(run direct_kernel)
|
||||
(run kernel_specialize)
|
||||
(run buffer_reuse)
|
||||
(run matmul_backend)
|
||||
(run glumoe)
|
||||
(run fusion_pair)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fusion_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run fusion_grow)
|
||||
(run fusion_merge)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn split_cycle() -> Vec<(&'static str, String)> {
|
||||
vec![
|
||||
("producers", format!("(saturate {})", producer_schedule())),
|
||||
("fusion", format!("(saturate {})", fusion_schedule())),
|
||||
]
|
||||
}
|
||||
|
||||
fn split_cycle_schedule() -> String {
|
||||
format!(
|
||||
"(seq
|
||||
(saturate {})
|
||||
(saturate {})
|
||||
)",
|
||||
producer_schedule(),
|
||||
fusion_schedule()
|
||||
)
|
||||
}
|
||||
|
||||
fn phase(egraph: &mut egglog::EGraph, name: &str, schedule: &str) -> bool {
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let command = format!("(run-schedule {schedule})");
|
||||
let outputs = egraph
|
||||
.parse_and_run_program(None, &command)
|
||||
.unwrap_or_else(|err| panic!("failed phase {name} schedule {schedule}: {err}"));
|
||||
let elapsed = start.elapsed();
|
||||
let after = egraph.num_tuples();
|
||||
let report = outputs
|
||||
.into_iter()
|
||||
.find_map(|output| match output {
|
||||
egglog::CommandOutput::RunSchedule(report) => Some(report),
|
||||
_ => None,
|
||||
})
|
||||
.expect("run-schedule did not return a report");
|
||||
let mut rules = report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(rule, time)| {
|
||||
(
|
||||
rule.to_string(),
|
||||
*time,
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.get(rule)
|
||||
.copied()
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
rules.sort_by_key(|(_, time, matches)| (std::cmp::Reverse(*time), std::cmp::Reverse(*matches)));
|
||||
let matches = report.num_matches_per_rule.values().sum::<usize>();
|
||||
println!(
|
||||
"phase {name:<18} {elapsed_ms:>8.2} ms | tuples {before} -> {after} ({delta:+}) | updated={updated} | iters={iters} | matches={matches}",
|
||||
elapsed_ms = elapsed.as_secs_f64() * 1000.0,
|
||||
delta = after as isize - before as isize,
|
||||
updated = report.updated,
|
||||
iters = report.iterations.len(),
|
||||
);
|
||||
for (rule, time, matches) in rules
|
||||
.into_iter()
|
||||
.filter(|(_, time, matches)| !time.is_zero() || *matches > 0)
|
||||
.take(8)
|
||||
{
|
||||
println!(
|
||||
" rule {rule:<82} {ms:>8.2} ms | matches {matches}",
|
||||
ms = time.as_secs_f64() * 1000.0,
|
||||
);
|
||||
}
|
||||
report.updated
|
||||
}
|
||||
|
||||
fn serialize_summary(egraph: &mut egglog::EGraph, root: &str) {
|
||||
let (sort, value) = egraph.eval_expr(&egglog::var!(root.to_string())).unwrap();
|
||||
let output = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
let mut classes = std::collections::BTreeSet::new();
|
||||
let mut top_ops = BTreeMap::<String, usize>::new();
|
||||
let mut nodes = 0usize;
|
||||
for node in output.egraph.nodes.values().filter(|node| !node.subsumed) {
|
||||
nodes += 1;
|
||||
classes.insert(node.eclass.clone());
|
||||
*top_ops.entry(node.op.clone()).or_default() += 1;
|
||||
}
|
||||
let top_ops = top_ops
|
||||
.into_iter()
|
||||
.sorted_by_key(|(_, count)| std::cmp::Reverse(*count))
|
||||
.take(12)
|
||||
.map(|(op, count)| format!("{op}={count}"))
|
||||
.join(", ");
|
||||
println!(
|
||||
"serialize nodes={nodes} classes={} roots={} top_ops={top_ops}",
|
||||
classes.len(),
|
||||
output.egraph.root_eclasses.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn run(args: Args) {
|
||||
let mut graph = build_case(args.case);
|
||||
let rolled = if args.skip_roll {
|
||||
0
|
||||
} else {
|
||||
graph.auto_roll_loops_prepass()
|
||||
};
|
||||
let (program, root) = hlir_to_egglog(&graph);
|
||||
|
||||
let mut ops = match args.backend {
|
||||
Backend::Native => <NativeRuntime as Runtime>::Ops::into_vec(),
|
||||
Backend::Cuda => <CudaRuntime as Runtime>::Ops::into_vec(),
|
||||
};
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup = args.cleanup && matches!(args.backend, Backend::Cuda);
|
||||
let setup = setup_program(&program, &ops, cleanup);
|
||||
|
||||
println!(
|
||||
"case={:?} backend={:?} mode={:?} passes={} cleanup={} rolled={} hlir_nodes={} setup_lines={} setup_bytes={} root={root}",
|
||||
args.case,
|
||||
args.backend,
|
||||
args.mode,
|
||||
args.passes,
|
||||
cleanup,
|
||||
rolled,
|
||||
graph.graph.node_count(),
|
||||
setup.lines().count(),
|
||||
setup.len(),
|
||||
);
|
||||
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let commands = egraph.parser.get_program_from_string(None, &setup).unwrap();
|
||||
egraph.run_program(commands).unwrap();
|
||||
println!(
|
||||
"setup {:>8.2} ms | tuples {before} -> {} ({:+})",
|
||||
start.elapsed().as_secs_f64() * 1000.0,
|
||||
egraph.num_tuples(),
|
||||
egraph.num_tuples() as isize - before as isize,
|
||||
);
|
||||
|
||||
match args.mode {
|
||||
Mode::Current | Mode::Steps => {
|
||||
for pass in 1..=args.passes {
|
||||
let mut updated = false;
|
||||
for (name, schedule) in split_cycle() {
|
||||
updated |= phase(&mut egraph, &format!("{pass:03} {name}"), &schedule);
|
||||
}
|
||||
if matches!(args.mode, Mode::Current) && !updated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Mode::FullDefault => {
|
||||
phase(&mut egraph, "expr", "(saturate expr)");
|
||||
phase(&mut egraph, "dtype", "(saturate dtype_prop)");
|
||||
phase(&mut egraph, "default-full", "(saturate (run))");
|
||||
}
|
||||
Mode::FullCycle => {
|
||||
phase(
|
||||
&mut egraph,
|
||||
"cycle-full",
|
||||
&format!("(saturate {})", split_cycle_schedule()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
phase(&mut egraph, "final expr", "(saturate expr)");
|
||||
if cleanup {
|
||||
phase(&mut egraph, "cleanup", "(saturate cleanup)");
|
||||
}
|
||||
phase(&mut egraph, "base cleanup", "(saturate base_cleanup)");
|
||||
serialize_summary(&mut egraph, &root);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
run(parse_args());
|
||||
}
|
||||
@@ -19,9 +19,9 @@ use crate::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
)
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
)
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
@@ -121,7 +116,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
@@ -129,21 +123,17 @@
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?n ; ldd
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
|
||||
@@ -1,428 +0,0 @@
|
||||
; Fuse a row-major Add on top of an existing cuBLASLt matmul into
|
||||
; D = alpha * A * B + beta * C.
|
||||
;
|
||||
; The existing matmul rewrites view Luminal's row-major output [m,n] as a
|
||||
; column-major cuBLASLt matrix [n,m]. A row-major C input with logical strides
|
||||
; [row_stride, 1] therefore maps to ldc=row_stride. This lets a C slice from a
|
||||
; wider parent tensor use a larger ldc while D keeps the matmul output layout.
|
||||
; cuBLASLt requires out-of-place C and D to have the same matrix order, so these
|
||||
; beta rules only fuse C layouts that map to the current COL-ordered D layout.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched c plus matmul beta"
|
||||
)
|
||||
|
||||
; ROW-ordered D beta fusions. These pair with cublaslt_row_order_rewrite.egg,
|
||||
; where the cuBLASLt problem dimensions match Luminal's logical output [m,n].
|
||||
; A row-major C input with logical strides [row_stride, 1] maps directly to a
|
||||
; ROW-ordered cuBLASLt C[m,n] descriptor with ldc=row_stride.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched c plus matmul beta"
|
||||
)
|
||||
@@ -1,614 +0,0 @@
|
||||
; cuBLASLt epilogue rewrites.
|
||||
;
|
||||
; ReLU in the frontend lowers through maximum_f32(0.0):
|
||||
;
|
||||
; (matmul < 0) * 0 + cast(cast((-cast(matmul < 0) + 1) as bool) as f32) * matmul
|
||||
;
|
||||
; These rules fuse that expression back into CUBLASLT_EPILOGUE_RELU.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu bias epilogue"
|
||||
)
|
||||
|
||||
; Canonical tanh-approx GELU can also appear directly as:
|
||||
;
|
||||
; x * sigmoid(1.5957691216 * x * (1 + 0.044715 * x * x))
|
||||
;
|
||||
; Match that sigmoid form and fuse it into the cuBLASLt GELU epilogues.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu bias epilogue"
|
||||
)
|
||||
|
||||
; This first slice fuses column-bias adds into CUBLASLT_EPILOGUE_BIAS for the
|
||||
; older COL-ordered output view. In that view Luminal's logical [m,n] output is
|
||||
; represented as a cuBLASLt [n,m] matrix, so cuBLASLt's row-broadcast bias maps
|
||||
; to the common logical column bias of length n.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d column bias plus matmul epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column bias plus matmul epilogue"
|
||||
)
|
||||
@@ -1,775 +0,0 @@
|
||||
; FP8 support is narrower than "any FP8 x any FP8". cuBLASLt's regular FP8
|
||||
; matmul table supports these A/B descriptor pairs for F32 outputs:
|
||||
; E4M3 x E4M3
|
||||
; E4M3 x E5M2
|
||||
; E5M2 x E4M3
|
||||
; and requires TN format on Ada/Hopper-class GPUs. These rules therefore match
|
||||
; row-major x column-major Luminal matmuls, which the existing COL-order lowering
|
||||
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
|
||||
; transb=N.
|
||||
|
||||
(rule
|
||||
(
|
||||
; Match the scaled FP8 linear form directly before the unscaled FP8
|
||||
; matmul rewrite can hide the quantize/dequant scale structure.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
(= ?scaled (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(= ?cast (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
(
|
||||
(delete (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(delete (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers direct output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Fusion growth can make the live path consume a raw FP8 cuBLASLt
|
||||
; candidate through an internal CudaBinaryElementwise scale multiply,
|
||||
; instead of the original HLIR output-scale Mul. The scalar scale
|
||||
; product is tensor-wide, so the two scalar factors can be passed as
|
||||
; cuBLASLt A/B scale inputs and the internal multiply can be bypassed.
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(let ?fs_sgemm (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
(union ?fused_scale ?fs_sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
(set (dtype ?fs_sgemm) (F32))
|
||||
)
|
||||
:ruleset fusion_grow
|
||||
:name "cublaslt scaled fp8 fused output-scale f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(= ?fused_scale (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
)
|
||||
(
|
||||
(delete (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(delete (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers fused output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Batched form of the scaled FP8 linear rewrite. The scale operands are
|
||||
; scalar tensors expanded across the last three output/activation axes.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 batched row-major x column-major f32 output"
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
; Mixed output dtype rewrites for cuBLASLt.
|
||||
;
|
||||
; The first mixed mode we need for low-precision matmuls is:
|
||||
;
|
||||
; D[f32] = A[fp16/bf16] * B[fp16/bf16]
|
||||
;
|
||||
; Luminal graphs express this today as a Cast(F32) around a low-precision
|
||||
; matmul. cuBLASLt can write the f32 output directly, so expose that candidate
|
||||
; before beta fusion tries to consume an f32 C input.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F16) (F16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt f16 matmul cast f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (Bf16) (Bf16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt bf16 matmul cast f32 output"
|
||||
)
|
||||
@@ -1,452 +0,0 @@
|
||||
; Natural cuBLASLt row-order output rewrites. These keep Luminal's logical
|
||||
; output C[m,n] as a cuBLASLt ROW-ordered D[m,n] instead of using the older
|
||||
; swapped COL-ordered D[n,m] view. A and B orders mirror their matched logical
|
||||
; layouts, so this family is the legal base for future ROW-ordered beta fusions.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x column-major"
|
||||
)
|
||||
@@ -1,316 +0,0 @@
|
||||
; Scalar alpha/beta rewrites for cuBLASLt. These rules target scalar constants
|
||||
; expanded across the matmul/add shape, i.e. zero strides on every logical axis.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; alpha=1.0 hash-conses ?fused == ?matmul; the union merges Mul into ?matmul's eclass and saturate diverges.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; See 2d alpha scale: alpha=1.0 makes (saturate ...) diverge.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c plus matmul beta"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,124 +0,0 @@
|
||||
# FlashInfer Integration
|
||||
|
||||
FlashInfer replaces the multi-op attention pattern (Q×K^T → scale → mask → softmax → ×V) with a single fused GPU kernel via [FlashInfer](https://github.com/flashinfer-ai/flashinfer)'s batch decode and batch prefill APIs.
|
||||
|
||||
## Current State
|
||||
|
||||
**Working:**
|
||||
- Egglog rewrite rule matches any GQA paged attention pattern (model-agnostic shapes)
|
||||
- GA search selects FlashInfer when it wins profiling — verified on Llama 3 8B (32 layers) and Qwen 3 4B (36 layers)
|
||||
- **BatchDecode** (s=1): fp32 natively — FlashInfer's decode kernel uses scalar vectorized dot products, no tensor cores
|
||||
- **BatchPrefill**: template-instantiated for fp16 but **not callable from fp32** — FlashInfer's prefill kernel requires tensor core MMA (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically only operate on 16-bit types; the C API stubs return -1 for fp32; will be enabled when native fp16/bf16 pipeline is added
|
||||
- Decode handles all cases in the current fp32 pipeline (prefill uses cuBLAS attention via dim bucketing)
|
||||
- Indptr-based mask: `qo_indptr` and `kv_indptr` are computed in-graph so the egglog rule can see them in the same chunk as the attention ops
|
||||
|
||||
**Not yet implemented:**
|
||||
- Native fp16 / bf16 pipeline (would eliminate the cast overhead in prefill)
|
||||
- Page sizes > 1
|
||||
|
||||
---
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
src/host/flashinfer/
|
||||
flashinfer_attention.egg — egglog rewrite rule (pattern match → FlashInferAttention)
|
||||
mod.rs — FlashInferAttention op (EgglogOp + HostOp impl)
|
||||
jit.rs — JIT compilation: nvcc wrapper.cu → .so, dlopen, fn pointers
|
||||
find_indptrs.rs — walks the mask e-graph node to locate qo_indptr / kv_indptr inputs
|
||||
wrapper.cu — CUDA: FlashInfer template instantiation + helper kernels
|
||||
wrapper.h — C API header for wrapper.cu
|
||||
README.md — this file
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Egglog Pattern Matching
|
||||
|
||||
The rule in `flashinfer_attention.egg` matches the structural pattern of paged GQA attention:
|
||||
|
||||
```
|
||||
Gather(K_cache, idx) → GQA broadcast (Mul×1.0) → Q×K^T → Sum → scale → mask Add → softmax → attn×V → Sum → output
|
||||
Gather(V_cache, idx) → GQA broadcast (Mul×1.0) ──────────────────────────────────────────→ attn×V → Sum → output
|
||||
```
|
||||
|
||||
Key anchors that prevent false matches on MLP or other ops:
|
||||
- Two Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
- GQA broadcast via `Mul(gathered, Constant(1.0))` with all-zero strides
|
||||
- Mask Add with zero-stride broadcast in the first (nheads) dimension
|
||||
- Two sequential matmul+Sum pairs connected through softmax
|
||||
|
||||
Shape dimensions are egglog variables, not pinned constants — the rule works for any model with GQA (Llama, Qwen, Mistral, etc.). The structural invariants (dimension count, zero-stride positions, Gather from 2D) are enough to avoid combinatorial explosion during saturation.
|
||||
|
||||
When the rule fires, it unions `FlashInferAttention` with the original attention output, making it an equivalent alternative in the e-graph. The GA search then profiles both paths and picks the faster one.
|
||||
|
||||
### 2. Extraction: Finding Indptrs
|
||||
|
||||
During `extract()` (called when egglog selects the FlashInferAttention e-node), `find_indptrs.rs` walks backward from the mask node in the e-graph to locate the `qo_indptr` and `kv_indptr` Input nodes. It validates the mask structure by checking for the `Mul(allowed, Constant(1e10))` pattern that `compute_attn_mask()` produces.
|
||||
|
||||
The indptrs are appended as inputs 5 and 6 to the FlashInferAttention op, so the runtime can build the CSR page table directly without recomputing anything.
|
||||
|
||||
### 3. JIT Compilation
|
||||
|
||||
FlashInfer requires `HEAD_DIM` as a compile-time template parameter. Rather than baking it at `cargo build` time, `jit.rs` JIT-compiles `wrapper.cu` with the model's actual HEAD_DIM:
|
||||
|
||||
1. First call to `ensure_compiled(head_dim)` runs `nvcc` with `-DLUMINAL_HEAD_DIM=<N>`
|
||||
2. The compiled `.so` is cached at `~/.cache/luminal/flashinfer/libflashinfer_hd<N>_<arch>.so`
|
||||
3. Subsequent calls load the cached library via `dlopen`
|
||||
4. Function pointers (plan, run, transpose, etc.) are resolved and stored in a `static OnceLock`
|
||||
|
||||
Supported HEAD_DIM values: 64, 128, 256.
|
||||
|
||||
### 4. Runtime Execution
|
||||
|
||||
`FlashInferAttention::execute()` dispatches to decode or prefill based on `total_q_tokens vs batch_size`:
|
||||
|
||||
**Common steps:**
|
||||
1. **Extract kv_indices** — a helper kernel converts the flat gather index `(c, KV_DIM)` to slot indices `(c,)`
|
||||
2. **Read indptrs to host** — copied to CPU for the plan phase
|
||||
3. **Plan** — queries GPU occupancy and decides split-KV decomposition
|
||||
4. **Run** — the fused kernel writes `(total_q_tokens, num_qo_heads, head_dim)`
|
||||
5. **Transpose** — transposes to `(num_qo_heads, total_q_tokens, head_dim)` to match the Sum reduction layout
|
||||
|
||||
**Decode path** (current, fp32): Always used. Runs FlashInfer's BatchDecode directly on fp32 buffers.
|
||||
|
||||
**Prefill path** (future, fp16/bf16 only): The prefill kernel templates are compiled into the JIT .so for fp16 (CTA_TILE_Q=16/64/128, causal mask). The C API stubs currently return -1 since the pipeline is fp32. When native fp16/bf16 dtype support is added, `execute()` will dispatch to prefill when `total_q_tokens > batch_size`.
|
||||
|
||||
Global workspaces (`static OnceLock`) are shared across all FlashInferAttention instances to avoid ~4ms allocation overhead per GA profiling candidate. Without this, the GA never selects FlashInfer because the first-run allocation cost dwarfs the kernel time.
|
||||
|
||||
## How the Attention Mask Enables FlashInfer
|
||||
|
||||
For the egglog rule to fire, the `qo_indptr` and `kv_indptr` tensors must be visible in the same e-graph chunk as the attention ops. This is why the mask is computed *inside* each layer (via `compute_attn_mask()` in the model) rather than passed as a pre-computed input.
|
||||
|
||||
The mask computation uses a specific structure:
|
||||
```rust
|
||||
let allowed = same_request * causal;
|
||||
allowed * 1e10 - 1e10 // → 0.0 for allowed, -1e10 for blocked
|
||||
```
|
||||
|
||||
The `Mul(allowed, Constant(1e10))` pattern is the anchor that `find_indptrs.rs` uses to walk backward and locate the indptr inputs.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Items listed in priority order. Checked items are done.
|
||||
|
||||
- [x] Model-agnostic egglog rule (shape variables instead of Llama-specific constants)
|
||||
- [x] bs>1 supersequence decode
|
||||
- [x] Indptr-based attention mask (replaces CPU-computed mask)
|
||||
- [x] Multi-model support (verified on Llama 3 8B and Qwen 3 4B)
|
||||
- [x] BatchPrefill kernel compiled for fp16 (causal mask, CTA_TILE_Q=16/64/128)
|
||||
- [ ] Native fp16 / bf16 pipeline (enables prefill, reduces memory, eliminates cuBLAS prefill fallback)
|
||||
- [ ] HEAD_DIM dispatch for 64, 96 (JIT supports 64/128/256; wrapper.cu needs 96 for Phi)
|
||||
- [ ] Page sizes > 1 (currently page_size=1; larger pages reduce CSR overhead)
|
||||
- [ ] Sliding window, ALiBi, logits soft cap (FlashInfer `AttentionVariant` templates)
|
||||
- [ ] MHA / MQA / arbitrary GQA ratios beyond {1, 2, 4, 8}
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- **page_size=1**: Each KV cache slot is one "page". This simplifies the CSR page table (`kv_indices` = physical slot indices directly) and matches the flat `(num_slots, KV_DIM)` cache layout.
|
||||
|
||||
- **Pinned structural anchors**: The egglog rule pins the *structure* (number of dimensions, which dims are zero-stride, presence of Gather from 2D cache) but uses variables for the *values* (head counts, head_dim). This prevents saturation blowup while remaining model-agnostic.
|
||||
|
||||
- **Prefill requires fp16/bf16**: FlashInfer's prefill kernel uses tensor core MMA instructions (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically require 16-bit inputs — there is no fp32 tensor core matmul instruction. The prefill kernel templates are compiled into the .so for fp16 but the C API returns -1 for fp32 callers. When native fp16/bf16 is added, prefill will be enabled automatically.
|
||||
|
||||
- **Global workspaces**: Float workspace (128 MiB), int workspace (8 MiB), and a page-locked host buffer are allocated once via `static OnceLock` and shared across all instances.
|
||||
@@ -1,328 +0,0 @@
|
||||
//! Walk the e-graph from the mask node to find qo_indptr and kv_indptr Input nodes.
|
||||
//!
|
||||
//! The mask is produced by `compute_attn_mask(q_pos, qo_indptr, kv_indptr)` using
|
||||
//! primitive HLIR ops. This module validates the mask's structure and extracts the
|
||||
//! indptr Input node IDs so FlashInfer can use them directly.
|
||||
|
||||
use luminal::egglog_utils::{ClassId, NodeId, SerializedEGraph};
|
||||
use luminal::prelude::FxHashSet;
|
||||
|
||||
/// Result of walking the mask computation chain.
|
||||
#[derive(Debug)]
|
||||
pub struct IndptrNodes<'a> {
|
||||
pub qo_indptr: &'a NodeId,
|
||||
pub kv_indptr: &'a NodeId,
|
||||
}
|
||||
|
||||
/// Find the qo_indptr and kv_indptr Input nodes by walking backwards from the mask.
|
||||
///
|
||||
/// Validates the mask structure: `allowed * 1e10 + (-1e10)`. Then does a BFS from
|
||||
/// the `allowed` subtree to find all reachable Input nodes with names containing
|
||||
/// "qo_indptr" and "kv_indptr".
|
||||
///
|
||||
/// Panics with a diagnostic message if the structure doesn't match or the
|
||||
/// indptr inputs can't be found.
|
||||
pub fn find_indptr_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let mask_inputs = logical_binary_inputs(egraph, mask_node, "Add").unwrap_or_else(|| {
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
panic!("find_indptr_inputs: mask is not an Add (kind={mask_kind_label})");
|
||||
});
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
"find_indptr_inputs: mask Add should have 2 inputs, got {}",
|
||||
mask_inputs.len()
|
||||
);
|
||||
|
||||
// Step 2: One of the inputs should be Mul(allowed, Constant(1e10))
|
||||
let (scaled_allowed, allowed_node) = find_1e10_mul(egraph, &mask_inputs);
|
||||
|
||||
// Step 3: BFS from `allowed` to find all reachable Input nodes
|
||||
let reachable_inputs = find_reachable_inputs(egraph, allowed_node);
|
||||
|
||||
// Step 4: Match by name
|
||||
let mut qo_indptr: Option<&NodeId> = None;
|
||||
let mut kv_indptr: Option<&NodeId> = None;
|
||||
|
||||
for (node_id, name) in &reachable_inputs {
|
||||
if name.contains("qo_indptr") {
|
||||
qo_indptr = Some(node_id);
|
||||
} else if name.contains("kv_indptr") {
|
||||
kv_indptr = Some(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let qo = qo_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'qo_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
let kv = kv_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'kv_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
IndptrNodes {
|
||||
qo_indptr: qo,
|
||||
kv_indptr: kv,
|
||||
}
|
||||
}
|
||||
|
||||
fn find_1e10_mul<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let Some(mul_inputs) = logical_binary_inputs(egraph, input_node, "Mul") else {
|
||||
continue;
|
||||
};
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
for (i, &inp) in mul_inputs.iter().enumerate() {
|
||||
if is_constant(egraph, inp, 1e10) {
|
||||
let other = mul_inputs[1 - i];
|
||||
return (input_node, other);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut debug_info = String::new();
|
||||
for (i, &input_node) in mask_add_inputs.iter().enumerate() {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
debug_info.push_str(&format!("\n input[{i}]: label={label}"));
|
||||
if label == "Op" && !children.is_empty() {
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
debug_info.push_str(&format!(" kind={kind_label}"));
|
||||
for (j, kc) in egraph.enodes[kind].1.iter().enumerate() {
|
||||
let kc_node = resolve_first_node(egraph, kc);
|
||||
debug_info.push_str(&format!(" child[{j}]={}", egraph.enodes[kc_node].0));
|
||||
}
|
||||
if kind_label.contains("Mul") && children.len() >= 2 {
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for (j, &mi) in mul_inputs.iter().enumerate() {
|
||||
let (ml, mc) = &egraph.enodes[mi];
|
||||
debug_info.push_str(&format!("\n mul_input[{j}]: label={ml}"));
|
||||
if ml == "Op" && !mc.is_empty() {
|
||||
let mk = resolve_first_node(egraph, &mc[0]);
|
||||
debug_info.push_str(&format!(" kind={}", egraph.enodes[mk].0));
|
||||
for (k, mkc) in egraph.enodes[mk].1.iter().enumerate() {
|
||||
let mkc_node = resolve_first_node(egraph, mkc);
|
||||
debug_info.push_str(&format!(" ch[{k}]={}", egraph.enodes[mkc_node].0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find Mul(allowed, Constant(1e10)) in mask Add inputs.{debug_info}"
|
||||
);
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let node = resolve_op_with_kind(egraph, node, "Constant").unwrap_or(node);
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
if !kind_label.contains("Constant") {
|
||||
return false;
|
||||
}
|
||||
let val_children = &egraph.enodes[kind].1;
|
||||
if val_children.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let val_node = resolve_first_node(egraph, &val_children[0]);
|
||||
let val_str = &egraph.enodes[val_node].0;
|
||||
if let Ok(val) = val_str.parse::<f64>() {
|
||||
(val as f32 - expected).abs() < 1.0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn find_reachable_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
start: &'a NodeId,
|
||||
) -> Vec<(&'a NodeId, String)> {
|
||||
let mut found = Vec::new();
|
||||
let mut visited = FxHashSet::default();
|
||||
let mut stack = vec![start];
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
if !visited.insert(node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
|
||||
if label == "Input" {
|
||||
if children.len() >= 2 {
|
||||
let name_node = resolve_first_node(egraph, &children[1]);
|
||||
let name = egraph.enodes[name_node].0.trim_matches('"').to_string();
|
||||
found.push((node, name));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if label == "Op" && children.len() >= 2 {
|
||||
let ir_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for inp in ir_inputs {
|
||||
stack.push(inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found
|
||||
}
|
||||
|
||||
fn walk_ilist_simple<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
ilist_eclass: &'a ClassId,
|
||||
) -> Vec<&'a NodeId> {
|
||||
let mut inputs = Vec::new();
|
||||
let mut current = resolve_first_node(egraph, ilist_eclass);
|
||||
|
||||
loop {
|
||||
let (label, children) = &egraph.enodes[current];
|
||||
if label == "INil" {
|
||||
break;
|
||||
}
|
||||
if label != "ICons" {
|
||||
break;
|
||||
}
|
||||
let ir_node = resolve_first_ir_node(egraph, &children[0]);
|
||||
inputs.push(ir_node);
|
||||
current = resolve_first_node(egraph, &children[1]);
|
||||
}
|
||||
|
||||
inputs
|
||||
}
|
||||
|
||||
fn resolve_first_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
&egraph.eclasses[eclass].1[0]
|
||||
}
|
||||
|
||||
fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
let nodes = &egraph.eclasses[eclass].1;
|
||||
for node in nodes {
|
||||
let label = &egraph.enodes[node].0;
|
||||
if label == "Op" || label == "Input" {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
|
||||
fn resolve_op_with_kind<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
kind_substr: &str,
|
||||
) -> Option<&'a NodeId> {
|
||||
let class = egraph.node_to_class.get(node)?;
|
||||
for candidate in &egraph.eclasses[class].1 {
|
||||
let (label, children) = &egraph.enodes[candidate];
|
||||
if label != "Op" || children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains(kind_substr) {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn logical_binary_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
op_name: &str,
|
||||
) -> Option<Vec<&'a NodeId>> {
|
||||
if let Some(op_node) = resolve_op_with_kind(egraph, node, op_name) {
|
||||
let (_, children) = &egraph.enodes[op_node];
|
||||
return Some(walk_ilist_simple(egraph, &children[1]));
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains("CudaBinaryElementwise") {
|
||||
let opcode_class = egraph.enodes[kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
return Some(
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if !egraph.enodes[kind].0.contains("FusionEnd") {
|
||||
return None;
|
||||
}
|
||||
let fe_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
let elem = *fe_inputs.first()?;
|
||||
let (elem_label, elem_children) = &egraph.enodes[elem];
|
||||
if elem_label != "Op" || elem_children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let elem_kind = resolve_first_node(egraph, &elem_children[0]);
|
||||
if !egraph.enodes[elem_kind].0.contains("CudaBinaryElementwise") {
|
||||
return None;
|
||||
}
|
||||
let opcode_class = egraph.enodes[elem_kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
walk_ilist_simple(egraph, &elem_children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn unwrap_fusion_start<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> &'a NodeId {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return node;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("FusionStart") {
|
||||
return node;
|
||||
}
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.first()
|
||||
.copied()
|
||||
.unwrap_or(node)
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
; FlashInfer batch decode attention rewrite rule.
|
||||
;
|
||||
; Matches the paged attention pattern for ANY model with GQA:
|
||||
; Gather(K_cache) → GQA broadcast → Q*K^T matmul → scale → add mask → softmax → attn*V matmul
|
||||
; Gather(V_cache) → GQA broadcast ──────────────────────────────────────────→ attn*V matmul
|
||||
;
|
||||
; Structural anchors (prevent false matches on MLP/other ops):
|
||||
; - Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
; - GQA broadcast via Mul(gathered, Constant(1.0)) with all-zero strides
|
||||
; - Scale Mul(QK, constant) connecting QK scores to mask Add
|
||||
; - Mask Add with zero-stride broadcast in first dim (nheads broadcast)
|
||||
; - Data flow: two sequential matmul+reduce pairs connected through softmax
|
||||
;
|
||||
; The egglog rule captures the mask as 5th input. During extract(), a Rust
|
||||
; function walks the mask's computation chain in the e-graph to locate the
|
||||
; qo_indptr and kv_indptr Input nodes (validated via the Constant(1e10) anchor
|
||||
; and structural checks). These are appended as inputs 5 and 6 so FlashInfer
|
||||
; can build the CSR page table directly — no runtime derivation needed.
|
||||
;
|
||||
; Shape dimensions are egglog variables, not pinned constants.
|
||||
; Dynamic dims "s" (batch/seq) and "c" (context) stay pinned as MVar.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ── Second matmul: Mul(softmax_out, V_gqa) ──
|
||||
; Shape: (nheads, s, hdim, c) — 4D
|
||||
(= ?mul2 (Op (Mul
|
||||
(ECons ?nheads (ECons (MVar "s") (ECons ?hdim (ECons (MVar "c") (ENil)))))
|
||||
?mul2_a_strides
|
||||
?mul2_b_strides
|
||||
?mul2_out_strides)
|
||||
(ICons ?soft (ICons ?v_gqa (INil)))))
|
||||
|
||||
; ── Second matmul: Sum (reduction over c) → output ──
|
||||
; Shape: (nheads, s, hdim) — reduces c
|
||||
(= ?output (Op (Sum
|
||||
(ECons ?nheads2 (ECons (MVar "s") (ECons ?hdim2 (ENil))))
|
||||
(MVar "c")
|
||||
?out_in_strides
|
||||
(MIter)
|
||||
?out_out_strides)
|
||||
(ICons ?mul2 (INil))))
|
||||
|
||||
; ── V GQA broadcast: Mul(V_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, c, hdim) — 3D
|
||||
(= ?v_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?v_gqa (Op (Mul
|
||||
(ECons ?nheads3 (ECons (MVar "c") (ECons ?hdim3 (ENil))))
|
||||
?v_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?v_gqa_out_strides)
|
||||
(ICons ?v_gathered (ICons ?v_gqa_const (INil)))))
|
||||
|
||||
; ── V Gather: rows from V_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?v_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim (ENil)))
|
||||
?v_gather_strides
|
||||
(ECons ?num_slots_v (ECons ?kvdim2 (ENil)))
|
||||
?v_src_strides)
|
||||
(ICons ?v_idx (ICons ?v_cache (INil)))))
|
||||
|
||||
; ── First matmul: Mul(Q, K_gqa) ──
|
||||
; Shape: (nheads, s, c, hdim) — 4D
|
||||
(= ?mul1 (Op (Mul
|
||||
(ECons ?nheads4 (ECons (MVar "s") (ECons (MVar "c") (ECons ?hdim4 (ENil)))))
|
||||
?mul1_a_strides
|
||||
?mul1_b_strides
|
||||
?mul1_out_strides)
|
||||
(ICons ?q (ICons ?k_gqa (INil)))))
|
||||
|
||||
; ── First matmul: Sum (reduction over hdim) → QK scores ──
|
||||
; Shape: (nheads, s, c) — reduces hdim
|
||||
(= ?qk (Op (Sum
|
||||
(ECons ?nheads5 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?hdim5
|
||||
?qk_in_strides
|
||||
(MIter)
|
||||
?qk_out_strides)
|
||||
(ICons ?mul1 (INil))))
|
||||
|
||||
; ── Mask Add: Add(scaled_QK, mask) ──
|
||||
; Shape: (nheads, s, c) — 3D
|
||||
; Mask is broadcast from (s, c) via zero-stride in first dim (nheads).
|
||||
(= ?masked (Op (Add
|
||||
(ECons ?nheads8 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?mask_add_a_strides
|
||||
(ECons (MNum 0) ?mask_rest_strides)
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; FlashInfer needs qo_indptr/kv_indptr to be recoverable from the mask
|
||||
; expression. Do not match examples that pass a precomputed mask Input.
|
||||
(= ?mask (Op (Add ?inner_mask_shape ?inner_mask_a_strides ?inner_mask_b_strides ?inner_mask_out_strides)
|
||||
(ICons ?mask_scaled_allowed (ICons ?mask_offset (INil)))))
|
||||
(= ?mask_scaled_allowed (Op (Mul ?allowed_shape ?allowed_strides ?scale_const_strides ?scaled_allowed_strides)
|
||||
(ICons ?mask_allowed (ICons ?mask_scale_const (INil)))))
|
||||
(= ?mask_scale_const (Op (Constant ?mask_scale_val) (INil)))
|
||||
(> ?mask_scale_val 9999999999.0)
|
||||
(< ?mask_scale_val 10000000001.0)
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?k_gqa (Op (Mul
|
||||
(ECons ?nheads6 (ECons ?hdim6 (ECons (MVar "c") (ENil))))
|
||||
?k_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?k_gqa_out_strides)
|
||||
(ICons ?k_gathered (ICons ?k_gqa_const (INil)))))
|
||||
|
||||
; ── K Gather: rows from K_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?k_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim3 (ENil)))
|
||||
?k_gather_strides
|
||||
(ECons ?num_slots_k (ECons ?kvdim4 (ENil)))
|
||||
?k_src_strides)
|
||||
(ICons ?k_idx (ICons ?k_cache (INil)))))
|
||||
|
||||
; ── Dtype consistency ──
|
||||
(= ?dt (dtype ?q))
|
||||
(= ?dt (dtype ?k_cache))
|
||||
(= ?dt (dtype ?v_cache))
|
||||
)
|
||||
(
|
||||
(let ?fi (Op (FlashInferAttention
|
||||
?nheads (MDiv ?kvdim ?hdim) ?hdim (MNum 1) (MVar "s"))
|
||||
(ICons ?q (ICons ?k_cache (ICons ?v_cache (ICons ?k_idx (ICons ?mask (INil))))))))
|
||||
(union ?output ?fi)
|
||||
(set (dtype ?fi) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "FlashInfer batch decode attention"
|
||||
)
|
||||
@@ -1,504 +0,0 @@
|
||||
//! JIT compilation and dynamic loading of FlashInfer kernels.
|
||||
//!
|
||||
//! Everything runs at compile / profiling time — there is no `build.rs`.
|
||||
//! `wrapper.cu` and `wrapper.h` are embedded via `include_str!()` and
|
||||
//! extracted to the cache directory on first use. The FlashInfer + CUTLASS
|
||||
//! header trees are located by probing `LUMINAL_FLASHINFER_DIR`, a small set
|
||||
//! of default paths, and (as a last resort) by `git clone`-ing FlashInfer at
|
||||
//! a pinned commit into the cache. `nvcc` is then invoked with the model's
|
||||
//! actual `HEAD_DIM` and the resulting `.so` is `dlopen`'d.
|
||||
//!
|
||||
//! `ensure_compiled` is called from `FlashInferAttention::extract()`, i.e.
|
||||
//! during luminal's compile / GA-profiling phase, not from `execute()`. After
|
||||
//! the first call the `OnceLock` makes subsequent lookups free.
|
||||
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
hash::{Hash, Hasher},
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
// ── Function pointer types matching wrapper.h ──
|
||||
|
||||
pub type PlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
indptr_h: *mut i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type RunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
pub type ExtractFn = unsafe extern "C" fn(
|
||||
flat_idx: *const i32,
|
||||
out: *mut i32,
|
||||
c: i32,
|
||||
kv_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type DeriveIndptrFn =
|
||||
unsafe extern "C" fn(mask: *const f32, indptr: *mut i32, s: i32, c: i32, stream: *mut c_void);
|
||||
|
||||
pub type TransposeOutputFn = unsafe extern "C" fn(
|
||||
src: *const f32,
|
||||
dst: *mut f32,
|
||||
batch: i32,
|
||||
heads: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type PrefillPlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
qo_indptr_h: *mut i32,
|
||||
kv_indptr_h: *mut i32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type PrefillRunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
qo_indptr: *mut i32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
// ── Embedded CUDA sources ──
|
||||
|
||||
const WRAPPER_CU: &str = include_str!("wrapper.cu");
|
||||
const WRAPPER_H: &str = include_str!("wrapper.h");
|
||||
|
||||
// ── Loaded library handle ──
|
||||
|
||||
pub struct FlashInferLib {
|
||||
// Keep the handle alive so the dlopen'd .so remains mapped.
|
||||
_lib: libloading::Library,
|
||||
pub plan: PlanFn,
|
||||
pub run: RunFn,
|
||||
pub extract_slot_indices: ExtractFn,
|
||||
pub derive_indptr_from_mask: DeriveIndptrFn,
|
||||
pub transpose_output: TransposeOutputFn,
|
||||
pub prefill_plan: PrefillPlanFn,
|
||||
pub prefill_run: PrefillRunFn,
|
||||
}
|
||||
|
||||
// SAFETY: The library handle and function pointers are valid for the lifetime
|
||||
// of the process. All functions are called with proper CUDA stream serialization.
|
||||
unsafe impl Send for FlashInferLib {}
|
||||
unsafe impl Sync for FlashInferLib {}
|
||||
|
||||
static FLASHINFER_LIB: OnceLock<FlashInferLib> = OnceLock::new();
|
||||
|
||||
/// Ensure the FlashInfer library is compiled and loaded for the given HEAD_DIM.
|
||||
/// Returns a reference to the loaded library. Thread-safe via OnceLock.
|
||||
pub fn ensure_compiled(head_dim: usize) -> &'static FlashInferLib {
|
||||
FLASHINFER_LIB.get_or_init(|| {
|
||||
assert!(
|
||||
matches!(head_dim, 64 | 128 | 256),
|
||||
"FlashInfer: unsupported HEAD_DIM={} (must be 64, 128, or 256 for f32)",
|
||||
head_dim
|
||||
);
|
||||
let so_path = compile_or_cache(head_dim);
|
||||
unsafe {
|
||||
FlashInferLib::load(&so_path)
|
||||
.unwrap_or_else(|e| panic!("Failed to load FlashInfer library: {e}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl FlashInferLib {
|
||||
/// Load a compiled FlashInfer .so and resolve function pointers.
|
||||
///
|
||||
/// # Safety
|
||||
/// The .so must be a valid FlashInfer wrapper compiled from wrapper.cu.
|
||||
unsafe fn load(path: &Path) -> Result<Self, libloading::Error> {
|
||||
let lib = unsafe { libloading::Library::new(path)? };
|
||||
let plan: PlanFn = unsafe { *lib.get::<PlanFn>(b"flashinfer_batch_decode_plan\0")? };
|
||||
let run: RunFn = unsafe { *lib.get::<RunFn>(b"flashinfer_batch_decode_run\0")? };
|
||||
let extract_slot_indices: ExtractFn =
|
||||
unsafe { *lib.get::<ExtractFn>(b"flashinfer_extract_slot_indices\0")? };
|
||||
let derive_indptr_from_mask: DeriveIndptrFn =
|
||||
unsafe { *lib.get::<DeriveIndptrFn>(b"flashinfer_derive_indptr_from_mask\0")? };
|
||||
let transpose_output: TransposeOutputFn =
|
||||
unsafe { *lib.get::<TransposeOutputFn>(b"flashinfer_transpose_output\0")? };
|
||||
let prefill_plan: PrefillPlanFn =
|
||||
unsafe { *lib.get::<PrefillPlanFn>(b"flashinfer_batch_prefill_plan\0")? };
|
||||
let prefill_run: PrefillRunFn =
|
||||
unsafe { *lib.get::<PrefillRunFn>(b"flashinfer_batch_prefill_run\0")? };
|
||||
Ok(Self {
|
||||
_lib: lib,
|
||||
plan,
|
||||
run,
|
||||
extract_slot_indices,
|
||||
derive_indptr_from_mask,
|
||||
transpose_output,
|
||||
prefill_plan,
|
||||
prefill_run,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile wrapper.cu for the given HEAD_DIM, or return cached .so path.
|
||||
fn compile_or_cache(head_dim: usize) -> PathBuf {
|
||||
let cache_dir = cache_directory();
|
||||
std::fs::create_dir_all(&cache_dir).expect("Failed to create FlashInfer cache directory");
|
||||
|
||||
// Extract bundled wrapper sources to the cache so nvcc can compile them.
|
||||
let (wrapper_cu_path, wrapper_h_dir) = extract_wrapper_sources(&cache_dir);
|
||||
|
||||
let arch = detect_cuda_arch();
|
||||
// Bake a hash of the embedded wrapper into the .so name so old caches are
|
||||
// discarded automatically when wrapper.cu or wrapper.h change.
|
||||
let wrapper_hash = wrapper_source_hash();
|
||||
let so_name = format!(
|
||||
"libflashinfer_hd{}_{}_w{:016x}.so",
|
||||
head_dim, arch, wrapper_hash
|
||||
);
|
||||
let so_path = cache_dir.join(&so_name);
|
||||
|
||||
if so_path.exists() {
|
||||
eprintln!(
|
||||
"FlashInfer: using cached library for HEAD_DIM={} ({})",
|
||||
head_dim,
|
||||
so_path.display()
|
||||
);
|
||||
return so_path;
|
||||
}
|
||||
|
||||
let Some((flashinfer_include, cutlass_include)) = locate_flashinfer_includes() else {
|
||||
panic!(
|
||||
"FlashInfer: could not locate header tree. Set LUMINAL_FLASHINFER_DIR to the \
|
||||
FlashInfer source root (the directory containing `include/` and \
|
||||
`3rdparty/cutlass/include/`)."
|
||||
);
|
||||
};
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: JIT compiling for HEAD_DIM={}, arch={} ...",
|
||||
head_dim, arch
|
||||
);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let output = Command::new("nvcc")
|
||||
.args([
|
||||
"-shared",
|
||||
"-o",
|
||||
so_path.to_str().unwrap(),
|
||||
&format!("-DLUMINAL_HEAD_DIM={}", head_dim),
|
||||
wrapper_cu_path.to_str().unwrap(),
|
||||
"-I",
|
||||
flashinfer_include.to_str().unwrap(),
|
||||
"-I",
|
||||
cutlass_include.to_str().unwrap(),
|
||||
"-I",
|
||||
wrapper_h_dir.to_str().unwrap(),
|
||||
"-std=c++17",
|
||||
&format!("-arch={}", arch),
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-w",
|
||||
"-rdc=true",
|
||||
"--compiler-options",
|
||||
"-fPIC",
|
||||
])
|
||||
.output()
|
||||
.expect("Failed to run nvcc. Is the CUDA toolkit installed?");
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let _ = std::fs::remove_file(&so_path);
|
||||
panic!(
|
||||
"FlashInfer JIT compilation failed (HEAD_DIM={}, arch={}):\nstdout: {}\nstderr: {}",
|
||||
head_dim, arch, stdout, stderr
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
"FlashInfer: compiled in {:.1}s → {}",
|
||||
elapsed.as_secs_f64(),
|
||||
so_path.display()
|
||||
);
|
||||
|
||||
so_path
|
||||
}
|
||||
|
||||
/// Returns ~/.cache/luminal/flashinfer/
|
||||
fn cache_directory() -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("luminal")
|
||||
.join("flashinfer")
|
||||
}
|
||||
|
||||
/// Drop the embedded wrapper.cu/wrapper.h into the cache dir so nvcc has files
|
||||
/// on disk to compile. Returns (wrapper.cu path, directory containing wrapper.h).
|
||||
fn extract_wrapper_sources(cache_dir: &Path) -> (PathBuf, PathBuf) {
|
||||
let cu = cache_dir.join("wrapper.cu");
|
||||
let h = cache_dir.join("wrapper.h");
|
||||
write_if_changed(&cu, WRAPPER_CU.as_bytes());
|
||||
write_if_changed(&h, WRAPPER_H.as_bytes());
|
||||
(cu, cache_dir.to_path_buf())
|
||||
}
|
||||
|
||||
fn write_if_changed(path: &Path, contents: &[u8]) {
|
||||
if let Ok(existing) = std::fs::read(path)
|
||||
&& existing == contents
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::fs::write(path, contents).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"FlashInfer: failed to write wrapper source to {}: {e}",
|
||||
path.display()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn wrapper_source_hash() -> u64 {
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
WRAPPER_CU.hash(&mut hasher);
|
||||
WRAPPER_H.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// ── Pinned FlashInfer source ──
|
||||
//
|
||||
// Bumping this constant invalidates the cached source tree AND the cached .so
|
||||
// (the .so cache key incorporates the wrapper hash, which is rebuilt against
|
||||
// these headers, so different headers compile to a different .so file even at
|
||||
// the same head_dim). If you change `FLASHINFER_GIT_REV`, also re-check
|
||||
// `wrapper.cu` against the new FlashInfer API.
|
||||
|
||||
const FLASHINFER_GIT_URL: &str = "https://github.com/flashinfer-ai/flashinfer.git";
|
||||
const CUTLASS_GIT_URL: &str = "https://github.com/NVIDIA/cutlass.git";
|
||||
const FLASHINFER_GIT_REV: &str = "f1e6fdcb8f65104047697f022b5d055ef022d763";
|
||||
const CUTLASS_GIT_REV: &str = "f3fde58372d33e9a5650ba7b80fc48b3b49d40c8";
|
||||
|
||||
fn locate_flashinfer_includes() -> Option<(PathBuf, PathBuf)> {
|
||||
if let Ok(path) = std::env::var("LUMINAL_FLASHINFER_DIR")
|
||||
&& !path.is_empty()
|
||||
{
|
||||
let root = PathBuf::from(path);
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
eprintln!(
|
||||
"FlashInfer: LUMINAL_FLASHINFER_DIR={} did not contain include/ and \
|
||||
3rdparty/cutlass/include/ — falling back to default locations",
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
|
||||
let home = std::env::var("HOME").unwrap_or_default();
|
||||
let candidates = [
|
||||
PathBuf::from(&home).join("luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
PathBuf::from(&home).join("luminal_cuda/flashinfer"),
|
||||
PathBuf::from("/opt/luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
];
|
||||
for root in candidates {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: fetch the pinned commit into the cache directory.
|
||||
fetch_flashinfer_source().ok().map(|root| {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
(inc, cutlass)
|
||||
})
|
||||
}
|
||||
|
||||
/// Clone FlashInfer at `FLASHINFER_GIT_REV` + CUTLASS at `CUTLASS_GIT_REV`
|
||||
/// into `~/.cache/luminal/flashinfer-src/<short_rev>/` if absent, then return
|
||||
/// the FlashInfer root directory. ~50 MB one-time download; subsequent calls
|
||||
/// short-circuit on the directory check.
|
||||
fn fetch_flashinfer_source() -> Result<PathBuf, String> {
|
||||
let short = &FLASHINFER_GIT_REV[..12];
|
||||
let cache_root = cache_directory().join("flashinfer-src").join(short);
|
||||
let inc = cache_root.join("include");
|
||||
let cutlass_inc = cache_root.join("3rdparty/cutlass/include");
|
||||
|
||||
if inc.exists() && cutlass_inc.exists() {
|
||||
return Ok(cache_root);
|
||||
}
|
||||
|
||||
let parent = cache_root.parent().unwrap();
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("failed to create {}: {e}", parent.display()))?;
|
||||
|
||||
// Clone into a staging dir, then atomic rename. Protects against multiple
|
||||
// processes racing to fetch the same source.
|
||||
let staging = parent.join(format!(".staging-{}-{}", short, std::process::id()));
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: cloning {FLASHINFER_GIT_URL} @ {short} into {} (one-time fetch, ~50 MB) …",
|
||||
cache_root.display()
|
||||
);
|
||||
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
FLASHINFER_GIT_URL,
|
||||
staging.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&staging, &["checkout", FLASHINFER_GIT_REV])?;
|
||||
|
||||
// Init only the CUTLASS submodule (skip spdlog — we don't need it for kernels).
|
||||
let cutlass_path = staging.join("3rdparty/cutlass");
|
||||
let _ = std::fs::remove_dir_all(&cutlass_path);
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
CUTLASS_GIT_URL,
|
||||
cutlass_path.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&cutlass_path, &["checkout", CUTLASS_GIT_REV])?;
|
||||
|
||||
if !staging.join("include").exists() {
|
||||
return Err(format!(
|
||||
"FlashInfer clone succeeded but include/ missing at {}",
|
||||
staging.display()
|
||||
));
|
||||
}
|
||||
if !staging.join("3rdparty/cutlass/include").exists() {
|
||||
return Err(format!(
|
||||
"CUTLASS clone succeeded but include/ missing at {}",
|
||||
staging.join("3rdparty/cutlass").display()
|
||||
));
|
||||
}
|
||||
|
||||
// Atomic-ish rename. If another process beat us to it, just keep theirs.
|
||||
match std::fs::rename(&staging, &cache_root) {
|
||||
Ok(()) => {}
|
||||
Err(_) if cache_root.exists() => {
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
}
|
||||
Err(e) => return Err(format!("rename to {} failed: {e}", cache_root.display())),
|
||||
}
|
||||
|
||||
Ok(cache_root)
|
||||
}
|
||||
|
||||
fn run_git(args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}. Is git installed?"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` failed: {}",
|
||||
args.join(" "),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_git_in(cwd: &Path, args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` in {} failed: {}",
|
||||
args.join(" "),
|
||||
cwd.display(),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect CUDA arch via env override → nvidia-smi → default sm_80.
|
||||
fn detect_cuda_arch() -> String {
|
||||
if let Ok(arch) = std::env::var("FLASHINFER_CUDA_ARCH") {
|
||||
return arch;
|
||||
}
|
||||
|
||||
if let Ok(output) = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
|
||||
.output()
|
||||
&& output.status.success()
|
||||
{
|
||||
let cap = String::from_utf8_lossy(&output.stdout);
|
||||
let cap = cap.trim().lines().next().unwrap_or("8.0");
|
||||
let sm = cap.replace('.', "");
|
||||
if !sm.is_empty() {
|
||||
return format!("sm_{}", sm);
|
||||
}
|
||||
}
|
||||
|
||||
"sm_80".to_string()
|
||||
}
|
||||
@@ -1,424 +0,0 @@
|
||||
pub mod find_indptrs;
|
||||
pub mod jit;
|
||||
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaSlice, CudaStream, DevicePtr, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// FlashInfer attention op (batch decode, fp32).
|
||||
///
|
||||
/// Replaces the full paged-GQA attention pattern (gather → broadcast → Q*K^T →
|
||||
/// scale → mask → softmax → *V) with a single FlashInfer fused kernel.
|
||||
///
|
||||
/// Graph inputs (7): Q, K_pool, V_pool, flat_gather_idx, mask, qo_indptr, kv_indptr.
|
||||
/// The egglog rule captures the first 5; `extract()` appends qo/kv indptrs after
|
||||
/// walking the e-graph from the mask. `batch_size` is derived at runtime from the
|
||||
/// indptr length (= num_sequences + 1).
|
||||
#[derive(Debug)]
|
||||
pub struct FlashInferAttention {
|
||||
pub num_qo_heads: usize,
|
||||
pub num_kv_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub page_size: usize,
|
||||
pub batch_dim: Expression,
|
||||
|
||||
pub plan_info: Mutex<Vec<i64>>,
|
||||
}
|
||||
|
||||
// SAFETY: PAGE_LOCKED_WORKSPACE holds a raw pointer to page-locked CUDA memory
|
||||
// allocated once and serialized via the CUDA stream that owns it.
|
||||
unsafe impl Send for FlashInferAttention {}
|
||||
unsafe impl Sync for FlashInferAttention {}
|
||||
|
||||
const FLOAT_WORKSPACE_SIZE: usize = 128 * 1024 * 1024; // 128 MiB
|
||||
const INT_WORKSPACE_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
static PAGE_LOCKED_WORKSPACE: OnceLock<PageLockedPtr> = OnceLock::new();
|
||||
|
||||
struct PageLockedPtr(*mut u8);
|
||||
|
||||
// SAFETY: The pointer is page-locked CUDA memory allocated once via
|
||||
// posix_memalign + cudaHostRegister and only mutated during OnceLock
|
||||
// initialization.
|
||||
unsafe impl Send for PageLockedPtr {}
|
||||
unsafe impl Sync for PageLockedPtr {}
|
||||
|
||||
impl std::fmt::Debug for PageLockedPtr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PageLockedPtr({:p})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FlashInferAttention {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_qo_heads: 0,
|
||||
num_kv_heads: 0,
|
||||
head_dim: 0,
|
||||
page_size: 0,
|
||||
batch_dim: Expression::default(),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for FlashInferAttention {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FlashInferAttention",
|
||||
&[
|
||||
("num_qo_heads", EXPRESSION),
|
||||
("num_kv_heads", EXPRESSION),
|
||||
("head_dim", EXPRESSION),
|
||||
("page_size", EXPRESSION),
|
||||
("batch_dim", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
// Q, K_pool, V_pool, flat_gather_idx, mask (egglog IList).
|
||||
// extract() appends qo_indptr + kv_indptr → 7 actual inputs at runtime.
|
||||
5
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["flashinfer_attention.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let num_qo_heads = extract_expr(egraph, kind_children[0], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let num_kv_heads = extract_expr(egraph, kind_children[1], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let head_dim = extract_expr(egraph, kind_children[2], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let page_size = extract_expr(egraph, kind_children[3], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let batch_dim = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
|
||||
let extracted = Self {
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
batch_dim,
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Trigger JIT compilation (or .so cache hit) at extract time, not at
|
||||
// first execute. Pays the ~30s cold-cache nvcc cost during compile
|
||||
// rather than during the GA profiling loop, where it would dominate
|
||||
// the candidate's measured runtime and make the GA reject FlashInfer.
|
||||
let _ = jit::ensure_compiled(head_dim);
|
||||
|
||||
// Walk the mask e-graph chain to recover qo_indptr / kv_indptr Input nodes.
|
||||
// input_enodes: [Q, K_cache, V_cache, gather_idx, mask]
|
||||
let mask_node = input_enodes[4];
|
||||
let indptrs = find_indptrs::find_indptr_inputs(egraph, mask_node);
|
||||
|
||||
// Build final inputs: [Q, K_cache, V_cache, gather_idx, mask, qo_indptr, kv_indptr]
|
||||
let mut final_inputs = input_enodes;
|
||||
final_inputs.push(indptrs.qo_indptr);
|
||||
final_inputs.push(indptrs.kv_indptr);
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
(op, final_inputs)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for FlashInferAttention {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let lib = jit::ensure_compiled(self.head_dim);
|
||||
|
||||
let total_q_tokens = self
|
||||
.batch_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention batch_dim is unresolved"))?;
|
||||
let c = *dyn_map
|
||||
.get(&'c')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'c'"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'r'"))?;
|
||||
|
||||
if inputs.len() < 7 {
|
||||
anyhow::bail!(
|
||||
"FlashInferAttention expects 7 inputs (Q, K, V, flat_idx, mask, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("FlashInferAttention missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_buf = get_buf("Q", inputs[0])?;
|
||||
let k_buf = get_buf("K_cache", inputs[1])?;
|
||||
let v_buf = get_buf("V_cache", inputs[2])?;
|
||||
let flat_idx_buf = get_buf("flat_gather_idx", inputs[3])?;
|
||||
// inputs[4] = mask (unused by FlashInfer — indptrs replace it)
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[6])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
// Derive batch_size (num sequences) from r = indptr length.
|
||||
let batch_size = r.saturating_sub(1);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"FlashInferAttention",
|
||||
total_q_tokens,
|
||||
batch_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let kv_dim = self.num_kv_heads * self.head_dim;
|
||||
let cu_stream = stream.cu_stream() as *mut std::ffi::c_void;
|
||||
|
||||
// Extract slot indices (one per context page) from the flat gather index.
|
||||
let indices_buf = unsafe { stream.alloc::<u8>(c.max(1) * std::mem::size_of::<i32>())? };
|
||||
let (indices_ptr, _idx_guard) = indices_buf.device_ptr(stream);
|
||||
|
||||
if c > 0 {
|
||||
unsafe {
|
||||
(lib.extract_slot_indices)(
|
||||
flat_idx_buf.ptr() as *const i32,
|
||||
indices_ptr as *mut i32,
|
||||
c as i32,
|
||||
kv_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Read kv_indptr to host for the plan phase.
|
||||
let kv_indptr_bytes = r * 4;
|
||||
let mut kv_indptr_host_bytes = vec![0u8; kv_indptr_bytes];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(
|
||||
&mut kv_indptr_host_bytes,
|
||||
kv_indptr_buf.ptr(),
|
||||
stream.cu_stream(),
|
||||
)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let kv_indptr_host: Vec<i32> = unsafe {
|
||||
let mut v = std::mem::ManuallyDrop::new(kv_indptr_host_bytes);
|
||||
Vec::from_raw_parts(v.as_mut_ptr() as *mut i32, r, r)
|
||||
};
|
||||
|
||||
// kv_last_page_len = [1; batch_size] when page_size=1.
|
||||
let last_page_host: Vec<i32> = vec![1; batch_size];
|
||||
let last_page_dev: CudaSlice<u8> = if batch_size > 0 {
|
||||
stream.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
last_page_host.as_ptr() as *const u8,
|
||||
last_page_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
unsafe { stream.alloc::<u8>(1)? }
|
||||
};
|
||||
let (last_page_ptr, _lp_guard) = last_page_dev.device_ptr(stream);
|
||||
|
||||
// Global shared workspaces (allocated once across all op instances to
|
||||
// amortize the ~4ms first-allocation cost during GA profiling).
|
||||
static FLOAT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
static INT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
let float_ws = FLOAT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(FLOAT_WORKSPACE_SIZE).unwrap() });
|
||||
let int_ws = INT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(INT_WORKSPACE_SIZE).unwrap() });
|
||||
let page_locked_ws = PAGE_LOCKED_WORKSPACE.get_or_init(|| unsafe {
|
||||
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
let status = libc::posix_memalign(&mut ptr, 4096, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(status, 0, "Failed to allocate page-locked workspace");
|
||||
let cuda_status = cuda_pin_memory(ptr, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(cuda_status, 0, "Failed to pin memory");
|
||||
PageLockedPtr(ptr as *mut u8)
|
||||
});
|
||||
|
||||
let (float_ws_ptr, _fws_guard) = float_ws.device_ptr(stream);
|
||||
let (int_ws_ptr, _iws_guard) = int_ws.device_ptr(stream);
|
||||
|
||||
// FlashInfer decode writes (total_q_tokens, heads, dim);
|
||||
// luminal expects (heads, total_q_tokens, dim) — transpose at the end.
|
||||
let output_elems = total_q_tokens * self.num_qo_heads * self.head_dim;
|
||||
let temp_out_buf =
|
||||
unsafe { stream.alloc::<u8>(output_elems * std::mem::size_of::<f32>())? };
|
||||
let (temp_out_ptr, _tmp_guard) = temp_out_buf.device_ptr(stream);
|
||||
|
||||
// PrefillPlanInfo has 15 entries, DecodePlanInfo fewer — 16 is enough.
|
||||
let mut plan_info_buf = [0i64; 16];
|
||||
let mut plan_info_len: i32 = 0;
|
||||
|
||||
// ── BatchDecode path ──
|
||||
// Prefill kernels require fp16/bf16 tensor-core MMA; the C API returns -1
|
||||
// when called from the fp32 pipeline. We only use decode here.
|
||||
let plan_ret = unsafe {
|
||||
(lib.plan)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
INT_WORKSPACE_SIZE,
|
||||
page_locked_ws.0 as *mut std::ffi::c_void,
|
||||
kv_indptr_host.as_ptr() as *mut i32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
plan_info_buf.as_mut_ptr(),
|
||||
&mut plan_info_len,
|
||||
)
|
||||
};
|
||||
if plan_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode plan failed with error code {plan_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
let mut plan_info = self.plan_info.lock().unwrap();
|
||||
plan_info.clear();
|
||||
plan_info.extend_from_slice(&plan_info_buf[..plan_info_len as usize]);
|
||||
|
||||
let run_ret = unsafe {
|
||||
(lib.run)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
plan_info.as_mut_ptr(),
|
||||
plan_info.len() as i32,
|
||||
q_buf.ptr() as *mut f32,
|
||||
k_buf.ptr() as *mut f32,
|
||||
v_buf.ptr() as *mut f32,
|
||||
kv_indptr_buf.ptr() as *mut i32,
|
||||
indices_ptr as *mut i32,
|
||||
last_page_ptr as *mut i32,
|
||||
temp_out_ptr as *mut f32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
)
|
||||
};
|
||||
drop(plan_info);
|
||||
|
||||
if run_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode run failed with error code {run_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
// Transpose (total_q_tokens, heads, dim) → (heads, total_q_tokens, dim)
|
||||
unsafe {
|
||||
(lib.transpose_output)(
|
||||
temp_out_ptr as *const f32,
|
||||
out_buf.ptr() as *mut f32,
|
||||
total_q_tokens as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.batch_dim * self.num_qo_heads * self.head_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("FlashInferAttention")
|
||||
}
|
||||
}
|
||||
|
||||
/// Pin host memory for CUDA async memcpy.
|
||||
///
|
||||
/// `cudaHostRegister` lives in libcudart, which cudarc doesn't link to our
|
||||
/// binary. Resolve it via `dlopen`/`dlsym` so we don't need a build script or
|
||||
/// a `#[link]` directive — keeping the crate buildable without any nvcc-side
|
||||
/// dependencies.
|
||||
unsafe fn cuda_pin_memory(ptr: *mut std::ffi::c_void, size: usize) -> i32 {
|
||||
type HostRegisterFn = unsafe extern "C" fn(*mut std::ffi::c_void, usize, u32) -> i32;
|
||||
static FN: OnceLock<usize> = OnceLock::new();
|
||||
|
||||
let raw = *FN.get_or_init(|| unsafe {
|
||||
let lib = [
|
||||
"libcudart.so",
|
||||
"libcudart.so.13",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
]
|
||||
.iter()
|
||||
.find_map(|n| libloading::Library::new(*n).ok())
|
||||
.expect("FlashInfer: could not dlopen libcudart for cudaHostRegister");
|
||||
let sym: libloading::Symbol<HostRegisterFn> = lib
|
||||
.get(b"cudaHostRegister\0")
|
||||
.expect("FlashInfer: libcudart missing cudaHostRegister symbol");
|
||||
let ptr = *sym as *const () as usize;
|
||||
// Keep libcudart resident for the process lifetime so the function
|
||||
// pointer remains valid.
|
||||
std::mem::forget(lib);
|
||||
ptr
|
||||
});
|
||||
let f: HostRegisterFn = unsafe { std::mem::transmute(raw) };
|
||||
// cudaHostRegisterDefault = 0
|
||||
unsafe { f(ptr, size, 0) }
|
||||
}
|
||||
@@ -1,357 +0,0 @@
|
||||
// FlashInfer batch decode + prefill wrapper for luminal_cuda.
|
||||
// JIT-compiled at runtime with -DLUMINAL_HEAD_DIM=N.
|
||||
//
|
||||
// Decode: instantiated for f32 (scalar vectorized dot products, no tensor cores).
|
||||
// Prefill: instantiated for f16 (requires tensor core MMA + ldmatrix).
|
||||
// The C API accepts fp32 buffers; cast kernels convert fp32↔fp16 at the boundary.
|
||||
//
|
||||
// NHD layout. GQA group_size and page_size are runtime parameters.
|
||||
|
||||
#ifndef LUMINAL_HEAD_DIM
|
||||
#error "LUMINAL_HEAD_DIM must be defined (e.g. -DLUMINAL_HEAD_DIM=128)"
|
||||
#endif
|
||||
|
||||
// Include utils.cuh first to get the original DISPATCH_HEAD_DIM, then override it
|
||||
// to only instantiate our specific HEAD_DIM. This avoids a compile error in
|
||||
// cascade.cuh where HEAD_DIM=512 + f32 triggers vec_size=16, vec_bits=512
|
||||
// which exceeds cp_async's 256-bit limit.
|
||||
#include <flashinfer/utils.cuh>
|
||||
#undef DISPATCH_HEAD_DIM
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
{ \
|
||||
constexpr size_t HEAD_DIM = LUMINAL_HEAD_DIM; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#include <flashinfer/attention/scheduler.cuh>
|
||||
#include <flashinfer/attention/decode.cuh>
|
||||
#include <flashinfer/attention/default_decode_params.cuh>
|
||||
#include <flashinfer/attention/prefill.cuh>
|
||||
#include <flashinfer/attention/default_prefill_params.cuh>
|
||||
#include <flashinfer/attention/mask.cuh>
|
||||
#include <flashinfer/attention/variants.cuh>
|
||||
#include <flashinfer/page.cuh>
|
||||
#include <flashinfer/pos_enc.cuh>
|
||||
|
||||
#include "wrapper.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// ── Decode types (f32) ──
|
||||
using DTypeQ = float;
|
||||
using DTypeKV = float;
|
||||
using DTypeO = float;
|
||||
using IdType = int32_t;
|
||||
|
||||
// ── Prefill types (f16 compute, fp32 external interface) ──
|
||||
using PrefillDTypeQ = half;
|
||||
using PrefillDTypeKV = half;
|
||||
using PrefillDTypeO = half;
|
||||
|
||||
constexpr uint32_t HEAD_DIM = LUMINAL_HEAD_DIM;
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;
|
||||
|
||||
// Attention variants
|
||||
using Variant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
using CausalVariant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
// Decode params (f32)
|
||||
using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
|
||||
|
||||
// Prefill params (f16)
|
||||
using PrefillParams = BatchPrefillPagedParams<PrefillDTypeQ, PrefillDTypeKV, PrefillDTypeO, IdType>;
|
||||
|
||||
// Forward declarations
|
||||
namespace flashinfer {
|
||||
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
|
||||
typename Params>
|
||||
cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
|
||||
PosEncodingMode POS_ENCODING_MODE, bool USE_FP16_QK_REDUCTION,
|
||||
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
|
||||
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
}
|
||||
|
||||
// Explicit instantiation: decode kernel (f32)
|
||||
template cudaError_t flashinfer::BatchDecodeWithPagedKVCacheDispatched<
|
||||
HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
DecodeParams params, DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// Explicit instantiation: prefill kernels (f16, causal mask, CTA_TILE_Q=16/64/128)
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
16, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
64, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
128, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// ── fp32 ↔ fp16 cast kernels ──
|
||||
|
||||
__global__ void cast_f32_to_f16_kernel(const float* src, half* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __float2half(src[i]);
|
||||
}
|
||||
|
||||
__global__ void cast_f16_to_f32_kernel(const half* src, float* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __half2float(src[i]);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
uint32_t group_size = num_qo_heads / num_kv_heads;
|
||||
|
||||
// We need to dispatch on GROUP_SIZE to get the right work estimation function
|
||||
cudaError_t status = cudaSuccess;
|
||||
|
||||
// Use a lambda to dispatch on group size
|
||||
auto do_plan = [&]<uint32_t GROUP_SIZE>() -> cudaError_t {
|
||||
auto work_estimation_func =
|
||||
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
|
||||
GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>;
|
||||
return DecodePlan<HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
float_workspace, float_ws_size,
|
||||
int_workspace, page_locked_int_workspace,
|
||||
int_ws_size, plan_info, indptr_h,
|
||||
(uint32_t)batch_size, (uint32_t)num_qo_heads,
|
||||
(uint32_t)page_size, /*enable_cuda_graph=*/false,
|
||||
stream, work_estimation_func);
|
||||
};
|
||||
|
||||
switch (group_size) {
|
||||
case 1: status = do_plan.operator()<1>(); break;
|
||||
case 2: status = do_plan.operator()<2>(); break;
|
||||
case 4: status = do_plan.operator()<4>(); break;
|
||||
case 8: status = do_plan.operator()<8>(); break;
|
||||
default: return -1; // unsupported group size
|
||||
}
|
||||
|
||||
if (status != cudaSuccess) return (int)status;
|
||||
|
||||
auto vec = plan_info.ToVector();
|
||||
*plan_info_len_out = (int)vec.size();
|
||||
std::memcpy(plan_info_out, vec.data(), vec.size() * sizeof(int64_t));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q,
|
||||
float* k_cache,
|
||||
float* v_cache,
|
||||
int32_t* kv_indptr,
|
||||
int32_t* kv_indices,
|
||||
int32_t* kv_last_page_len,
|
||||
float* output,
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
plan_info.FromVector(std::vector<int64_t>(plan_info_vec, plan_info_vec + plan_info_len));
|
||||
|
||||
// Construct paged_kv_t with NHD layout
|
||||
paged_kv_t<DTypeKV, IdType> paged_kv(
|
||||
(uint32_t)num_kv_heads,
|
||||
(uint32_t)page_size,
|
||||
HEAD_DIM,
|
||||
(uint32_t)batch_size,
|
||||
QKVLayout::kNHD,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
kv_last_page_len);
|
||||
|
||||
DecodeParams params;
|
||||
params.q = q;
|
||||
params.q_rope_offset = nullptr;
|
||||
params.paged_kv = paged_kv;
|
||||
params.o = output;
|
||||
params.lse = nullptr;
|
||||
params.maybe_alibi_slopes = nullptr;
|
||||
params.padded_batch_size = plan_info.padded_batch_size;
|
||||
params.num_qo_heads = (uint32_t)num_qo_heads;
|
||||
// Q buffer is (batch, num_qo_heads * head_dim) flat — the graph's split_dims + transpose
|
||||
// are stride tricks, no data movement. So the actual memory layout is (batch, heads, dim).
|
||||
params.q_stride_n = num_qo_heads * HEAD_DIM;
|
||||
params.q_stride_h = HEAD_DIM;
|
||||
params.window_left = -1; // no sliding window
|
||||
params.logits_soft_cap = 0.0f;
|
||||
params.sm_scale = 1.0f / sqrtf((float)HEAD_DIM);
|
||||
params.rope_rcp_scale = 1.0f;
|
||||
params.rope_rcp_theta = 1.0f;
|
||||
|
||||
// Set plan info pointers
|
||||
params.request_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.request_indices_offset);
|
||||
params.kv_tile_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_tile_indices_offset);
|
||||
params.o_indptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.o_indptr_offset);
|
||||
params.kv_chunk_size_ptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_chunk_size_ptr_offset);
|
||||
params.block_valid_mask = nullptr;
|
||||
params.partition_kv = false;
|
||||
|
||||
DTypeO* tmp_v = nullptr;
|
||||
float* tmp_s = nullptr;
|
||||
|
||||
if (plan_info.split_kv) {
|
||||
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_workspace, plan_info.v_offset);
|
||||
tmp_s = GetPtrFromBaseOffset<float>(float_workspace, plan_info.s_offset);
|
||||
if (plan_info.enable_cuda_graph) {
|
||||
params.block_valid_mask =
|
||||
GetPtrFromBaseOffset<bool>(int_workspace, plan_info.block_valid_mask_offset);
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t status =
|
||||
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, Variant>(
|
||||
params, tmp_v, tmp_s, /*enable_pdl=*/false, stream);
|
||||
|
||||
return (int)status;
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
// BatchPrefill (fp16/bf16 only — tensor core MMA requires 16-bit inputs)
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
//
|
||||
// The prefill kernel templates are instantiated above for fp16. These C API
|
||||
// functions accept fp32 pointers (matching the current luminal pipeline) but
|
||||
// return -1 to indicate that fp32 prefill is not supported. When native fp16
|
||||
// support is added, these will accept fp16 pointers and call through to the
|
||||
// instantiated templates.
|
||||
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void*, size_t, void*, size_t, void*,
|
||||
int32_t*, int32_t*, int, int,
|
||||
int, int, int, int, cudaStream_t,
|
||||
int64_t*, int*)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
int flashinfer_batch_prefill_run(
|
||||
void*, size_t, void*,
|
||||
int64_t*, int,
|
||||
float*, float*, float*,
|
||||
int32_t*, int32_t*, int32_t*, int32_t*,
|
||||
float*, int, int, int, int, int, int, cudaStream_t)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
// ── Slot index extraction kernel (outside extern "C" for __global__) ──
|
||||
|
||||
__global__ void extract_slot_indices_kernel(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < c) out[i] = flat_idx[i * kv_dim] / kv_dim;
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream) {
|
||||
if (c == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (c + threads - 1) / threads;
|
||||
extract_slot_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
flat_idx, out, c, kv_dim);
|
||||
}
|
||||
|
||||
// ── Derive CSR indptr from attention mask ──
|
||||
// Mask is (s, c) f32. Entries > -1e9 are "valid" (0.0), rest are -inf.
|
||||
// Per-row count of valid entries = context length for that sequence.
|
||||
// Output: indptr[0..=s] with indptr[0]=0 and indptr[i+1] = indptr[i] + ctx_len[i].
|
||||
// Single thread is fine since s is tiny (batch_size during decode, typically 1-8).
|
||||
|
||||
__global__ void derive_indptr_kernel(
|
||||
const float* mask, int32_t* indptr, int s, int c) {
|
||||
if (threadIdx.x != 0 || blockIdx.x != 0) return;
|
||||
indptr[0] = 0;
|
||||
for (int i = 0; i < s; i++) {
|
||||
int count = 0;
|
||||
for (int j = 0; j < c; j++) {
|
||||
if (mask[i * c + j] > -1e9f) count++;
|
||||
}
|
||||
indptr[i + 1] = indptr[i] + count;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream) {
|
||||
if (s == 0) return;
|
||||
derive_indptr_kernel<<<1, 1, 0, stream>>>(mask, indptr, s, c);
|
||||
}
|
||||
|
||||
// ── Output transpose: (batch, heads, dim) → (heads, batch, dim) ──
|
||||
// FlashInfer writes output as (batch, heads, dim) but Luminal expects (heads, batch, dim).
|
||||
// For batch=1 these are identical; for batch>1 we need an explicit transpose.
|
||||
|
||||
__global__ void transpose_bhd_to_hbd_kernel(
|
||||
const float* src, float* dst, int batch, int heads, int dim) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = batch * heads * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Decompose linear index into (b, h, d) for src layout
|
||||
int d = idx % dim;
|
||||
int h = (idx / dim) % heads;
|
||||
int b = idx / (heads * dim);
|
||||
|
||||
// Write to (h, b, d) layout in dst
|
||||
dst[h * batch * dim + b * dim + d] = src[idx];
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream) {
|
||||
int total = batch * heads * dim;
|
||||
if (total == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (total + threads - 1) / threads;
|
||||
transpose_bhd_to_hbd_kernel<<<blocks, threads, 0, stream>>>(
|
||||
src, dst, batch, heads, dim);
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Plan phase: CPU-side scheduling. Must call before each new batch config.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase: GPU kernel launch.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [batch_size, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* kv_indptr, // [batch_size + 1]
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [batch_size, num_qo_heads, head_dim]
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Extract slot indices from a flat gather index tensor.
|
||||
// flat_idx shape: (c, kv_dim) i32, out shape: (c,) i32.
|
||||
// out[i] = flat_idx[i * kv_dim] / kv_dim
|
||||
void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Derive CSR indptr from attention mask.
|
||||
// mask shape: (s, c) f32. Entries > -1e9 are valid.
|
||||
// indptr shape: (s + 1,) i32. indptr[0] = 0, indptr[i+1] = cumsum of valid counts.
|
||||
void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Transpose output from (batch, heads, dim) to (heads, batch, dim).
|
||||
void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// ── BatchPrefill with Paged KV Cache ──
|
||||
|
||||
// Plan phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* qo_indptr_h, int32_t* kv_indptr_h,
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [total_num_rows, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* qo_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [total_num_rows, num_qo_heads, head_dim]
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,129 +1,17 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTypeTuple = (
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
&'static str,
|
||||
luminal::dtype::DType,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_type_tuple(op: &dyn HostOp) -> Option<CublasLtTypeTuple> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::type_tuple)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtScaleValues = (f64, f64);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_scale_values(op: &dyn HostOp) -> Option<CublasLtScaleValues> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::scale_values)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_epilogue(op: &dyn HostOp) -> Option<&'static str> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::epilogue)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtMatrixOrders = (&'static str, &'static str, &'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_matrix_orders(op: &dyn HostOp) -> Option<CublasLtMatrixOrders> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::matrix_orders)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTransposeOps = (&'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_transpose_ops(op: &dyn HostOp) -> Option<CublasLtTransposeOps> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::transpose_ops)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTensorScaleInputs = (bool, bool);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_tensor_scale_inputs(op: &dyn HostOp) -> Option<CublasLtTensorScaleInputs> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::tensor_scale_inputs)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
/// the reusable arena, or an external pointer. Host ops only need the pointer
|
||||
/// and the logical byte length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct DeviceBuffer {
|
||||
ptr: u64,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl DeviceBuffer {
|
||||
pub fn new(ptr: u64, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
pub fn ptr(self) -> u64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn len(self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
|
||||
let mut host = vec![0u8; self.len];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(host)
|
||||
}
|
||||
}
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
@@ -141,7 +29,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
@@ -160,15 +48,6 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns relative lifetimes for extra buffer nodes within this host op.
|
||||
///
|
||||
/// The tuple is `(node, first_step, last_step)`, where steps are local to
|
||||
/// this host op's execution. Returning `None` tells the runtime to treat
|
||||
/// every extra buffer as live for the whole host op.
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -5,19 +5,12 @@
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared expert index/gather markers
|
||||
; 2) Shared gate-up matmul marker
|
||||
; 3) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 5) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEExpertIndexState
|
||||
(MkGLUMoEExpertIndexState Expression Expression IR)
|
||||
)
|
||||
(GLUMoEExpertGatherState
|
||||
(MkGLUMoEExpertGatherState Expression Expression IR IR)
|
||||
)
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
@@ -35,8 +28,6 @@
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_expert_index (IR) GLUMoEExpertIndexState :merge new)
|
||||
(function glumoe_expert_gather (IR) GLUMoEExpertGatherState :merge new)
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
@@ -45,38 +36,17 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?iota_base (Op (Iota ?io ?iota_base_range) (INil)))
|
||||
(= ?mul_base (Op (Mul ?mul_base_shape ?mul_base_a_stride ?mul_base_b_stride ?mul_base_out_stride) (ICons ?topk_idx (ICons ?iota_base (INil)))))
|
||||
(= ?iota_within (Op (Iota (MIter) ?iota_within_range) (INil)))
|
||||
(= ?add_idx (Op (Add ?add_shape ?add_a_stride ?add_b_stride ?add_out_stride) (ICons ?mul_base (ICons ?iota_within (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_index ?add_idx)
|
||||
(MkGLUMoEExpertIndexState ?io ?iota_within_range ?topk_idx))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert index marker"
|
||||
)
|
||||
; ===== Gate-up expert gather =====
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?index_state (glumoe_expert_index ?idx))
|
||||
(= ?index_state (MkGLUMoEExpertIndexState ?io ?within_range ?topk_idx))
|
||||
(= ?gathered (Op (Gather ?gather_idx_shape ?gather_idx_stride ?gather_data_shape ?gather_data_stride) (ICons ?idx (ICons ?weights (INil)))))
|
||||
(= ?f32 (Op (Cast ?f32_size (F32)) (ICons ?gathered (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_gather ?f32)
|
||||
(MkGLUMoEExpertGatherState ?io ?within_range ?topk_idx ?weights))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert gather marker"
|
||||
)
|
||||
; ===== Cast BF16→F32 =====
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?gather_state (glumoe_expert_gather ?gu_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?gu_io ?gu_iota_within_range ?topk_idx ?gate_up_w))
|
||||
; ===== Gate-up batched matmul =====
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
@@ -84,7 +54,6 @@
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
@@ -111,7 +80,6 @@
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
@@ -145,7 +113,6 @@
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
@@ -155,8 +122,12 @@
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -164,7 +135,6 @@
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
@@ -174,8 +144,12 @@
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -183,7 +157,6 @@
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
@@ -195,10 +168,6 @@
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
@@ -208,44 +177,10 @@
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 2 (SwiGLU with row-normalized top-k weights) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?normed_topk (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 2))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (normalized swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
@@ -273,9 +208,6 @@
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
@@ -50,7 +50,7 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU/SwiGLUNormalized: ignored (rewriter wires `topk_values` again)
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
@@ -78,7 +78,6 @@ pub struct GLUMoE {
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
SwiGLUNormalized,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
@@ -86,7 +85,6 @@ impl GLUMoEMode {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
2 => Self::SwiGLUNormalized,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
@@ -95,7 +93,7 @@ impl GLUMoEMode {
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU | Self::SwiGLUNormalized => 0,
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
@@ -226,9 +224,8 @@ impl EgglogOp for GLUMoE {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
@@ -237,15 +234,17 @@ impl EgglogOp for GLUMoE {
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
),
|
||||
Rule::raw(include_str!["glumoe_rewrite.egg"]),
|
||||
]
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
@@ -296,140 +295,27 @@ impl HostOp for GLUMoE {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 6 {
|
||||
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
|
||||
}
|
||||
|
||||
// Resolve dimensions
|
||||
let hidden = self
|
||||
.gu_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
|
||||
let intermediate = self
|
||||
.dn_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
|
||||
let top_k = self
|
||||
.output_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
|
||||
let gu_io = self
|
||||
.gu_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
|
||||
let dn_io = self
|
||||
.dn_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
if hidden == 0 || intermediate == 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
|
||||
);
|
||||
}
|
||||
if top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if gu_io % hidden != 0 {
|
||||
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
|
||||
}
|
||||
if dn_io % intermediate != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
|
||||
);
|
||||
}
|
||||
|
||||
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
|
||||
let down_hidden = dn_io / intermediate;
|
||||
if gate_up_dim != intermediate * 2 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
|
||||
gate_up_dim,
|
||||
intermediate * 2
|
||||
);
|
||||
}
|
||||
if down_hidden != hidden {
|
||||
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
|
||||
}
|
||||
|
||||
let output_bytes = self
|
||||
.output_bytes()
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
|
||||
if output_bytes % (hidden * 4) != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
|
||||
hidden * 4
|
||||
);
|
||||
}
|
||||
let seq = output_bytes / (hidden * 4);
|
||||
if seq == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
|
||||
})
|
||||
};
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
|
||||
// Get input/output buffers
|
||||
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
|
||||
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
|
||||
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
|
||||
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let min_topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
if output_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
|
||||
output_buf.len()
|
||||
);
|
||||
}
|
||||
|
||||
let gu_stride_bytes = gate_up_dim * hidden * 2;
|
||||
let down_stride_bytes = hidden * intermediate * 2;
|
||||
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
|
||||
gate_up_buf.len()
|
||||
);
|
||||
}
|
||||
let num_experts = gate_up_buf.len() / gu_stride_bytes;
|
||||
if num_experts == 0 {
|
||||
anyhow::bail!("GLUMoE has no expert weights");
|
||||
}
|
||||
if down_buf.len() < num_experts * down_stride_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down weight buffer too small: have {} bytes, need {}",
|
||||
down_buf.len(),
|
||||
num_experts * down_stride_bytes
|
||||
);
|
||||
}
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
@@ -441,101 +327,41 @@ impl HostOp for GLUMoE {
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
|
||||
if !topk_idx_i32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index element count {} is not divisible by seq {seq}",
|
||||
topk_idx_i32.len()
|
||||
);
|
||||
}
|
||||
if !topk_vals_f32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value element count {} is not divisible by seq {seq}",
|
||||
topk_vals_f32.len()
|
||||
);
|
||||
}
|
||||
let topk_idx_row_stride = topk_idx_i32.len() / seq;
|
||||
let topk_vals_row_stride = topk_vals_f32.len() / seq;
|
||||
if topk_idx_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index row stride {topk_idx_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
if topk_vals_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value row stride {topk_vals_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
|
||||
let topk_idx_at = |token: usize, expert: usize| -> i32 {
|
||||
topk_idx_i32[token * topk_idx_row_stride + expert]
|
||||
};
|
||||
let topk_val_at = |token: usize, expert: usize| -> f32 {
|
||||
topk_vals_f32[token * topk_vals_row_stride + expert]
|
||||
};
|
||||
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i);
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at token {t} top-k position {i} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
}
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - SwiGLUNormalized: normalize topk values row-wise
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => {
|
||||
if topk_vals_row_stride == top_k {
|
||||
topk_vals_f32
|
||||
} else {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i);
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
}
|
||||
GLUMoEMode::SwiGLUNormalized => {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i) * inv_norm;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
if per_expert_scale_host.len() < per_expert_scale_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
|
||||
per_expert_scale_host.len()
|
||||
);
|
||||
}
|
||||
let per_expert_scale_f32: &[f32] =
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
@@ -544,8 +370,7 @@ impl HostOp for GLUMoE {
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[t * top_k + i] =
|
||||
topk_val_at(t, i) * inv_norm * scale;
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
@@ -558,10 +383,10 @@ impl HostOp for GLUMoE {
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = slice_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = slice_ptr(&workspace, stream);
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
@@ -580,15 +405,17 @@ impl HostOp for GLUMoE {
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, &weight) in weights.iter().enumerate() {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
let expert_idx = expert_idx as usize;
|
||||
|
||||
// a. Gate+Up matmul (BF16 in, BF16 out)
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
@@ -681,11 +508,7 @@ impl HostOp for GLUMoE {
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
|
||||
buf.ptr()
|
||||
}
|
||||
|
||||
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
|
||||
@@ -1,289 +0,0 @@
|
||||
//! Direct conv2d_bias kernel — fuses unfold + matmul + bias into one
|
||||
//! CUDA kernel with no `(H_out*W_out, C_in*K*K)` intermediate matrix.
|
||||
//!
|
||||
//! This is exposed as a luminal `CustomOp`, not a standard egglog-rewritten
|
||||
//! `KernelOp`, because the conv has no useful fusion opportunities with
|
||||
//! surrounding ops in the graphs it's used in (the VAE's resnet blocks),
|
||||
//! and pattern-matching the unfold+permute+merge_dims+matmul+bias chain
|
||||
//! reliably from egglog is significantly more work than just bypassing
|
||||
//! the egglog rewrite path entirely.
|
||||
//!
|
||||
//! The kernel is one-thread-per-output: each thread computes
|
||||
//! `out[co, ho, wo] = bias[co] + sum_{ci,ki,kj} input[ci, ho*S+ki-P, wo*S+kj-P] * weight[co, ci, ki, kj]`
|
||||
//! with bounds checks on the spatial dims for padding. This is far from
|
||||
//! peak FLOPs (no shared-memory tiling, no warp-level reduction over K)
|
||||
//! but it's correct and the memory footprint is just the input + weight +
|
||||
//! bias + output buffers — no `(M, K)` or `(M, N, K)` intermediate, so it
|
||||
//! scales linearly with the actual conv FLOPs rather than blowing up at
|
||||
//! large H/W like the unfold-based formulation.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::FxHashMap;
|
||||
use luminal::{
|
||||
dtype::DType, graph::Graph, op::CustomOp, op::LLIROp, prelude::GraphTensor, shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct conv2d-with-bias kernel. All shape/kernel params are static
|
||||
/// (baked into the CUDA source via #defines), so each conv shape gets
|
||||
/// its own compiled kernel. Inputs (in order): input `(C_in, H_in, W_in)`,
|
||||
/// weight `(C_out, C_in*K*K)` (i.e. flattened `(C_out, C_in, K, K)`), bias
|
||||
/// `(C_out,)`. Output: `(C_out, H_out, W_out)`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DKernel {
|
||||
pub c_in: usize,
|
||||
pub h_in: usize,
|
||||
pub w_in: usize,
|
||||
pub c_out: usize,
|
||||
pub kernel: usize,
|
||||
pub stride: usize,
|
||||
pub padding: usize,
|
||||
pub h_out: usize,
|
||||
pub w_out: usize,
|
||||
}
|
||||
|
||||
impl Conv2DKernel {
|
||||
fn output_elements(&self) -> usize {
|
||||
self.c_out * self.h_out * self.w_out
|
||||
}
|
||||
}
|
||||
|
||||
const THREADS_PER_BLOCK: usize = 256;
|
||||
|
||||
impl KernelOp for Conv2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let total = self.output_elements();
|
||||
let grid = total.div_ceil(THREADS_PER_BLOCK);
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void conv2d_bias_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias
|
||||
) {{
|
||||
const int TOTAL = {total};
|
||||
const int CIN = {c_in};
|
||||
const int H = {h_in};
|
||||
const int W = {w_in};
|
||||
const int HOUT = {h_out};
|
||||
const int WOUT = {w_out};
|
||||
const int K = {k};
|
||||
const int S = {s};
|
||||
const int P = {p};
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= TOTAL) return;
|
||||
int hw = HOUT * WOUT;
|
||||
int co = idx / hw;
|
||||
int rem = idx - co * hw;
|
||||
int ho = rem / WOUT;
|
||||
int wo = rem - ho * WOUT;
|
||||
|
||||
float acc = bias[co];
|
||||
int weight_co_base = co * (CIN * K * K);
|
||||
for (int ci = 0; ci < CIN; ci++) {{
|
||||
int input_ci_base = ci * (H * W);
|
||||
int weight_ci_base = weight_co_base + ci * (K * K);
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < K; ki++) {{
|
||||
int hi = ho * S + ki - P;
|
||||
if (hi < 0 || hi >= H) continue;
|
||||
int input_row_base = input_ci_base + hi * W;
|
||||
int weight_row_base = weight_ci_base + ki * K;
|
||||
#pragma unroll
|
||||
for (int kj = 0; kj < K; kj++) {{
|
||||
int wj = wo * S + kj - P;
|
||||
if (wj < 0 || wj >= W) continue;
|
||||
acc += input[input_row_base + wj] * weight[weight_row_base + kj];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
out[idx] = acc;
|
||||
}}
|
||||
",
|
||||
total = total,
|
||||
c_in = self.c_in,
|
||||
h_in = self.h_in,
|
||||
w_in = self.w_in,
|
||||
h_out = self.h_out,
|
||||
w_out = self.w_out,
|
||||
k = self.kernel,
|
||||
s = self.stride,
|
||||
p = self.padding,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("conv2d_bias_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(THREADS_PER_BLOCK),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.output_elements())
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per output: C_in * K * K input loads + same many weight loads + 1 bias load.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 2 * C_in * K * K mul-adds per output, plus the bias add = +1.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Conv2DBias"
|
||||
}
|
||||
}
|
||||
|
||||
/// luminal `CustomOp` that wraps `Conv2DKernel`. Lets us drop the kernel
|
||||
/// straight into an HLIR graph via `cx.custom_op(...)` without going
|
||||
/// through egglog rewrites.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DCustom(pub Conv2DKernel);
|
||||
|
||||
impl CustomOp for Conv2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// 2D conv-with-bias on a `(C_in, H, W)` F32 input tensor, with weights
|
||||
/// stored as `(C_out, C_in*K*K)` and bias as `(C_out,)`. Stride/padding/kernel
|
||||
/// are static. Output: `(C_out, H_out, W_out)`.
|
||||
///
|
||||
/// This is a thin wrapper over [`Conv2DKernel`] that hides the
|
||||
/// `cx.custom_op` plumbing. All inputs MUST be `DType::F32` and contiguous
|
||||
/// row-major; pass `tensor * 1.0_f32` first if you have a strided view.
|
||||
pub fn conv2d_bias(
|
||||
input: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(input.dtype, DType::F32, "conv2d_bias requires F32 input");
|
||||
assert_eq!(weight.dtype, DType::F32, "conv2d_bias requires F32 weight");
|
||||
assert_eq!(bias.dtype, DType::F32, "conv2d_bias requires F32 bias");
|
||||
|
||||
let dims = input.dims();
|
||||
assert_eq!(dims.len(), 3, "conv2d_bias expects (C_in, H, W) input");
|
||||
let c_in = dims[0].to_usize().expect("C_in must be a static dim");
|
||||
let h_in = dims[1].to_usize().expect("H must be a static dim");
|
||||
let w_in = dims[2].to_usize().expect("W must be a static dim");
|
||||
|
||||
let w_dims = weight.dims();
|
||||
assert_eq!(
|
||||
w_dims.len(),
|
||||
2,
|
||||
"conv2d_bias expects weight (C_out, C_in*K*K)"
|
||||
);
|
||||
let c_out = w_dims[0].to_usize().expect("C_out must be a static dim");
|
||||
let w_kk = w_dims[1]
|
||||
.to_usize()
|
||||
.expect("weight inner dim must be static");
|
||||
assert_eq!(
|
||||
w_kk,
|
||||
c_in * kernel * kernel,
|
||||
"weight inner dim {w_kk} != C_in*K*K = {}",
|
||||
c_in * kernel * kernel,
|
||||
);
|
||||
|
||||
let b_dims = bias.dims();
|
||||
assert_eq!(b_dims.len(), 1, "conv2d_bias expects bias (C_out,)");
|
||||
assert_eq!(
|
||||
b_dims[0].to_usize().expect("bias dim must be static"),
|
||||
c_out
|
||||
);
|
||||
|
||||
assert!(
|
||||
h_in + 2 * padding >= kernel,
|
||||
"padded H_in ({}) is smaller than kernel ({})",
|
||||
h_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
assert!(
|
||||
w_in + 2 * padding >= kernel,
|
||||
"padded W_in ({}) is smaller than kernel ({})",
|
||||
w_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
let h_out = (h_in + 2 * padding - kernel) / stride + 1;
|
||||
let w_out = (w_in + 2 * padding - kernel) / stride + 1;
|
||||
|
||||
let kern = Conv2DKernel {
|
||||
c_in,
|
||||
h_in,
|
||||
w_in,
|
||||
c_out,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
h_out,
|
||||
w_out,
|
||||
};
|
||||
let cx: &mut Graph = unsafe { &mut *input.graph_ref };
|
||||
cx.custom_op(
|
||||
Conv2DCustom(kern),
|
||||
vec![input, weight, bias],
|
||||
(c_out, h_out, w_out),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
@@ -1,378 +0,0 @@
|
||||
// =========================================================================
|
||||
// Generic CUDA elementwise ops used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// CUDA elementwise execution is represented as a FusionEnd-rooted region even
|
||||
// for a single op. These ops are therefore region-internal only; standalone
|
||||
// compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (CudaUnaryElementwise, CudaBinaryElementwise);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
fn extract_string_label(egraph: &SerializedEGraph, node: &ENodeId) -> String {
|
||||
egraph.enodes[node].0.trim_matches('"').to_string()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaUnaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaUnaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaUnaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
for (hlir, opcode) in [
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?out_s) (ICons ?x (INil))))
|
||||
(= ?dt (dtype ?u))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?out_s ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Exp\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?exp2 ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(datatype*
|
||||
(CudaSigmoidScaledState
|
||||
(MkCudaSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function cuda_sigmoid_scaled (IR) CudaSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (cuda_sigmoid_scaled ?scaled)
|
||||
(MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-region-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (cuda_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Sigmoid\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?sig_out ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaUnaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaUnaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaUnaryElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaBinaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaBinaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaBinaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Add ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?bin))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Add\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Add\")",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Mul ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?a))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Mul\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Mul\")",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[4], list_cache, expr_cache).unwrap();
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype: extract_dtype(egraph, kind_children[5]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaBinaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaBinaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes() * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaBinaryElementwise"
|
||||
}
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
// =========================================================================
|
||||
// Fusion boundary markers — FusionStart and FusionEnd.
|
||||
//
|
||||
// Tag-like LLIR ops that bracket a region of elementwise ops destined to
|
||||
// be emitted as a single CUDA kernel:
|
||||
// - N FusionStart nodes per region (one per FS leaf — distinct external
|
||||
// reads),
|
||||
// - exactly 1 FusionEnd per region.
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Both markers' `compile()` is
|
||||
// `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (FusionStart, FusionEnd);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// FusionStart
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionStart {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionStart",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No idempotence rule. `FusionStart(FusionStart(x)) ≡ FusionStart(x)`
|
||||
// would unify nested markers and create eclass cycles via the
|
||||
// pair-fuse rules; without it, occasional re-firings produce extra
|
||||
// semantically-correct identity layers, bounded by the run schedule.
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionStart {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionStart must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// FusionEnd
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionEnd {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionEnd",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Generic region growth works directly from HLIR elementwise ops into
|
||||
// `Cuda*Elementwise` region nodes. The concrete HLIR op still appears in
|
||||
// the egraph, so fusion remains a normal nondestructive alternative, but
|
||||
// the region-internal representation is arity based instead of one
|
||||
// dedicated fused sort per operation.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
];
|
||||
let binaries: &[(&str, &str)] = &[("Add", "Add"), ("Mul", "Mul")];
|
||||
|
||||
// Grow FE → unary consumer: U(FE(inner)) → FE(CudaUnary(inner)).
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Grow FE → binary consumer, left and right orientations.
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Absorb an elementwise producer through a FusionStart boundary. This
|
||||
// makes a region that initially treats `producer(...)` as an external
|
||||
// input able to pull that producer inside later.
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?x (INil))))
|
||||
(= ?fs_u (Op (FusionStart ?shape ?s ?dt) (ICons ?u (INil))))
|
||||
) (
|
||||
(let ?fs_x (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?fs_x (INil))))
|
||||
(union ?fs_u ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-U-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?bad_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?bad_fs (INil))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-unary-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?fs_bin (Op (FusionStart ?shape ?out_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(union ?fs_bin ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-B-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?bad_fs (ICons ?fs_b (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?bad_fs (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(CudaBinary(ia, ib)).
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// No dissolve rule (`FS(FE(x)) → x`): unioning FS's eclass with FE's
|
||||
// inner eclass creates self-referential eclasses after grow rules
|
||||
// extend the downstream region, and extraction then panics with
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionEnd {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionEnd must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionEnd"
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
//! Binary-inclusive elementwise kernel fusion.
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `elementwise` — generic region-internal CUDA elementwise op variants.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / generic elementwise / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod elementwise;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use elementwise::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior generic elementwise variants). Combined into a flat
|
||||
/// tuple for the `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, elementwise::Ops);
|
||||
@@ -1,639 +0,0 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// Older fusion lowering left elementwise / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — unfused non-region kernels, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior elementwise DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all elementwise bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior Cuda*Elementwise / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use as_any::Downcast;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::elementwise::{CudaBinaryElementwise, CudaUnaryElementwise},
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Compile units — what `kernel_to_host` iterates over instead of nodes.
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior Cuda*Elementwise nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub elementwise_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
/// strides; the host launch passes the same device pointer twice.
|
||||
pub fs_nodes: Vec<NodeIndex>,
|
||||
/// External producer NodeIndices, one per `fs_nodes` entry in the same
|
||||
/// order. Becomes the `inputs` field of the FE's `CompiledKernel`, and
|
||||
/// the kernel function's `in0`, `in1`, ... parameters in that order.
|
||||
pub external_inputs: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum CompileUnit {
|
||||
Single(NodeIndex),
|
||||
Region(RegionUnit),
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region detection.
|
||||
// =========================================================================
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// Cuda*Elementwise and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / Cuda*Elementwise nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
/// subgraph, because `partition_marked_convex` may put a shared FS leaf
|
||||
/// (one whose e-graph congruence-deduplicated it across multiple
|
||||
/// regions) into a different subgraph than the FE that absorbs it.
|
||||
/// Without this global view, `build_compile_units` running on the FS's
|
||||
/// subgraph would not see any FE walking back to the FS and would emit the
|
||||
/// FS as `CompileUnit::Single`; marker standalone compilation is not supported.
|
||||
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
for fe in llir_graph.node_indices() {
|
||||
if name_of(fe) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = vec![fe];
|
||||
visited.insert(fe);
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
absorbed.insert(pred);
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
absorbed
|
||||
}
|
||||
|
||||
pub(crate) fn build_compile_units(
|
||||
topo_order: &[NodeIndex],
|
||||
llir_graph: &LLIRGraph,
|
||||
globally_absorbed: &FxHashSet<NodeIndex>,
|
||||
) -> Vec<CompileUnit> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
// First pass: every FusionEnd in the subgraph anchors a region; gather
|
||||
// the region's interior + FS leaves by walking incoming edges
|
||||
// backward, stopping at FusionStart (a leaf — its predecessor is the
|
||||
// external producer, outside the region).
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut regions: FxHashMap<NodeIndex, RegionUnit> = FxHashMap::default();
|
||||
|
||||
for &node in topo_order {
|
||||
if name_of(node) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut interior: Vec<NodeIndex> = Vec::new();
|
||||
let mut fs_nodes: Vec<NodeIndex> = Vec::new();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = Vec::new();
|
||||
stack.push(node);
|
||||
visited.insert(node);
|
||||
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
fs_nodes.push(pred);
|
||||
// Don't recurse past FS — its predecessor is
|
||||
// external (outside the region).
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// A nested FE inside a region. Under the current
|
||||
// rule design these are cascade artifacts — treat
|
||||
// them as transparent (walk through) rather than
|
||||
// as a separate region. The outer region absorbs
|
||||
// them. They do not become CompileUnit::Region
|
||||
// anchors because their eclass is already the
|
||||
// outer region's.
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-elementwise predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
// malformed and we likely should not have a
|
||||
// region at all; caller will see incomplete
|
||||
// interior.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Topological order on the interior + FS nodes (so the kernel
|
||||
// emits `let v = ...;` lines after their inputs are bound). We
|
||||
// use the parent graph's toposort filtered to in-region nodes.
|
||||
let mut region_set: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
region_set.extend(interior.iter().copied());
|
||||
region_set.extend(fs_nodes.iter().copied());
|
||||
let topo = toposort(llir_graph, None).expect("LLIR cycle in region detection");
|
||||
let interior_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && interior.contains(n))
|
||||
.collect();
|
||||
let fs_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && fs_nodes.contains(n))
|
||||
.collect();
|
||||
|
||||
// External producer for each FS leaf, in the same order.
|
||||
let external_inputs: Vec<NodeIndex> = fs_topo
|
||||
.iter()
|
||||
.map(|&fs| {
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap_or_else(|| {
|
||||
// Dump the malformed structure: which FE
|
||||
// triggered the walk, every node in fs_topo and
|
||||
// interior_topo, and each FS's incoming /
|
||||
// outgoing degree. Helps localize whether the
|
||||
// missing edge came from extraction or a
|
||||
// downstream LLIR transform.
|
||||
if std::env::var("LUMINAL_DEBUG_FUSION_PANIC").is_ok() {
|
||||
eprintln!(
|
||||
"FusionStart panic: fe={} (kernel={:?})",
|
||||
node.index(),
|
||||
llir_graph.node_weight(node).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
}),
|
||||
);
|
||||
eprintln!(" fs_topo ({}):", fs_topo.len());
|
||||
for &f in &fs_topo {
|
||||
let in_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Incoming)
|
||||
.count();
|
||||
let out_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Outgoing)
|
||||
.count();
|
||||
let kn = llir_graph
|
||||
.node_weight(f)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(
|
||||
" fs={} kind={} in_deg={} out_deg={}",
|
||||
f.index(),
|
||||
kn,
|
||||
in_deg,
|
||||
out_deg,
|
||||
);
|
||||
}
|
||||
eprintln!(" interior_topo ({}):", interior_topo.len());
|
||||
for &i in &interior_topo {
|
||||
let kn = llir_graph
|
||||
.node_weight(i)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(" interior={} kind={}", i.index(), kn);
|
||||
}
|
||||
}
|
||||
panic!("FusionStart with no predecessor")
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
absorbed.extend(interior_topo.iter().copied());
|
||||
absorbed.extend(fs_topo.iter().copied());
|
||||
|
||||
regions.insert(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
elementwise_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Second pass: emit compile units in original topo order, replacing
|
||||
// FE nodes with their RegionUnit and skipping anything absorbed —
|
||||
// either by a region in *this* subgraph (`absorbed`) or by any
|
||||
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
|
||||
// latter prevents shared FS markers whose consumers live in other
|
||||
// convex subgraphs from being emitted as standalone compile units:
|
||||
// those FSes are absorbed by some other region, and the consuming
|
||||
// region reads from FS's external producer.
|
||||
let mut units: Vec<CompileUnit> = Vec::new();
|
||||
for &node in topo_order {
|
||||
if let Some(region) = regions.remove(&node) {
|
||||
units.push(CompileUnit::Region(region));
|
||||
} else if absorbed.contains(&node) || globally_absorbed.contains(&node) {
|
||||
continue;
|
||||
} else {
|
||||
units.push(CompileUnit::Single(node));
|
||||
}
|
||||
}
|
||||
units
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-elementwise body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
|
||||
llir_graph
|
||||
.node_weight(node)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>())
|
||||
.is_some_and(|op| {
|
||||
(***op).downcast_ref::<CudaUnaryElementwise>().is_some()
|
||||
|| (***op).downcast_ref::<CudaBinaryElementwise>().is_some()
|
||||
})
|
||||
}
|
||||
|
||||
fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("static_cast<float>({local})")
|
||||
} else {
|
||||
local.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("{cuda_ty}({expr})")
|
||||
} else {
|
||||
expr.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
|
||||
let a = || elementwise_value(locals[0], dtype);
|
||||
let b = || elementwise_value(locals[1], dtype);
|
||||
match op {
|
||||
"Sin" => format!("sinf({})", a()),
|
||||
"Sqrt" => format!("sqrtf({})", a()),
|
||||
"Exp" => format!("expf({})", a()),
|
||||
"Exp2" => format!("exp2f({})", a()),
|
||||
"Log2" => format!("log2f({})", a()),
|
||||
"Recip" => format!("1.0f / {}", a()),
|
||||
"Sigmoid" => format!("1.0f / (1.0f + expf(-{}))", a()),
|
||||
"Add" => format!("{} + {}", a(), b()),
|
||||
"Mul" => format!("{} * {}", a(), b()),
|
||||
other => panic!("region_codegen: unknown elementwise op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region compilation — emit one CUDA kernel for the whole region.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct CompiledRegion {
|
||||
pub function: CudaFunction,
|
||||
pub module: Arc<CudaModule>,
|
||||
pub kernel_str: String,
|
||||
pub grid: (Expression, Expression, Expression),
|
||||
pub block: (Expression, Expression, Expression),
|
||||
pub shared_mem: Expression,
|
||||
pub constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn compile_region(
|
||||
region: &RegionUnit,
|
||||
llir_graph: &LLIRGraph,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompiledRegion {
|
||||
// Resolve FE: shape, strides (for the write), dtype.
|
||||
let fe_op = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.expect("FE node must be a KernelOp");
|
||||
let fe_struct: &FusionEnd = (***fe_op)
|
||||
.downcast_ref::<FusionEnd>()
|
||||
.expect("region root must be FusionEnd");
|
||||
let out_shape: &[Expression] = &fe_struct.shape;
|
||||
let out_strides: &[Expression] = &fe_struct.strides;
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides and elementwise shapes.
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
for &fs_idx in ®ion.fs_nodes {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
for &elem_idx in ®ion.elementwise_topo {
|
||||
let elem_op = llir_graph[elem_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
if let Some(elem) = (***elem_op).downcast_ref::<CudaUnaryElementwise>() {
|
||||
all_vars.extend(elem.shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.in_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
} else if let Some(elem) = (***elem_op).downcast_ref::<CudaBinaryElementwise>() {
|
||||
all_vars.extend(elem.out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.a_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.b_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&all_vars);
|
||||
let dyn_dims_param = if all_vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
|
||||
// Build kernel signature: out, then one input per FS leaf in
|
||||
// `region.fs_nodes` order. The `external_inputs` list (parallel to
|
||||
// `fs_nodes`) is what the host wires into the launch params.
|
||||
let mut signature_params: Vec<String> = vec![format!("{cuda_ty} *out")];
|
||||
for i in 0..region.fs_nodes.len() {
|
||||
signature_params.push(format!("const {cuda_ty} *in{i}"));
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk elementwise nodes in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), elementwise nodes get fs_nodes.len()..(+ elementwise_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.elementwise_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
|
||||
let mut body = String::new();
|
||||
body.push_str(&format!(
|
||||
" long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n"
|
||||
));
|
||||
|
||||
// FS leaves: each reads from its corresponding `in_i` parameter using
|
||||
// its own strides.
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
let read_idx = flatten_strides(out_shape, &fs_struct.strides).to_kernel();
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = in{i}[{read_idx}];\n",
|
||||
name = local_name(fs_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// Elementwise ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.elementwise_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let (elem_name, elem_dtype) =
|
||||
if let Some(elem) = (***op_ref).downcast_ref::<CudaUnaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else if let Some(elem) = (***op_ref).downcast_ref::<CudaBinaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else {
|
||||
panic!(
|
||||
"region_codegen: expected Cuda*Elementwise op, got {}",
|
||||
op_ref.kernel_name()
|
||||
);
|
||||
};
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|(_, src)| local_name(src))
|
||||
.collect();
|
||||
// Sort by edge id like the rest of the codegen does for stable
|
||||
// input ordering.
|
||||
let mut edges: Vec<(_, NodeIndex)> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect();
|
||||
edges.sort_by_key(|(eid, _)| *eid);
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = elementwise_body(elem_name, &inputs_ref, elem_dtype);
|
||||
let expr = elementwise_init_expr(&expr, elem_dtype, cuda_ty);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the elementwise node feeding FE (its single incoming edge in
|
||||
// the region — an elementwise node or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionEnd with no predecessor");
|
||||
let fe_input_local = local_name(fe_input);
|
||||
let write_idx = flatten_strides(out_shape, out_strides).to_kernel();
|
||||
body.push_str(&format!(" out[{write_idx}] = {fe_input_local};\n"));
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}\n\
|
||||
{dyn_defines}\n\
|
||||
extern \"C\" {{\n\
|
||||
\x20 __global__ void fused_region_k({signature}{dyn_dims_param}) {{\n\
|
||||
{body}\
|
||||
\x20 }}\n\
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, function) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("region kernel PTX compile failed");
|
||||
let module = stream
|
||||
.context()
|
||||
.load_module(ptx)
|
||||
.expect("module load failed");
|
||||
let function = module
|
||||
.load_function("fused_region_k")
|
||||
.expect("region kernel function not found");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), function.clone()));
|
||||
(module, function)
|
||||
};
|
||||
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
|
||||
CompiledRegion {
|
||||
function,
|
||||
module,
|
||||
kernel_str: kernel,
|
||||
grid: (out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
block: (out_size.min(256), 1.into(), 1.into()),
|
||||
shared_mem: 0.into(),
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::fusion::elementwise::CudaBinaryElementwise;
|
||||
use luminal::op::LLIROp;
|
||||
use luminal::prelude::petgraph::algo::toposort;
|
||||
|
||||
/// Helper: wrap a `KernelOp` in an `LLIROp` of the kernel dialect.
|
||||
fn llir_of(op: impl KernelOp + 'static) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(op) as Box<dyn KernelOp>)
|
||||
}
|
||||
|
||||
/// Reproducer for the `FusionStart with no predecessor` panic at
|
||||
/// `region_codegen.rs:232`. The egglog rolling pass + iterated mode
|
||||
/// (`LUMINAL_LOOP_ROLL_ITERATE=1`) has been observed to produce LLIR
|
||||
/// graphs where a `FusionStart` marker is reached as a region leaf
|
||||
/// during the FE→FS walk but has no incoming edge — meaning the
|
||||
/// region has nothing to read from. `build_compile_units` then
|
||||
/// panics when constructing `external_inputs` because every FS leaf
|
||||
/// is required to have exactly one external producer.
|
||||
///
|
||||
/// Until that path is fixed, this test pins the failure mode so a
|
||||
/// regression doesn't silently change the panic message or location.
|
||||
/// `should_panic` rather than `ignore` so it stays runnable in CI
|
||||
/// and surfaces if the panic ever moves.
|
||||
#[test]
|
||||
#[should_panic(expected = "FusionStart with no predecessor")]
|
||||
fn fusion_start_with_no_predecessor_panics() {
|
||||
// Minimal reproducer:
|
||||
//
|
||||
// (no input) ──▶ FusionStart ──▶ CudaBinaryElementwise ──▶ FusionEnd
|
||||
//
|
||||
// CudaBinaryElementwise is a binary op (n_inputs = 2) so a real region would
|
||||
// have two FS leaves. For this panic-shape test only the *first*
|
||||
// FS leaf needs a missing predecessor — `build_compile_units`
|
||||
// panics in `expect("FusionStart with no predecessor")` as soon
|
||||
// as any FS in `fs_topo` lacks one. We add only one FS edge so
|
||||
// CudaBinaryElementwise has a dangling second input slot, but that's fine:
|
||||
// we're testing the specific panic path inside `build_compile_units`,
|
||||
// not full kernel codegen.
|
||||
let mut llir: LLIRGraph = LLIRGraph::default();
|
||||
|
||||
let fs_node = llir.add_node(llir_of(FusionStart::default()));
|
||||
let fadd_node = llir.add_node(llir_of(CudaBinaryElementwise::default()));
|
||||
let fe_node = llir.add_node(llir_of(FusionEnd::default()));
|
||||
|
||||
// FusionStart → CudaBinaryElementwise → FusionEnd.
|
||||
llir.add_edge(fs_node, fadd_node, ());
|
||||
llir.add_edge(fadd_node, fe_node, ());
|
||||
|
||||
let topo = toposort(&llir, None).expect("LLIR cycle in test setup");
|
||||
let absorbed = globally_absorbed_markers(&llir);
|
||||
|
||||
// This is the call that panics with `FusionStart with no
|
||||
// predecessor` because `fs_node`'s incoming-edges iterator is
|
||||
// empty.
|
||||
let _ = build_compile_units(&topo, &llir, &absorbed);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,427 +0,0 @@
|
||||
//! Direct 2D matmul kernel — bypasses egglog rewrites, used as a custom op
|
||||
//! for matmul shapes where the cublaslt egg rules don't reliably fire.
|
||||
//!
|
||||
//! The cublaslt 2D rules in `host/cublaslt/cublaslt_*Cm_rewrite.egg` /
|
||||
//! `cublaslt_Rm*_rewrite.egg` are *supposed* to match any 2D matmul whose
|
||||
//! Mul + SumReduce broadcast lowering has the expected stride patterns,
|
||||
//! and the conditional matmul cleanup is *supposed* to delete the
|
||||
//! elementwise Mul + KernelSumReduce fallback whenever a cublaslt alternative
|
||||
//! exists. In practice both fail to fire reliably for the VAE's mid-block
|
||||
//! `AttnBlock` matmuls — at 1024² that lets the search occasionally pick
|
||||
//! the broadcast-Mul path for `q @ kᵀ`, generating a `(HW, HW, C) =
|
||||
//! (16384, 16384, 512)` ≈ 524 GiB single intermediate that OOMs the GPU.
|
||||
//!
|
||||
//! Same approach as `kernel::conv2d`: define a `KernelOp`, wrap it in a
|
||||
//! `CustomOp`, expose a tiny `pub fn` so callers don't see the
|
||||
//! `cx.custom_op` plumbing. This is opaque to egglog by design — we
|
||||
//! aren't trying to fuse with surrounding ops, just guarantee a sane
|
||||
//! lowering for the matmuls we know are problematic.
|
||||
//!
|
||||
//! The CUDA implementation is a textbook 2D-blocked SGEMM:
|
||||
//! * 16×16 output tile per block (256 threads)
|
||||
//! * Tiled load of A and B into shared memory in K-size chunks
|
||||
//! * Each thread accumulates one output element across all K-tiles
|
||||
//! * Optional bias broadcast along the M axis at write-out
|
||||
//! * `transpose_b` toggles between row-major B `(K, N)` and row-major
|
||||
//! B `(N, K)` (i.e. the `A @ Bᵀ` pattern that linear/projection
|
||||
//! layers use).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct 2D matmul `(M, K) × {(K, N) | (N, K)} → (M, N)` with optional
|
||||
/// per-output-column bias and an optional batch axis. A and output are
|
||||
/// always F32. B can be F32 or BF16; BF16 is converted to F32 on each
|
||||
/// load, which avoids materializing the cast as a separate intermediate
|
||||
/// tensor (important for the text encoder / transformer where the F32-
|
||||
/// cast weights would not fit in GPU memory). All shape parameters are
|
||||
/// static (baked into the CUDA source via #defines).
|
||||
///
|
||||
/// When `batch > 1` the kernel does `batch` independent 2D matmuls in
|
||||
/// parallel: A is `(batch, M, K)`, B is `(batch, *, *)` with the same
|
||||
/// per-batch shape, output is `(batch, M, N)`. All three are assumed
|
||||
/// contiguous row-major across batches (i.e. `a_batch_stride = M*K`,
|
||||
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
|
||||
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
|
||||
/// `(N,)` and broadcast across batches.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DKernel {
|
||||
pub m: usize,
|
||||
pub n: usize,
|
||||
pub k: usize,
|
||||
pub batch: usize,
|
||||
/// If `true`, B is interpreted as `(N, K)` row-major and accessed as
|
||||
/// `B[n][k]` (i.e. `A @ Bᵀ`). If `false`, B is `(K, N)` row-major and
|
||||
/// accessed as `B[k][n]` (i.e. `A @ B`).
|
||||
pub transpose_b: bool,
|
||||
pub has_bias: bool,
|
||||
/// Storage dtype of B. Currently F32 or BF16 are supported.
|
||||
pub weight_dtype: DType,
|
||||
}
|
||||
|
||||
const TILE: usize = 16;
|
||||
|
||||
impl KernelOp for Matmul2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let bias_param = if self.has_bias {
|
||||
", const float* __restrict__ bias"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let bias_add = if self.has_bias {
|
||||
" acc += bias[n];\n"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
|
||||
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
|
||||
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// Plus the per-batch offset (`b_batch_off`).
|
||||
let b_index_expr = if self.transpose_b {
|
||||
"b_batch_off + (b_n_base + tx) * K + (k0 + ty)"
|
||||
} else {
|
||||
"b_batch_off + (k0 + ty) * N + (b_n_base + tx)"
|
||||
};
|
||||
// Convert B's element to float on load. For BF16 we declare B as
|
||||
// `__nv_bfloat16*` and use `__bfloat162float`; for F32 it's a no-op.
|
||||
let (b_param_type, b_load_expr, bf16_include) = match self.weight_dtype {
|
||||
DType::F32 => (
|
||||
"const float* __restrict__ B",
|
||||
format!("B[{b_index_expr}]"),
|
||||
"",
|
||||
),
|
||||
DType::Bf16 => (
|
||||
"const __nv_bfloat16* __restrict__ B",
|
||||
format!("__bfloat162float(B[{b_index_expr}])"),
|
||||
"#include <cuda_bf16.h>\n",
|
||||
),
|
||||
other => panic!("Matmul2DKernel: unsupported weight_dtype {other:?}"),
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ A,
|
||||
{b_param_type}{bias_param}
|
||||
) {{
|
||||
const int M = {m};
|
||||
const int N = {n};
|
||||
const int K = {k};
|
||||
const int TILE = {tile};
|
||||
|
||||
__shared__ float As[{tile}][{tile}];
|
||||
__shared__ float Bs[{tile}][{tile}];
|
||||
|
||||
int bx = blockIdx.x; // tile column (n)
|
||||
int by = blockIdx.y; // tile row (m)
|
||||
int batch = blockIdx.z; // batch index (0..BATCH-1)
|
||||
int tx = threadIdx.x; // 0..TILE-1, output col within tile
|
||||
int ty = threadIdx.y; // 0..TILE-1, output row within tile
|
||||
|
||||
int m_global = by * TILE + ty;
|
||||
int n_global = bx * TILE + tx;
|
||||
|
||||
int a_m_base = by * TILE;
|
||||
int b_n_base = bx * TILE;
|
||||
|
||||
// Per-batch base pointer offsets (contiguous row-major across batches).
|
||||
int a_batch_off = batch * (M * K);
|
||||
int b_batch_off = batch * (K * N);
|
||||
int c_batch_off = batch * (M * N);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
int n_tiles = (K + TILE - 1) / TILE;
|
||||
for (int t = 0; t < n_tiles; ++t) {{
|
||||
int k0 = t * TILE;
|
||||
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
|
||||
int a_m = a_m_base + ty;
|
||||
int a_k = k0 + tx;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
|
||||
|
||||
// Load B tile depending on transpose_b
|
||||
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
|
||||
int b_k_or_k = k0 + ty; // similarly
|
||||
// We compute Bs[ty][tx] such that the inner loop reads Bs[k_local][n_local] = B[k][n].
|
||||
// For transpose_b=true (B is (N,K)): B[k][n] in math = B_storage[n][k] = B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// For transpose_b=false (B is (K,N)): B[k][n] in math = B_storage[k][n] = B[(k0+ty)*N + (b_n_base+tx)]
|
||||
bool b_in_bounds = ({transpose_b} ? (b_n_or_k < N && b_k_or_k < K)
|
||||
: (b_k_or_k < K && b_n_or_k < N));
|
||||
Bs[ty][tx] = b_in_bounds ? ({b_load_expr}) : 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < {tile}; ++kk) {{
|
||||
acc += As[ty][kk] * Bs[kk][tx];
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
|
||||
if (m_global < M && n_global < N) {{
|
||||
int n = n_global;
|
||||
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
m = self.m,
|
||||
n = self.n,
|
||||
k = self.k,
|
||||
tile = TILE,
|
||||
transpose_b = self.transpose_b,
|
||||
b_load_expr = b_load_expr,
|
||||
b_param_type = b_param_type,
|
||||
bias_param = bias_param,
|
||||
bias_add = bias_add,
|
||||
bf16_include = bf16_include,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("matmul_2d_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let grid_x = self.n.div_ceil(TILE);
|
||||
let grid_y = self.m.div_ceil(TILE);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(grid_y),
|
||||
Expression::from(self.batch),
|
||||
),
|
||||
(
|
||||
Expression::from(TILE),
|
||||
Expression::from(TILE),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.m * self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// K elements from A (F32) + K elements from B (F32 or BF16) + maybe bias (F32).
|
||||
let b_bytes = match self.weight_dtype {
|
||||
DType::F32 => 4,
|
||||
DType::Bf16 => 2,
|
||||
_ => 4,
|
||||
};
|
||||
let bias_bytes = if self.has_bias { 4 } else { 0 };
|
||||
Expression::from(
|
||||
self.batch * self.m * self.n * (self.k * 4 + self.k * b_bytes + bias_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let per_out = self.k * 2 + if self.has_bias { 1 } else { 0 };
|
||||
Expression::from(self.batch * self.m * self.n * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Matmul2D"
|
||||
}
|
||||
}
|
||||
|
||||
/// CustomOp wrapper for [`Matmul2DKernel`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DCustom(pub Matmul2DKernel);
|
||||
|
||||
impl CustomOp for Matmul2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
|
||||
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
|
||||
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
|
||||
/// pattern produced by linear / projection layers (`x @ w.t()`).
|
||||
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
|
||||
/// `(N,)`, row-major F32 throughout.
|
||||
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
|
||||
}
|
||||
|
||||
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
|
||||
///
|
||||
/// Lowers as plain HLIR — `Cast(A, BF16) @ permute(B_bf16) → Cast(F32)`.
|
||||
/// The activation cast and output cast are tiny (M*K and M*N elements;
|
||||
/// the K=hidden weight stays BF16). The inner BF16 matmul matches the
|
||||
/// existing cublaslt rewrite rules and runs as
|
||||
/// `CUBLAS_COMPUTE_32F_FAST_16BF` — Hopper's native 2× BF16 path.
|
||||
pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "linear_no_bias_bf16_w expects F32 A");
|
||||
assert_eq!(
|
||||
b_bf16.dtype,
|
||||
DType::Bf16,
|
||||
"linear_no_bias_bf16_w expects BF16 B"
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b_bf16.dims();
|
||||
assert_eq!(a_dims.len(), 2);
|
||||
assert_eq!(b_dims.len(), 2);
|
||||
let a_bf16 = a.cast(DType::Bf16);
|
||||
let b_kn = b_bf16.permute((1, 0));
|
||||
a_bf16.matmul(b_kn).cast(DType::F32)
|
||||
}
|
||||
|
||||
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
|
||||
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
|
||||
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
fn matmul_inner(
|
||||
a: GraphTensor,
|
||||
b: GraphTensor,
|
||||
transpose_b: bool,
|
||||
bias: Option<GraphTensor>,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
|
||||
let weight_dtype = b.dtype;
|
||||
assert!(
|
||||
matches!(weight_dtype, DType::F32 | DType::Bf16),
|
||||
"matmul B must be F32 or BF16, got {weight_dtype:?}",
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b.dims();
|
||||
assert_eq!(
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
"matmul A/B rank mismatch: {} vs {}",
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
);
|
||||
assert!(
|
||||
a_dims.len() == 2 || a_dims.len() == 3,
|
||||
"matmul expects rank 2 or 3, got rank {}",
|
||||
a_dims.len(),
|
||||
);
|
||||
|
||||
let (batch, a_off) = if a_dims.len() == 3 {
|
||||
let ba = a_dims[0].to_usize().expect("batch dim must be static");
|
||||
let bb = b_dims[0].to_usize().expect("batch dim must be static");
|
||||
assert_eq!(
|
||||
ba, bb,
|
||||
"matmul batch dim mismatch: A batch={ba}, B batch={bb}"
|
||||
);
|
||||
(ba, 1)
|
||||
} else {
|
||||
(1, 0)
|
||||
};
|
||||
|
||||
let m = a_dims[a_off].to_usize().expect("M must be a static dim");
|
||||
let k_a = a_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (A) must be a static dim");
|
||||
let (n, k_b) = if transpose_b {
|
||||
// B per-batch is (N, K)
|
||||
let n = b_dims[a_off].to_usize().expect("N must be a static dim");
|
||||
let k = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
(n, k)
|
||||
} else {
|
||||
// B per-batch is (K, N)
|
||||
let k = b_dims[a_off]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
let n = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("N must be a static dim");
|
||||
(n, k)
|
||||
};
|
||||
assert_eq!(k_a, k_b, "matmul K mismatch: A K={k_a}, B K={k_b}");
|
||||
let k = k_a;
|
||||
|
||||
let has_bias = bias.is_some();
|
||||
if let Some(bias) = bias {
|
||||
let bdims = bias.dims();
|
||||
assert_eq!(bdims.len(), 1, "matmul bias must be 1D");
|
||||
assert_eq!(
|
||||
bdims[0].to_usize().expect("bias dim must be static"),
|
||||
n,
|
||||
"matmul bias size must equal N"
|
||||
);
|
||||
assert_eq!(bias.dtype, DType::F32, "matmul bias must be F32");
|
||||
}
|
||||
|
||||
let kern = Matmul2DKernel {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch,
|
||||
transpose_b,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
};
|
||||
let cx = unsafe { &mut *a.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
vec![a, b, bias]
|
||||
} else {
|
||||
vec![a, b]
|
||||
};
|
||||
if batch == 1 {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
|
||||
} else {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,13 @@ use luminal_tracing::schema::{
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
pub mod rope;
|
||||
|
||||
pub use conv2d::{Conv2DCustom, Conv2DKernel, conv2d_bias};
|
||||
pub use cuda_graph::*;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
@@ -23,6 +23,9 @@ pub type Ops = (
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
KernelFusedElementwise,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -126,8 +129,7 @@ impl KernelOp for KernelMeanReduce {
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block: usize = 256; // 8 warps per block
|
||||
let n_warps = threads_per_block / 32;
|
||||
let threads_per_block = 256; // 8 warps per block
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
@@ -148,24 +150,12 @@ extern \"C\" {{
|
||||
long long iters = {iters};
|
||||
long long iter_stride = {iter_stride};
|
||||
|
||||
float thread_sum = 0.0f;
|
||||
for (long long i = threadIdx.x; i < iters; i += {threads_per_block})
|
||||
thread_sum += (float)in[in_start + i * iter_stride];
|
||||
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
|
||||
|
||||
__shared__ float warp_sums[{n_warps}];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp = threadIdx.x >> 5;
|
||||
if (lane == 0) warp_sums[warp] = thread_sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {{
|
||||
float sum = 0.0f;
|
||||
for (int w = 0; w < {n_warps}; w++) sum += warp_sums[w];
|
||||
out[{out_index}] = ({dtype})(sum / (float)iters);
|
||||
{dtype} sum = 0;
|
||||
for (long long i = 0; i < iters; i++) {{
|
||||
sum += in[in_start + i * iter_stride];
|
||||
}}
|
||||
|
||||
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
@@ -178,8 +168,6 @@ extern \"C\" {{
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel(),
|
||||
threads_per_block = threads_per_block,
|
||||
n_warps = n_warps,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
@@ -196,9 +184,9 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(threads_per_block.into(), 1.into(), 1.into()), // block
|
||||
0.into(), // shmem size
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
@@ -292,9 +280,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match KernelScatter and rewrite to KernelScatterNoCopy with ConsumedBuffer on dest.
|
||||
// ConsumedBuffer wraps dest to signal in-place modification.
|
||||
// This is only valid when the destination buffer can also represent
|
||||
// the scatter output layout. If dest is a strided/broadcast view,
|
||||
// regular Scatter must first materialize a contiguous output copy.
|
||||
//
|
||||
// Two-phase resolution:
|
||||
// 1. During (run): cleanup rules delete ConsumedBuffer if dest is shared (another op uses it)
|
||||
@@ -305,31 +290,12 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
|
||||
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
|
||||
let mut rules = vec![
|
||||
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail)))
|
||||
((consumed_buffer_ilist_contains ?list ?head))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-head\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail))
|
||||
(consumed_buffer_ilist_contains ?tail ?item))
|
||||
((consumed_buffer_ilist_contains ?list ?item))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-tail\"
|
||||
)",
|
||||
),
|
||||
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?dst ?os)
|
||||
(= ?dty (dtype ?src))
|
||||
)
|
||||
(
|
||||
@@ -339,7 +305,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
(union ?scatter ?nocopy)
|
||||
(set (dtype ?nocopy) ?dty)
|
||||
)
|
||||
:ruleset buffer_reuse
|
||||
:name \"scatter to scatter-no-copy\"
|
||||
)",
|
||||
),
|
||||
@@ -349,7 +314,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?dt (dtype ?a)))
|
||||
((set (dtype ?cb) ?dt))
|
||||
:ruleset dtype_prop
|
||||
:name \"consumed-buffer-dtype\"
|
||||
)",
|
||||
),
|
||||
@@ -359,28 +323,13 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?op1 (Op ?k1 ?ilist1))
|
||||
(consumed_buffer_ilist_contains ?ilist1 ?cb)
|
||||
(= ?ilist1 (ICons ?cb ?rest1))
|
||||
(= ?op2 (Op ?k2 ?ilist2))
|
||||
(!= ?op1 ?op2)
|
||||
(consumed_buffer_ilist_contains ?ilist2 ?a))
|
||||
(= ?ilist2 (ICons ?a ?t2)))
|
||||
((delete (ConsumedBuffer ?a)))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-cleanup-shared-op-use\"
|
||||
)",
|
||||
));
|
||||
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
|
||||
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?dest))
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
|
||||
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
|
||||
:ruleset post_cleanup
|
||||
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
|
||||
:name \"consumed-buffer-cleanup-pos\"
|
||||
)",
|
||||
));
|
||||
// Surviving ConsumedBuffers are valid — union with source and delete.
|
||||
@@ -507,8 +456,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
scatter_kernel,
|
||||
(n_src.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(n_src, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -623,7 +572,7 @@ extern \"C\" {{
|
||||
// KernelBatchMatVec: Fused batched matrix-vector product for attention
|
||||
// Matches: Mul(broadcast) + Sum pattern for [B, 1, K] x [B, K, N] -> [B, 1, N]
|
||||
// or [B, M, K] x [B, K, N] -> [B, M, N] with small M
|
||||
// Replaces the broadcast elementwise Mul + single-threaded KernelSumReduce pipeline
|
||||
// Replaces the broadcast KernelMul + single-threaded KernelSumReduce pipeline
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -711,7 +660,6 @@ impl EgglogOp for KernelBatchMatVec {
|
||||
(union ?sum ?bmv)
|
||||
(set (dtype ?bmv) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch mat-vec\"
|
||||
)"
|
||||
)]
|
||||
@@ -992,7 +940,6 @@ impl EgglogOp for KernelBatchMatMul {
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
@@ -1232,7 +1179,6 @@ impl EgglogOp for KernelSoftmax {
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
@@ -1454,3 +1400,650 @@ extern \"C\" {{
|
||||
"Softmax"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelExp: native exp (uses expf instead of exp2f * constant)
|
||||
// Single-kernel alternative to the 3-kernel Constant+Mul+Exp2 path.
|
||||
// Improves numerical precision by avoiding the truncated log2(e) constant.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelExp {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelExp {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelExp",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match Exp2(Mul(x, log2e_constant)) directly.
|
||||
// This matches the pattern created by frontend exp() = (self * (1/ln(2))).exp2()
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?kexp (Op (KernelExp ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelExp {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void exp_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = expf(in[{in_idx}]);
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("exp_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Exp"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelSigmoid: fused sigmoid = 1/(1+exp(-x))
|
||||
// Single-kernel alternative to the 5-kernel Neg+Exp+Const+Add+Recip path.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSigmoid {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelSigmoid",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelSigmoid {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void sigmoid_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / (1.0f + expf(-in[{in_idx}]));
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sigmoid_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// neg + exp + add + recip = ~4 ops per element
|
||||
self.shape.iter().copied().product::<Expression>() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
/// A unary math function that can appear inside a fused elementwise kernel.
|
||||
/// Each variant has a stable string name (used both as the egglog token in
|
||||
/// the rule-generated ops string and as the `kernel_name()` of the source
|
||||
/// unary kernel op).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum UnaryFn {
|
||||
Sin,
|
||||
Sqrt,
|
||||
Exp2,
|
||||
Log2,
|
||||
Recip,
|
||||
}
|
||||
|
||||
impl UnaryFn {
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
UnaryFn::Sin => "Sin",
|
||||
UnaryFn::Sqrt => "Sqrt",
|
||||
UnaryFn::Exp2 => "Exp2",
|
||||
UnaryFn::Log2 => "Log2",
|
||||
UnaryFn::Recip => "Recip",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
match name {
|
||||
"Sin" => UnaryFn::Sin,
|
||||
"Sqrt" => UnaryFn::Sqrt,
|
||||
"Exp2" => UnaryFn::Exp2,
|
||||
"Log2" => UnaryFn::Log2,
|
||||
"Recip" => UnaryFn::Recip,
|
||||
_ => panic!("invalid UnaryFn name: {name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An LLIR-only op created by fusing a chain of unary elementwise kernels.
|
||||
/// Only fires when every op in the chain shares the same stride pattern,
|
||||
/// so reads and writes use a single `strides` field.
|
||||
///
|
||||
/// The `ops` sequence is carried as a comma-separated egglog `String`
|
||||
/// (e.g. `"Sin,Sqrt,Exp2"`) — it's pure codegen metadata that egglog never
|
||||
/// reasons about, and `String` is a primitive sort, so this avoids
|
||||
/// introducing a new datatype/sort just to carry the list.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelFusedElementwise {
|
||||
shape: Vec<Expression>,
|
||||
strides: Vec<Expression>,
|
||||
ops: Vec<UnaryFn>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl KernelFusedElementwise {
|
||||
pub fn ops(&self) -> &[UnaryFn] {
|
||||
&self.ops
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelFusedElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelFusedElementwise",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("ops", STRING),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let unaries = [
|
||||
("KernelSin", UnaryFn::Sin),
|
||||
("KernelSqrt", UnaryFn::Sqrt),
|
||||
("KernelExp2", UnaryFn::Exp2),
|
||||
("KernelLog2", UnaryFn::Log2),
|
||||
("KernelRecip", UnaryFn::Recip),
|
||||
];
|
||||
let mut rules = Vec::with_capacity(unaries.len() * unaries.len() + unaries.len());
|
||||
|
||||
// Pair fusion: two adjacent pure-elementwise unaries -> Fused[a, b].
|
||||
for (a_name, a_fn) in unaries {
|
||||
for (b_name, b_fn) in unaries {
|
||||
let (a_str, b_str) = (a_fn.name(), b_fn.name());
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?a (Op ({a_name} ?shape ?strides ?strides ?dt) (ICons ?inp (INil))))
|
||||
(= ?b (Op ({b_name} ?shape ?strides ?strides ?dt) (ICons ?a (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (KernelFusedElementwise ?shape ?strides
|
||||
\"{a_str},{b_str}\" ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?b ?fused)
|
||||
)
|
||||
:name \"fuse-{a_name}-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Chain extend: Fused[ops] -> unary -> Fused[ops + \",<new>\"]. One
|
||||
// rule per outer unary. `+` is the builtin variadic string concat,
|
||||
// so this is O(1) per firing and handles chains of any length
|
||||
// without recursion.
|
||||
for (b_name, b_fn) in unaries {
|
||||
let b_str = b_fn.name();
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?fused (Op (KernelFusedElementwise ?shape ?strides ?ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(= ?next (Op ({b_name} ?shape ?strides ?strides ?dt)
|
||||
(ICons ?fused (INil))))
|
||||
)
|
||||
(
|
||||
(let ?new_ops (+ ?ops \",{b_str}\"))
|
||||
(let ?new_fused (Op (KernelFusedElementwise ?shape ?strides ?new_ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?next ?new_fused)
|
||||
)
|
||||
:name \"extend-Fused-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// The `ops` field is a String enode; its label is the quoted
|
||||
// literal (e.g. `"Sin,Sqrt"`), so strip the quotes and split.
|
||||
let ops_str = egraph.enodes[kind_children[2]].0.replace('"', "");
|
||||
let ops = if ops_str.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
ops_str.split(',').map(UnaryFn::from_name).collect()
|
||||
};
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
ops,
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelFusedElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let idx = flatten_strides(&self.shape, &self.strides).to_kernel();
|
||||
let ops_body = self
|
||||
.ops
|
||||
.iter()
|
||||
.map(|op| match op {
|
||||
UnaryFn::Sin => "val = sinf(val);",
|
||||
UnaryFn::Sqrt => "val = sqrtf(val);",
|
||||
UnaryFn::Exp2 => "val = exp2f(val);",
|
||||
UnaryFn::Log2 => "val = log2f(val);",
|
||||
UnaryFn::Recip => "val = 1.0f / val;",
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n ");
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void fused_elementwise_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
long long idx = {idx};
|
||||
{dtype} val = in[idx];
|
||||
{ops_body}
|
||||
out[idx] = val;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_elementwise_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * (self.ops.len() as i32)
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusedElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
//! Fused RoPE (rotary position embedding) — interleaved-pair convention.
|
||||
//!
|
||||
//! Replaces flux2's 6-op RoPE chain (split / slice / squeeze / neg / concat /
|
||||
//! merge_dims / 4× cast / mul / add) with a single kernel launch per call.
|
||||
//! ~120 RoPE calls per forward pass at full DiT depth.
|
||||
//!
|
||||
//! Convention: `repeat_interleave_real=True` (Flux 2 / diffusers), so adjacent
|
||||
//! dim pairs rotate together. For an input `[a0, b0, a1, b1, ...]` and per-
|
||||
//! position `(cos, sin)`, the output is
|
||||
//! `out[2j] = x[2j] * cos[2j] - x[2j+1] * sin[2j]`
|
||||
//! `out[2j+1] = x[2j+1] * cos[2j+1] + x[2j] * sin[2j+1]`
|
||||
//!
|
||||
//! Layout: x `(S, H, D)`, cos/sin `(S, D)` (broadcast across H).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPEKernel {
|
||||
pub s: usize,
|
||||
pub h: usize,
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
const TPB: usize = 64;
|
||||
|
||||
impl KernelOp for RoPEKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let s = self.s;
|
||||
let h = self.h;
|
||||
let d = self.d;
|
||||
assert!(d.is_multiple_of(2), "RoPE head_dim must be even");
|
||||
let kernel = format!(
|
||||
r#"
|
||||
extern "C" __global__ void rope_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ cos_,
|
||||
const float* __restrict__ sin_
|
||||
) {{
|
||||
const int S = {s};
|
||||
const int H = {h};
|
||||
const int D = {d};
|
||||
int sh = blockIdx.x; // 0..S*H
|
||||
int s_idx = sh / H;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const float* xr = x + sh * D;
|
||||
const float* cosr = cos_ + s_idx * D;
|
||||
const float* sinr = sin_ + s_idx * D;
|
||||
float* yr = out + sh * D;
|
||||
|
||||
for (int i = tid; i < D; i += {TPB}) {{
|
||||
float xi = xr[i];
|
||||
float xpair;
|
||||
if ((i & 1) == 0) {{
|
||||
// even: paired with i+1, rotated value is -x[i+1]
|
||||
xpair = -xr[i + 1];
|
||||
}} else {{
|
||||
// odd: paired with i-1, rotated value is +x[i-1]
|
||||
xpair = xr[i - 1];
|
||||
}}
|
||||
yr[i] = xi * cosr[i] + xpair * sinr[i];
|
||||
}}
|
||||
}}
|
||||
"#
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("rope_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
"rope_kernel".to_string(),
|
||||
(
|
||||
Expression::from(s * h),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(TPB),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.s * self.h * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// x: full (S,H,D); cos/sin: (S,D) read H times each but cached.
|
||||
Expression::from(self.s * self.h * self.d * 4 + self.s * self.d * 4 * 2)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 4 per output element (mul, neg/load, mul, add).
|
||||
Expression::from(self.s * self.h * self.d * 4)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"RoPE"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPECustom(pub RoPEKernel);
|
||||
|
||||
impl CustomOp for RoPECustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE: `x` shape `(S, H, D)` F32, `cos`/`sin` shape `(S, D)` F32.
|
||||
/// Returns `(S, H, D)` F32.
|
||||
pub fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(x.dtype, DType::F32, "RoPE x must be F32");
|
||||
let cos = if cos.dtype == DType::F32 {
|
||||
cos
|
||||
} else {
|
||||
cos.cast(DType::F32)
|
||||
};
|
||||
let sin = if sin.dtype == DType::F32 {
|
||||
sin
|
||||
} else {
|
||||
sin.cast(DType::F32)
|
||||
};
|
||||
let x_dims = x.dims();
|
||||
assert_eq!(x_dims.len(), 3, "RoPE x must be 3-D (S, H, D)");
|
||||
let s = x_dims[0].to_usize().expect("RoPE: S must be static");
|
||||
let h = x_dims[1].to_usize().expect("RoPE: H must be static");
|
||||
let d = x_dims[2].to_usize().expect("RoPE: D must be static");
|
||||
let cos_dims = cos.dims();
|
||||
let sin_dims = sin.dims();
|
||||
assert_eq!(cos_dims.len(), 2, "RoPE cos must be 2-D (S, D)");
|
||||
assert_eq!(sin_dims.len(), 2, "RoPE sin must be 2-D (S, D)");
|
||||
assert_eq!(cos_dims[0].to_usize().unwrap(), s, "RoPE cos S mismatch");
|
||||
assert_eq!(cos_dims[1].to_usize().unwrap(), d, "RoPE cos D mismatch");
|
||||
assert_eq!(sin_dims[0].to_usize().unwrap(), s, "RoPE sin S mismatch");
|
||||
assert_eq!(sin_dims[1].to_usize().unwrap(), d, "RoPE sin D mismatch");
|
||||
|
||||
let kern = RoPEKernel { s, h, d };
|
||||
let cx = unsafe { &mut *x.graph_ref };
|
||||
cx.custom_op(RoPECustom(kern), vec![x, cos, sin], (s, h, d), DType::F32)
|
||||
}
|
||||
@@ -13,7 +13,6 @@ use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
@@ -23,11 +22,10 @@ use luminal::{
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
fusion::region_codegen::{self, CompileUnit},
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
@@ -48,12 +46,8 @@ struct CompiledKernel {
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Human-readable labels for input nodes, for launch diagnostics.
|
||||
input_labels: Vec<String>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
|
||||
has_dyn_dims_param: bool,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
@@ -73,9 +67,7 @@ impl CompiledKernel {
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
input_labels: Vec<String>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
has_dyn_dims_param: bool,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
@@ -86,9 +78,7 @@ impl CompiledKernel {
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
@@ -192,32 +182,6 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
/// LLIR node IDs of every kernel in this CudaGraphOp, in the order
|
||||
/// they execute inside the compiled CUDA graph. This is the
|
||||
/// toposort `kernel_to_host` used at compile time, preserved here
|
||||
/// so the runtime can compute live ranges that match real
|
||||
/// execution order: each kernel in `state.kernels` was added to
|
||||
/// the CUDA graph with `prev_graph_node` as its sole dependency,
|
||||
/// which serializes them.
|
||||
pub fn kernel_topo_order(&self) -> Vec<NodeIndex> {
|
||||
self.state.borrow().kernels.iter().map(|k| k.node).collect()
|
||||
}
|
||||
|
||||
/// Direct LLIR-node inputs of one kernel inside this CudaGraphOp.
|
||||
/// Used by the runtime's live-range pass to refine intra-graph
|
||||
/// consumer positions: a kernel's input can stop being live as
|
||||
/// soon as that specific kernel finishes, not when the whole
|
||||
/// CudaGraphOp finishes.
|
||||
pub fn kernel_inputs(&self, kernel_node: NodeIndex) -> Vec<NodeIndex> {
|
||||
self.state
|
||||
.borrow()
|
||||
.kernels
|
||||
.iter()
|
||||
.find(|k| k.node == kernel_node)
|
||||
.map(|k| k.inputs.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -261,7 +225,7 @@ impl HostOp for CudaGraphOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
@@ -293,40 +257,6 @@ impl HostOp for CudaGraphOp {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
let state = self.state.borrow();
|
||||
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
|
||||
let max_step = state.kernels.len().saturating_sub(1);
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
lifetimes
|
||||
.entry(node)
|
||||
.and_modify(|(first, last)| {
|
||||
*first = (*first).min(step);
|
||||
*last = (*last).max(step);
|
||||
})
|
||||
.or_insert((step, step));
|
||||
};
|
||||
|
||||
for (step, kernel) in state.kernels.iter().enumerate() {
|
||||
for &input in &kernel.inputs {
|
||||
touch(input, step);
|
||||
}
|
||||
touch(kernel.node, step);
|
||||
}
|
||||
|
||||
for node in self.extra_buffer_nodes() {
|
||||
lifetimes.entry(node).or_insert((0, max_step));
|
||||
}
|
||||
|
||||
Some(
|
||||
lifetimes
|
||||
.into_iter()
|
||||
.map(|(node, (start, end))| (node, start, end))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
@@ -337,64 +267,11 @@ impl HostOp for CudaGraphOp {
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
|
||||
match kernel_name {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_requires_output_buffer(
|
||||
kernel: &CompiledKernel,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> bool {
|
||||
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
|
||||
&& kernel.kernel_op.output_aliases_input().is_none()
|
||||
}
|
||||
|
||||
fn validate_kernel_pointers(
|
||||
kernel: &CompiledKernel,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
|
||||
if *input_ptr == 0 {
|
||||
let input_label = kernel
|
||||
.input_labels
|
||||
.get(idx)
|
||||
.map(String::as_str)
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!(
|
||||
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
input_node,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
@@ -465,7 +342,7 @@ impl CudaGraphOp {
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.ptr());
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,26 +390,13 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
@@ -559,19 +423,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
@@ -600,7 +451,7 @@ impl CudaGraphOp {
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
@@ -622,7 +473,7 @@ impl CudaGraphOp {
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.ptr());
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -669,19 +520,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
@@ -690,41 +528,18 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
|
||||
eprintln!(
|
||||
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
kernel.inputs.len(),
|
||||
kernel.has_dyn_dims_param,
|
||||
params.values.len(),
|
||||
);
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
@@ -840,41 +655,6 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / Cuda*Elementwise nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
|
||||
graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
|
||||
name_of(graph, node) == Some("FusionStart")
|
||||
|| graph[node].to_op::<LoopStart>().is_some()
|
||||
|| graph[node].to_op::<LoopEnd>().is_some()
|
||||
|| graph[node].to_op::<LoopInput>().is_some()
|
||||
|| graph[node].to_op::<LoopInputStatic>().is_some()
|
||||
|| graph[node].to_op::<LoopOutput>().is_some()
|
||||
|| graph[node].to_op::<LoopOutputSelect>().is_some()
|
||||
};
|
||||
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
|
||||
let mut visited = FxHashSet::default();
|
||||
while visited.insert(node) && is_transparent_input(graph, node) {
|
||||
let Some(pred) = graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.next()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
node = pred;
|
||||
}
|
||||
node
|
||||
};
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
@@ -892,7 +672,6 @@ pub fn kernel_to_host(
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
let mut external_inputs = FxHashSet::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
@@ -906,151 +685,49 @@ pub fn kernel_to_host(
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
// CUDA kernel for the whole DAG); everything else stays as
|
||||
// CompileUnit::Single (the existing per-op compile path).
|
||||
let compile_units =
|
||||
region_codegen::build_compile_units(&topo_order, llir_graph, &globally_absorbed);
|
||||
// Compile all kernels with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
// Compile all units with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(compile_units.len());
|
||||
for unit in &compile_units {
|
||||
match unit {
|
||||
CompileUnit::Single(kernel_node_idx) => {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
|
||||
// Collect inputs from graph edges
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect_vec();
|
||||
if let Some(expected_inputs) =
|
||||
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
|
||||
{
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
expected_inputs,
|
||||
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel_op_ref.kernel_name(),
|
||||
kernel_node_idx,
|
||||
);
|
||||
}
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op.clone(),
|
||||
has_dyn_dims_param,
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
CompileUnit::Region(region) => {
|
||||
// Generate one fused CUDA kernel for the whole region.
|
||||
let compiled = region_codegen::compile_region(
|
||||
region,
|
||||
llir_graph,
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior elementwise nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region
|
||||
.external_inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect();
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(region.fe_node);
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
region.fe_node,
|
||||
compiled.function,
|
||||
compiled.grid,
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
}
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
@@ -1090,17 +767,16 @@ pub fn kernel_to_host(
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into
|
||||
// subgraph. Also include normalized FusionStart predecessors, because
|
||||
// the compiled kernels read from the concrete producer buffer rather
|
||||
// than the marker node.
|
||||
external_inputs.extend(subgraph.iter().flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
}));
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
@@ -1144,41 +820,22 @@ pub fn kernel_to_host(
|
||||
}
|
||||
}
|
||||
|
||||
// Add each cross-CudaGraphOp dep edge iff it would carry new ordering
|
||||
// information without closing a cycle. The previous topo-position gate
|
||||
// ("skip when src_pos >= dst_pos") was too coarse: it dropped edges
|
||||
// whose src happened to land later in the toposort than their dst even
|
||||
// when no path dst→src actually existed, leaving consumers free to run
|
||||
// before the producer wrote their input buffer (wrong outputs); and it
|
||||
// also added edges that were already implied by an existing src→dst
|
||||
// path (extra serialization, no new info).
|
||||
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
|
||||
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
|
||||
use petgraph::algo::has_path_connecting;
|
||||
for (src, dst) in edges_to_add {
|
||||
if has_path_connecting(&*llir_graph, src, dst, None) {
|
||||
continue; // already ordered src→dst by some path; edge redundant
|
||||
}
|
||||
if has_path_connecting(&*llir_graph, dst, src, None) {
|
||||
continue; // adding src→dst would close a cycle
|
||||
}
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
let topo = toposort(&*llir_graph, None).unwrap();
|
||||
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, n) in topo.iter().enumerate() {
|
||||
topo_pos.insert(*n, i);
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// Cuda*Elementwise) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
|
||||
// chewing through dead nodes every decode token.
|
||||
//
|
||||
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
|
||||
// walks' starting points), so we keep them — they're the kernel
|
||||
// anchor for the region's compiled kernel.
|
||||
for node in globally_absorbed {
|
||||
// Defensive: only remove if the node still exists.
|
||||
if llir_graph.node_weight(node).is_some() {
|
||||
llir_graph.remove_node(node);
|
||||
for (src, dst) in edges_to_add {
|
||||
// Only add forward edges (src before dst in topo order) to avoid creating cycles
|
||||
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
|
||||
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
|
||||
if src_pos >= dst_pos {
|
||||
continue; // Skip back-edges
|
||||
}
|
||||
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,842 +0,0 @@
|
||||
//! Unit + integration tests for the FlashInfer port.
|
||||
//!
|
||||
//! Four layers:
|
||||
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
|
||||
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
|
||||
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
|
||||
//! 3. Mask helper correctness (GPU): the primitive-op `test_compute_attn_mask` builder produces the right (s, c) mask.
|
||||
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
|
||||
//! compared against a luminal-compiled reference attention graph.
|
||||
//!
|
||||
//! GPU-dependent tests short-circuit when no CUDA device is available.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use cudarc::driver::{CudaStream, DevicePtr};
|
||||
use luminal::egglog_utils::{hlir_to_egglog, run_egglog};
|
||||
use luminal::op::{EgglogOp, IntoEgglogOp};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::host::flashinfer::FlashInferAttention;
|
||||
use crate::host::{DeviceBuffer, HostOp};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::get_cuda_stream;
|
||||
|
||||
/// Look up an op in `CudaRuntime::Ops::into_vec()` by its egglog sort name.
|
||||
fn ops_contains_sort(name: &str) -> bool {
|
||||
let ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.iter().any(|op| {
|
||||
// `SortDef` is opaque; its Debug repr starts with the sort name.
|
||||
let sort_dbg = format!("{:?}", op.sort());
|
||||
sort_dbg.contains(name)
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Test-wide model dimensions ───────────────────────────────────────────
|
||||
//
|
||||
// Small Llama-shaped GQA model: nheads=8, kv_heads=2, group=4, head_dim=64.
|
||||
// Chosen so HEAD_DIM ∈ {64, 128, 256} (FlashInfer constraint) and the test
|
||||
// suite fits in O(1ms) of GPU time per case.
|
||||
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_KV_HEADS: usize = 2;
|
||||
const KV_GROUPS: usize = 4;
|
||||
const N_HEADS: usize = N_KV_HEADS * KV_GROUPS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const HIDDEN: usize = N_HEADS * HEAD_DIM;
|
||||
|
||||
// ─── Reference attention graph (Q*K^T → softmax → *V via the compiler) ───
|
||||
|
||||
fn build_attention_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', HIDDEN));
|
||||
let k_ctx = cx.named_tensor("k_ctx", ('c', KV_DIM));
|
||||
let v_ctx_input = cx.named_tensor("v_ctx", ('c', KV_DIM));
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k_ctx.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx_input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// GQA broadcast: zero-stride Mul by 1.0
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let weights = scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
(cx, q_rope, k_ctx, v_ctx_input, attn_out)
|
||||
}
|
||||
|
||||
fn run_reference_attention(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch_size: usize,
|
||||
context_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
|
||||
cx.set_dim('s', batch_size);
|
||||
cx.set_dim('c', context_len);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, 3);
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt.execute(&cx.dyn_map);
|
||||
rt.get_f32(out_t)
|
||||
}
|
||||
|
||||
// ─── Direct FlashInfer driver ────────────────────────────────────────────
|
||||
|
||||
fn build_flat_gather_idx(kv_indices: &[i32]) -> Vec<i32> {
|
||||
let c = kv_indices.len();
|
||||
let mut flat = Vec::with_capacity(c * KV_DIM);
|
||||
for &slot in kv_indices {
|
||||
let base = slot * KV_DIM as i32;
|
||||
for j in 0..KV_DIM as i32 {
|
||||
flat.push(base + j);
|
||||
}
|
||||
}
|
||||
flat
|
||||
}
|
||||
|
||||
fn transpose_hbd_to_bhd(data: &[f32], heads: usize, batch: usize, dim: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; data.len()];
|
||||
for h in 0..heads {
|
||||
for b in 0..batch {
|
||||
for d in 0..dim {
|
||||
out[b * heads * dim + h * dim + d] = data[h * batch * dim + b * dim + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn alloc_dev(stream: &Arc<CudaStream>, bytes: usize) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = bytes.max(1);
|
||||
unsafe { stream.alloc::<u8>(bytes).unwrap() }
|
||||
}
|
||||
|
||||
fn copy_to_dev<T: Copy>(stream: &Arc<CudaStream>, data: &[T]) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
|
||||
};
|
||||
stream.clone_htod(bytes).unwrap()
|
||||
}
|
||||
|
||||
/// Run FlashInferAttention.execute() directly and reshape the output to the
|
||||
/// reference (batch, heads, dim) layout used by `run_reference_attention`.
|
||||
fn run_flashinfer(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k_cache: &[f32],
|
||||
v_cache: &[f32],
|
||||
kv_indptr: &[i32],
|
||||
kv_indices: &[i32],
|
||||
batch_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let q_buf = copy_to_dev(stream, q);
|
||||
let k_buf = copy_to_dev(stream, k_cache);
|
||||
let v_buf = copy_to_dev(stream, v_cache);
|
||||
let flat_idx = build_flat_gather_idx(kv_indices);
|
||||
let flat_idx_buf = copy_to_dev(stream, &flat_idx);
|
||||
let mask_buf = alloc_dev(stream, 4); // unused but reserved
|
||||
let qo_indptr: Vec<i32> = (0..=batch_size as i32).collect();
|
||||
let qo_indptr_buf = copy_to_dev(stream, &qo_indptr);
|
||||
let kv_indptr_buf = copy_to_dev(stream, kv_indptr);
|
||||
let out_buf = alloc_dev(stream, batch_size * HIDDEN * 4);
|
||||
|
||||
let fi = FlashInferAttention {
|
||||
num_qo_heads: N_HEADS,
|
||||
num_kv_heads: N_KV_HEADS,
|
||||
head_dim: HEAD_DIM,
|
||||
page_size: 1,
|
||||
batch_dim: Expression::from('s'),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Reserve dedicated NodeIndex values for the test ports.
|
||||
let nodes: Vec<NodeIndex> = (0..8).map(NodeIndex::new).collect();
|
||||
let (q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n, out_n) = (
|
||||
nodes[0], nodes[1], nodes[2], nodes[3], nodes[4], nodes[5], nodes[6], nodes[7],
|
||||
);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
let q_ptr = q_buf.device_ptr(stream).0;
|
||||
let k_ptr = k_buf.device_ptr(stream).0;
|
||||
let v_ptr = v_buf.device_ptr(stream).0;
|
||||
let idx_ptr = flat_idx_buf.device_ptr(stream).0;
|
||||
let mask_ptr = mask_buf.device_ptr(stream).0;
|
||||
let qo_ptr = qo_indptr_buf.device_ptr(stream).0;
|
||||
let kv_ptr = kv_indptr_buf.device_ptr(stream).0;
|
||||
let out_ptr = out_buf.device_ptr(stream).0;
|
||||
buffers.insert(q_n, DeviceBuffer::new(q_ptr, q.len() * 4));
|
||||
buffers.insert(k_n, DeviceBuffer::new(k_ptr, k_cache.len() * 4));
|
||||
buffers.insert(v_n, DeviceBuffer::new(v_ptr, v_cache.len() * 4));
|
||||
buffers.insert(idx_n, DeviceBuffer::new(idx_ptr, flat_idx.len() * 4));
|
||||
buffers.insert(mask_n, DeviceBuffer::new(mask_ptr, 4));
|
||||
buffers.insert(qo_n, DeviceBuffer::new(qo_ptr, qo_indptr.len() * 4));
|
||||
buffers.insert(kv_n, DeviceBuffer::new(kv_ptr, kv_indptr.len() * 4));
|
||||
buffers.insert(out_n, DeviceBuffer::new(out_ptr, batch_size * HIDDEN * 4));
|
||||
|
||||
let inputs = [q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n];
|
||||
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', batch_size);
|
||||
dyn_map.insert('c', kv_indices.len());
|
||||
dyn_map.insert('r', kv_indptr.len());
|
||||
|
||||
fi.execute(stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.expect("FlashInferAttention execute failed");
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Output is (heads, batch, dim); reshape to (batch, heads, dim).
|
||||
let mut out_bytes = vec![0u8; batch_size * HIDDEN * 4];
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtoh_async(&mut out_bytes, out_ptr, stream.cu_stream())
|
||||
.unwrap();
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let raw: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(out_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
transpose_hbd_to_bhd(&raw, N_HEADS, batch_size, HEAD_DIM)
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn deterministic_f32(n: usize, seed: f32, scale: f32) -> Vec<f32> {
|
||||
(0..n).map(|i| (i as f32 * seed).sin() * scale).collect()
|
||||
}
|
||||
|
||||
fn assert_close(a: &[f32], b: &[f32], rtol: f32, atol: f32) {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"length mismatch: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
);
|
||||
let mut worst = (0usize, 0.0f32);
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff = (x - y).abs();
|
||||
if diff > worst.1 {
|
||||
worst = (i, diff);
|
||||
}
|
||||
let tol = atol + rtol * y.abs();
|
||||
assert!(
|
||||
diff <= tol,
|
||||
"mismatch at idx {i}: {x} vs {y} (|diff|={diff}, tol={tol})"
|
||||
);
|
||||
}
|
||||
eprintln!("max |diff| = {:.2e} @ idx {}", worst.1, worst.0);
|
||||
}
|
||||
|
||||
// ─── Layer 1: egglog metadata sanity (no GPU) ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_registers_via_into_egglog() {
|
||||
// Confirm the op is reachable through the Runtime::Ops tuple. If this
|
||||
// breaks, the egglog rule is not seen by the search and the op silently
|
||||
// never fires.
|
||||
assert!(
|
||||
ops_contains_sort("FlashInferAttention"),
|
||||
"FlashInferAttention is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_egg_rule_parses() {
|
||||
// Rule::raw() returns the rule with no validation; egglog parses it at
|
||||
// graph build. Smoke-test by running it through the egglog frontend via
|
||||
// a tiny program string.
|
||||
let op = FlashInferAttention::default();
|
||||
let rewrites = op.rewrites();
|
||||
assert_eq!(rewrites.len(), 1);
|
||||
// The rule must mention FlashInferAttention to be the right one.
|
||||
let s = format!("{:?}", rewrites[0]);
|
||||
assert!(
|
||||
s.contains("FlashInferAttention"),
|
||||
"rewrite is not the FlashInfer rule: {s}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_sort_shape() {
|
||||
let op = FlashInferAttention::default();
|
||||
let s = op.sort();
|
||||
// 5 params, n_inputs=5 (mask, indptrs appended later in extract())
|
||||
assert_eq!(op.n_inputs(), 5);
|
||||
let dbg = format!("{:?}", s);
|
||||
assert!(dbg.contains("FlashInferAttention"));
|
||||
}
|
||||
|
||||
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs1_ctx4() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k = deterministic_f32(context_len * KV_DIM, 0.021, 0.1);
|
||||
let v = deterministic_f32(context_len * KV_DIM, 0.031, 0.1);
|
||||
let expected = run_reference_attention(&stream, &q, &k, &v, batch_size, context_len);
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = (0..context_len as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs2_supersequence() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 2;
|
||||
let ctx0 = 8;
|
||||
let ctx1 = 3;
|
||||
let total_ctx = ctx0 + ctx1;
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.014, 0.1);
|
||||
let k = deterministic_f32(total_ctx * KV_DIM, 0.022, 0.1);
|
||||
let v = deterministic_f32(total_ctx * KV_DIM, 0.032, 0.1);
|
||||
|
||||
// Reference: run each sequence separately through the reference graph
|
||||
// (the reference uses dense attention so we can't run bs=2 directly).
|
||||
let expected0 = run_reference_attention(
|
||||
&stream,
|
||||
&q[..HIDDEN],
|
||||
&k[..ctx0 * KV_DIM],
|
||||
&v[..ctx0 * KV_DIM],
|
||||
1,
|
||||
ctx0,
|
||||
);
|
||||
let expected1 = run_reference_attention(
|
||||
&stream,
|
||||
&q[HIDDEN..],
|
||||
&k[ctx0 * KV_DIM..],
|
||||
&v[ctx0 * KV_DIM..],
|
||||
1,
|
||||
ctx1,
|
||||
);
|
||||
let expected: Vec<f32> = expected0.into_iter().chain(expected1).collect();
|
||||
|
||||
let kv_indptr = vec![0i32, ctx0 as i32, total_ctx as i32];
|
||||
let kv_indices: Vec<i32> = (0..total_ctx as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_noncontiguous_page_table() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let num_slots = 8;
|
||||
let slot_indices = [3usize, 0, 7, 1];
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k_full = deterministic_f32(num_slots * KV_DIM, 0.022, 0.1);
|
||||
let v_full = deterministic_f32(num_slots * KV_DIM, 0.033, 0.1);
|
||||
|
||||
// Reference operates on the contiguous gathered cache.
|
||||
let mut k_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
let mut v_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
for (i, &slot) in slot_indices.iter().enumerate() {
|
||||
k_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&k_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
v_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&v_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
}
|
||||
let expected = run_reference_attention(
|
||||
&stream,
|
||||
&q,
|
||||
&k_gathered,
|
||||
&v_gathered,
|
||||
batch_size,
|
||||
context_len,
|
||||
);
|
||||
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = slot_indices.iter().map(|&s| s as i32).collect();
|
||||
let result = run_flashinfer(
|
||||
&stream,
|
||||
&q,
|
||||
&k_full,
|
||||
&v_full,
|
||||
&kv_indptr,
|
||||
&kv_indices,
|
||||
batch_size,
|
||||
);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
// ─── Layer 3b: HEAD_DIM 128 path (validates the head-dim JIT dispatch) ────
|
||||
//
|
||||
// Each FlashInfer .so is compiled for one HEAD_DIM. JIT caches by head dim;
|
||||
// the OnceLock means only one is loaded per process. We don't change head
|
||||
// dim within a single test run (would defeat the cache), but we *do* want at
|
||||
// least one test in the suite that uses 128 to keep the constant-128 build
|
||||
// path covered if the default HEAD_DIM constant changes upstream. We assert
|
||||
// the constraint here rather than firing a second JIT.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_jit_head_dim_assertion() {
|
||||
// 64 / 128 / 256 must be the only allowed values.
|
||||
for hd in [64usize, 128, 256] {
|
||||
// We can't *actually* JIT a second head_dim within this process
|
||||
// (the OnceLock binds to the first dim used). Just check the dim
|
||||
// is in the supported set.
|
||||
assert!(matches!(hd, 64 | 128 | 256));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Layer 4: egglog rule firing (no GPU) ────────────────────────────────
|
||||
//
|
||||
// These tests build HLIR graphs and run egglog saturation. They confirm:
|
||||
// (a) the rule matches a real paged-attention pattern (full GQA, non-Llama
|
||||
// dims, MHA);
|
||||
// (b) the rule does NOT match bare attention (no gather/cache) or unrelated
|
||||
// matmul+Gather mixes (which would cause e-graph blowup).
|
||||
//
|
||||
// Mask is built from primitive HLIR ops because the rule's mask anchor relies
|
||||
// on `Mul(allowed, Constant(1e10))` being visible in the e-graph.
|
||||
|
||||
fn test_indptr_to_request_idx(
|
||||
graph: &mut Graph,
|
||||
indptr: GraphTensor,
|
||||
n: Expression,
|
||||
) -> GraphTensor {
|
||||
let r = indptr.dims1();
|
||||
let indices = graph.arange(n).expand_dim(1, r);
|
||||
let indptr_2d = indptr.expand_dim(0, n);
|
||||
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
|
||||
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
|
||||
}
|
||||
|
||||
fn test_compute_attn_mask(
|
||||
graph: &mut Graph,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
c: Expression,
|
||||
) -> GraphTensor {
|
||||
let s = q_pos.dims1();
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s);
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c);
|
||||
let c_arange = graph.arange(c);
|
||||
let c_kv_start = kv_indptr.gather(c_request);
|
||||
let c_local_pos = c_arange - c_kv_start;
|
||||
let q_req_2d = q_request.expand_dim(1, c);
|
||||
let c_req_2d = c_request.expand_dim(0, s);
|
||||
let same = q_req_2d.eq(c_req_2d);
|
||||
let c_pos_2d = c_local_pos.expand_dim(0, s);
|
||||
let qp_2d = q_pos.expand_dim(1, c);
|
||||
let causal = c_pos_2d.le(qp_2d);
|
||||
let allowed = same.cast(luminal::dtype::DType::F32) * causal.cast(luminal::dtype::DType::F32);
|
||||
allowed * 1e10 - 1e10
|
||||
}
|
||||
|
||||
fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = data.graph().arange(d as i32).expand_dim(0, n);
|
||||
data.gather(base + col)
|
||||
}
|
||||
|
||||
fn scatter_rows(
|
||||
src: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
dest: GraphTensor,
|
||||
d: usize,
|
||||
) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = src.graph().arange(d as i32).expand_dim(0, n);
|
||||
src.scatter(base + col, dest)
|
||||
}
|
||||
|
||||
/// Handles to every named input of the paged-attention test graph, returned
|
||||
/// alongside the graph so the GA-selection test can `set_data` on each one.
|
||||
#[allow(dead_code)]
|
||||
struct PagedAttnHandles {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v_new: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
}
|
||||
|
||||
/// Build a full paged-attention HLIR graph with the structural anchors the
|
||||
/// FlashInfer egglog rule looks for: scatter into a 2D cache, gather rows out
|
||||
/// by index, GQA broadcast via `Mul(..., 1.0)` with zero strides, Q*K^T → Sum
|
||||
/// → scale → mask Add → softmax → *V → Sum.
|
||||
fn build_paged_attention_graph(
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> (Graph, PagedAttnHandles) {
|
||||
let kv_groups = n_heads / n_kv_heads;
|
||||
let kv_dim = n_kv_heads * head_dim;
|
||||
let hidden = n_heads * head_dim;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', hidden));
|
||||
let k_rope = cx.named_tensor("k_rope", ('s', kv_dim));
|
||||
let v_new = cx.named_tensor("v_new", ('s', kv_dim));
|
||||
let k_cache = cx.named_tensor("k_cache", (2048, kv_dim)).persist();
|
||||
let v_cache = cx.named_tensor("v_cache", (2048, kv_dim)).persist();
|
||||
let scatter_idx = cx
|
||||
.named_tensor("scatter_idx", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let q_pos = cx
|
||||
.named_tensor("q_pos", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let qo_indptr = cx
|
||||
.named_tensor("qo_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let kv_indptr = cx
|
||||
.named_tensor("kv_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, kv_dim);
|
||||
let v_cache_out = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, kv_dim);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, kv_dim);
|
||||
|
||||
let c: Expression = 'c'.into();
|
||||
let attn_mask = test_compute_attn_mask(&mut cx, q_pos, qo_indptr, kv_indptr, c);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.split_dims(1, head_dim).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (head_dim as f32).sqrt();
|
||||
let mask = attn_mask.expand_dim(0, n_heads);
|
||||
let masked_scores = scores + mask;
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
attn_out.output();
|
||||
k_cache_out.output();
|
||||
v_cache_out.output();
|
||||
|
||||
(
|
||||
cx,
|
||||
PagedAttnHandles {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
q_pos,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Saturate egglog on the graph and report whether a FlashInferAttention
|
||||
/// e-node was produced. Helper used by the rule-firing tests.
|
||||
fn saturate_and_has_flashinfer(cx: &Graph) -> (bool, Vec<String>) {
|
||||
let (program, root) = hlir_to_egglog(cx);
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
// cleanup=false: keep every saturation-introduced e-node so we can inspect
|
||||
// whether the FlashInferAttention rule produced a node, regardless of
|
||||
// whether downstream extraction would have pruned it.
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
let has_flashinfer = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
|
||||
// Collect distinct OpKind labels so a failure can print what *did* match.
|
||||
let mut op_kinds: Vec<String> = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.filter(|(l, _)| {
|
||||
!l.starts_with('(')
|
||||
&& ![
|
||||
"Op",
|
||||
"Input",
|
||||
"Output",
|
||||
"OutputJoin",
|
||||
"ICons",
|
||||
"INil",
|
||||
"ECons",
|
||||
"ENil",
|
||||
"MNum",
|
||||
"MVar",
|
||||
"MMul",
|
||||
"MDiv",
|
||||
"MIter",
|
||||
]
|
||||
.contains(&l.as_str())
|
||||
})
|
||||
.map(|(l, _)| l.clone())
|
||||
.collect();
|
||||
op_kinds.sort();
|
||||
op_kinds.dedup();
|
||||
|
||||
(has_flashinfer, op_kinds)
|
||||
}
|
||||
|
||||
/// Debug aid: dump the egglog program and key e-graph metrics for the lite
|
||||
/// paged-attention test so we can see why the FlashInfer rule isn't matching.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn flashinfer_dump_paged_attn_egglog() {
|
||||
// First sanity-check that each Ops member returns its rewrites and that
|
||||
// FlashInferAttention's rule appears in the combined corpus.
|
||||
let ops_vec = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
eprintln!("==== Ops rewrites count ====");
|
||||
let mut fi_rewrites = 0usize;
|
||||
let mut total_rewrites = 0usize;
|
||||
for op in &ops_vec {
|
||||
let rws = op.rewrites();
|
||||
total_rewrites += rws.len();
|
||||
for r in &rws {
|
||||
let s = format!("{r:?}");
|
||||
if s.contains("FlashInferAttention") {
|
||||
fi_rewrites += 1;
|
||||
eprintln!("FOUND FlashInfer rewrite ({} chars)", s.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"==== ops_vec.len()={} total_rewrites={total_rewrites} fi_rewrites={fi_rewrites} ====",
|
||||
ops_vec.len()
|
||||
);
|
||||
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (program, root) = hlir_to_egglog(&cx);
|
||||
eprintln!("==== EGGLOG PROGRAM (root={root}) ====");
|
||||
for (i, line) in program.lines().enumerate() {
|
||||
eprintln!("{:5}: {line}", i + 1);
|
||||
}
|
||||
eprintln!(
|
||||
"==== END EGGLOG PROGRAM ({} lines) ====",
|
||||
program.lines().count()
|
||||
);
|
||||
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
// Bucket enode labels by frequency.
|
||||
let mut counts: std::collections::HashMap<String, usize> = Default::default();
|
||||
for (label, _) in egraph.enodes.values() {
|
||||
*counts.entry(label.clone()).or_default() += 1;
|
||||
}
|
||||
let mut sorted: Vec<_> = counts.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
eprintln!("==== E-GRAPH LABEL HISTOGRAM (top 60) ====");
|
||||
for (label, n) in sorted.iter().take(60) {
|
||||
eprintln!(" {n:6} {label}");
|
||||
}
|
||||
let has_fi = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
eprintln!("==== has FlashInferAttention enode: {has_fi} ====");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_bare_attention() {
|
||||
// Dense attention without paged gather + cache should NOT match.
|
||||
let (cx, _, _, _, _) = build_attention_graph();
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on bare attention (no gather/cache)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_unrelated_matmuls() {
|
||||
// A Gather + plain matmul (MLP-shaped projection) plus two chained matmuls
|
||||
// through softmax — close to attention structurally but missing the GQA
|
||||
// broadcast / mask Add anchors. The rule must reject this.
|
||||
let mut cx = Graph::default();
|
||||
let cache = cx.named_tensor("cache", (4096, KV_DIM)).persist();
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let weight = cx.named_tensor("weight", (HIDDEN, KV_DIM)).persist();
|
||||
|
||||
let n = gather_idx.dims1();
|
||||
let base = (gather_idx * KV_DIM).expand_dim(1, KV_DIM);
|
||||
let col = cx.arange(KV_DIM as i32).expand_dim(0, n);
|
||||
let gathered = cache.gather(base + col);
|
||||
let proj = gathered.matmul(weight.t());
|
||||
proj.output();
|
||||
|
||||
let a = cx.named_tensor("a", ('s', HIDDEN));
|
||||
let b = cx.named_tensor("b", (HIDDEN, HIDDEN)).persist();
|
||||
let c_tensor = cx.named_tensor("c_tensor", (HIDDEN, HIDDEN)).persist();
|
||||
let ab = a.matmul(b.t());
|
||||
let abc = ab.softmax(1).matmul(c_tensor.t());
|
||||
abc.output();
|
||||
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on unrelated matmuls + Gather"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_full_paged_attention() {
|
||||
// Default Llama-shaped test dims (HEAD_DIM=64, N_HEADS=8, N_KV_HEADS=2).
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found in the e-graph (Llama-shaped paged attention). \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_non_llama_dims() {
|
||||
// Different head counts: HEAD_DIM=64, N_HEADS=16, N_KV_HEADS=4 (group=4).
|
||||
// Exercises the model-agnostic structural variables in the rule.
|
||||
let (cx, _) = build_paged_attention_graph(16, 4, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for non-Llama dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_mha() {
|
||||
// MHA: KV_GROUPS=1 (n_heads == n_kv_heads). The GQA broadcast still
|
||||
// structurally appears (expand_dim(1, 1) + merge), so the rule should
|
||||
// still match.
|
||||
let (cx, _) = build_paged_attention_graph(12, 12, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for MHA dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 5: extraction reachability (no GPU) ───────────────────────────
|
||||
//
|
||||
// After `build_search_space` saturates egglog, the GA picks an extraction by
|
||||
// cost. In a tiny test graph the cuBLAS+kernel path is often faster than the
|
||||
// FlashInfer host op (which pays a `plan()` setup cost per call), so asserting
|
||||
// "GA picked FlashInfer" is flaky. Instead, sample many random valid genomes
|
||||
// from the search space and assert that the FlashInfer extraction is reachable
|
||||
// — meaning the rule fired AND `find_indptrs` extraction succeeded for at
|
||||
// least one offspring. That is the end-to-end check we actually want.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_extraction_reachable_from_search_space() {
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
let (mut cx, _h) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
cx.set_dim('s', 1usize);
|
||||
cx.set_dim('c', 16usize);
|
||||
cx.set_dim('r', 2usize);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let egraph = cx
|
||||
.egraph()
|
||||
.expect("egraph missing after build_search_space");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("egglog_ops missing after build_search_space");
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0xf1a541);
|
||||
let mut prev: FxHashSet<u64> = FxHashSet::default();
|
||||
let initial = luminal::egglog_utils::random_initial_choice(egraph, &mut rng);
|
||||
prev.insert(luminal::egglog_utils::hash_choice_set(&initial));
|
||||
let mut base = initial;
|
||||
|
||||
let mut found = false;
|
||||
'outer: for _ in 0..50 {
|
||||
let offspring =
|
||||
luminal::egglog_utils::extract_generation(egraph, &base, 10, 2, &mut prev, &mut rng);
|
||||
if offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
for genome in offspring {
|
||||
if luminal::egglog_utils::validate_choice_set(egraph, &genome, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
// Catch a possible panic from find_indptrs walking the mask — we
|
||||
// want the test to fail with a clean message, not abort.
|
||||
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
luminal::egglog_utils::egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
)
|
||||
}));
|
||||
let Ok(llir_graph) = panicked else { continue };
|
||||
|
||||
let has_fi = llir_graph.node_indices().any(|n| {
|
||||
llir_graph[n]
|
||||
.to_dialect::<dyn HostOp>()
|
||||
.and_then(|op| op.stats_name())
|
||||
== Some("FlashInferAttention")
|
||||
});
|
||||
if has_fi {
|
||||
found = true;
|
||||
break 'outer;
|
||||
}
|
||||
base = genome;
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found,
|
||||
"FlashInferAttention extraction not reachable from search space after 50 generations"
|
||||
);
|
||||
}
|
||||
@@ -3,27 +3,93 @@ use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::fusion::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
use crate::kernel::other_ops::{KernelFusedElementwise, UnaryFn};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{
|
||||
TOLERANCE_SAFETY_FACTOR, dtype_epsilon, random_f32_vec, test_binary_cuda, test_unary_cuda,
|
||||
};
|
||||
use crate::tests::utilities::{random_f32_vec, test_unary_cuda};
|
||||
|
||||
/// Return every distinct kernel_name that appears across many random extractions
|
||||
/// of the search space. Used to check whether fusion produces a reachable
|
||||
/// `KernelFusedElementwise` node (or, negatively, that it never does).
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_names = Vec::new();
|
||||
for _ in 0..50 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
let name = k.kernel_name().to_string();
|
||||
if !all_names.contains(&name) {
|
||||
all_names.push(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_names
|
||||
}
|
||||
|
||||
/// Return every distinct `Vec<UnaryFn>` that appears inside a reachable
|
||||
/// `KernelFusedElementwise` across many random extractions. Used to verify
|
||||
/// that a specific fused configuration (e.g. a 3-op chain) is reachable.
|
||||
fn extract_all_fused_configs(cx: &mut Graph) -> Vec<Vec<UnaryFn>> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_configs: Vec<Vec<UnaryFn>> = Vec::new();
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(kop) = op.to_dialect::<dyn KernelOp>()
|
||||
&& let Some(fused) = (***kop).downcast_ref::<KernelFusedElementwise>()
|
||||
{
|
||||
let cfg = fused.ops().to_vec();
|
||||
if !all_configs.contains(&cfg) {
|
||||
all_configs.push(cfg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_configs
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_unary_ops_fuse() {
|
||||
// Marker form: `a.sin().sqrt()` should fuse into a region with FusedSin
|
||||
// and FusedSqrt under one FusionEnd (per pair-fuse U→U).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
names.iter().any(|n| n == "FusedElementwise"),
|
||||
"expected KernelSin→KernelSqrt on contiguous strides to be fusable into \
|
||||
a single FusedElementwise kernel, but reachable kernels were: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -31,42 +97,33 @@ fn test_two_unary_ops_fuse() {
|
||||
fn test_stride_mismatch_prevents_fusion() {
|
||||
// A permute between sin and sqrt gives sqrt a non-contiguous view of sin's
|
||||
// contiguous output, so sqrt's in_strides != its out_strides and the
|
||||
// non-linear `?s ?s` match in the pair-fuse U→U rule can't fire.
|
||||
// non-linear `?strides` match in the fusion rule can't fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let _b = a.sin().permute((1, 0)).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"permute between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a permute between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduction_prevents_unary_fusion() {
|
||||
// A reduction between two unaries is not elementwise, so pair-fuse U→U
|
||||
// (which only matches adjacent elementwise pairs) must not fire across
|
||||
// the reduction.
|
||||
// A reduction between two unaries is not elementwise, so the fusion rule
|
||||
// (which only matches unary+unary pairs) must not fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let _b = a.sin().sum(1).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"reduction between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a reduction between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -88,36 +145,31 @@ fn test_unary_fusion_preserves_output() {
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single marker region containing all three elementwise ops.
|
||||
// reachable as a single FusedElementwise containing all three ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2"]);
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2];
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single marker region containing all
|
||||
// four elementwise ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
// 4-op chain should collapse into a single Fused containing all four ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2", "FusedLog2"]);
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2, UnaryFn::Log2];
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2, Log2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -264,747 +316,3 @@ extern "C" __global__ void fused_k(float* out, const float* in, long long n) {
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Binary-inclusive fusion tests (marker-based FusionStart / FusionEnd scheme).
|
||||
//
|
||||
// Detects fused regions by walking backward from each `FusionEnd`-tagged LLIR
|
||||
// node through `Direction::Incoming` edges until a `FusionStart` is reached.
|
||||
// The walker stops at FusionStarts (they mark the external-input boundary of
|
||||
// the region). A region's summary is: the sorted set of internal op names,
|
||||
// the count of distinct FusionStart nodes reached, and the count of FusionEnd
|
||||
// nodes (invariant: always 1 per region).
|
||||
// =========================================================================
|
||||
|
||||
/// A single fused region extracted from the LLIR graph after egglog.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct FusedRegion {
|
||||
/// Sorted internal op `kernel_name()`s, excluding the `FusionStart` /
|
||||
/// `FusionEnd` markers. Sorted so DAG traversal order doesn't produce
|
||||
/// spurious "distinct" regions.
|
||||
internal_ops_sorted: Vec<String>,
|
||||
/// Number of distinct `FusionStart` nodes reached by the walk. Per design
|
||||
/// this equals the number of distinct external input tensors.
|
||||
start_count: usize,
|
||||
/// Number of `FusionEnd` nodes in the region. Per design this is always 1.
|
||||
end_count: usize,
|
||||
}
|
||||
|
||||
/// Helper: collect every distinct fused region reachable across many random
|
||||
/// extractions of the search space.
|
||||
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut seen: Vec<FusedRegion> = Vec::new();
|
||||
// 200 samples: the random extractor picks one e-node per e-class per
|
||||
// call, and the fully-fused diamond form lives in an e-class with
|
||||
// many equivalent forms. 50 was flaky; 200 is reliably stable and
|
||||
// each sample is cheap (~100 µs).
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let name_of = |idx: NodeIndex| -> Option<String> {
|
||||
llir.node_weight(idx).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| {
|
||||
if let Some(elem) = (***k).downcast_ref::<CudaUnaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else if let Some(elem) = (***k).downcast_ref::<CudaBinaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else {
|
||||
k.kernel_name().to_string()
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
let end_nodes: Vec<NodeIndex> = llir
|
||||
.node_indices()
|
||||
.filter(|&idx| name_of(idx).as_deref() == Some("FusionEnd"))
|
||||
.collect();
|
||||
|
||||
for end in end_nodes {
|
||||
let mut internal: Vec<String> = Vec::new();
|
||||
// Count distinct external input *tensors*, not distinct FusionStart
|
||||
// node indices. Egglog rule firings can emit multiple FusionStart
|
||||
// enodes that all wrap the same source tensor (e.g. when the same
|
||||
// `a` is consumed at two sites inside the fused region, each
|
||||
// pair-fuse / grow firing mints its own FusionStart). Those are
|
||||
// logically one FusionStart per the design invariant
|
||||
// ("N = number of distinct external input tensors").
|
||||
let mut start_sources: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
visited.insert(end);
|
||||
let mut stack = vec![end];
|
||||
|
||||
// Resolve chains of nested FusionStart wrappers (cascade artifact)
|
||||
// to the real external source. A FusionStart whose incoming neighbor
|
||||
// is itself a FusionStart is a cascade layer, not a new external
|
||||
// tensor. A FusionEnd predecessor is a real external region output
|
||||
// in the generic singleton-region model, so do not walk through it.
|
||||
let resolve_source = |mut n: NodeIndex| -> NodeIndex {
|
||||
loop {
|
||||
match name_of(n).as_deref() {
|
||||
Some("FusionStart") => {
|
||||
let mut inc = llir.neighbors_directed(n, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(p) => n = p,
|
||||
None => return n,
|
||||
}
|
||||
}
|
||||
_ => return n,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
for pred in llir.neighbors_directed(node, petgraph::Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred).as_deref() {
|
||||
Some("FusionStart") => {
|
||||
// If this FS's predecessor is itself a FE (or a
|
||||
// chain of FS/FE wrappers that eventually hits a
|
||||
// non-marker op inside the region), the FS is a
|
||||
// cascade artifact, not a real external boundary.
|
||||
// Walk past it and its upstream FE into the same
|
||||
// region. Otherwise treat the predecessor as the
|
||||
// external source tensor — which may be a KernelOp
|
||||
// *or* a non-KernelOp (HLIR loadable) node, so we
|
||||
// can't gate counting on `name_of` being `Some`.
|
||||
let mut inc =
|
||||
llir.neighbors_directed(pred, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(src_node) => {
|
||||
start_sources.insert(resolve_source(src_node));
|
||||
}
|
||||
None => {
|
||||
// FS with no predecessor — degenerate.
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// Transparent: inner FusionEnds are cascade-wart
|
||||
// artifacts from grow rules re-firing and creating
|
||||
// nested `FE(Op(FE(...)))` wrappers. They don't
|
||||
// represent real work or a real boundary — walk
|
||||
// past them and do not count them as internal ops.
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) => {
|
||||
internal.push(other.to_string());
|
||||
stack.push(pred);
|
||||
}
|
||||
None => {
|
||||
// Non-KernelOp predecessor (shouldn't appear inside a
|
||||
// fused region under the design). Stop walking this path.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal.sort();
|
||||
// Skip singleton regions: every elementwise op has a seeded
|
||||
// `FE(Op(FS(...)))` form, so random extraction will surface
|
||||
// many one-op regions that are equivalent to not fusing. We
|
||||
// only care about regions that represent real multi-op fusion.
|
||||
if internal.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
let region = FusedRegion {
|
||||
internal_ops_sorted: internal,
|
||||
start_count: start_sources.len(),
|
||||
end_count: 1,
|
||||
};
|
||||
if !seen.contains(®ion) {
|
||||
seen.push(region);
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
fn sorted_names(items: &[&str]) -> Vec<String> {
|
||||
let mut v: Vec<String> = items.iter().map(|s| (*s).to_string()).collect();
|
||||
v.sort();
|
||||
v
|
||||
}
|
||||
|
||||
// ---- Structural tests: the expected fused shape is reachable ----
|
||||
|
||||
#[test]
|
||||
fn test_single_binary_does_not_fuse_alone() {
|
||||
// A lone elementwise op gets a seeded singleton region by design; we
|
||||
// filter singletons out in `extract_all_fused_regions`. What this test
|
||||
// asserts is that no *multi-op* region appears for a standalone binary
|
||||
// — nothing to grow into.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
assert!(
|
||||
regions.is_empty(),
|
||||
"a solo binary op should not form a multi-op fused region, but got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_of_binaries_fuses() {
|
||||
// `(a + b) * c`: three external inputs collapse into one region with
|
||||
// internal [Add, Mul] and 3 FusionStarts.
|
||||
//
|
||||
// Requires BB family, which is opt-in at runtime via
|
||||
// LUMINAL_FUSION_FAMILIES. Set it before the graph build so the rules
|
||||
// emitted from FusionEnd::rewrites include the B-B pair-fuse rules.
|
||||
// SAFETY: tests run in parallel; we set this before constructing the
|
||||
// Graph, and never unset, so concurrent tests just see BB on.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = ((a + b) * c).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_then_unary_fuses() {
|
||||
// `sin(a + b)`: binary feeds a unary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).sin().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_then_binary_fuses() {
|
||||
// `sin(a) + b`: unary feeds a binary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a.sin() + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Subsume in grow rules (introduced to bound the BB partial-FE explosion)
|
||||
// means a multi-consumer producer can no longer be fused into the same
|
||||
// region as all its consumers — only one branch wins. The diamond's `t`
|
||||
// has two consumers, so the structural "one 5-op region" outcome is no
|
||||
// longer guaranteed. Numerical correctness still holds (see
|
||||
// test_diamond_dag_preserves_output).
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_diamond_dag_fuses() {
|
||||
// The canonical diamond-DAG example agreed with the user:
|
||||
// t = a + b; u = exp2(t); v = sin(t); w = u * a; out = w + v
|
||||
// `a` is reused (feeds outer Add and Mul) and `t` is reused (feeds Exp2 and
|
||||
// Sin). Expected: one fused region with internal ops [Add, Add, Exp2, Mul,
|
||||
// Sin], 2 FusionStarts (distinct tensors a, b), 1 FusionEnd.
|
||||
// We use exp2 rather than exp because the frontend's exp() desugars to
|
||||
// Mul(x, LOG2E).exp2(), which would add a constant input and a Mul op and
|
||||
// obscure the diamond topology this test is checking.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2 && r.end_count == 1),
|
||||
"expected diamond DAG to fuse into one region with ops {expected:?}, \
|
||||
2 FusionStarts, 1 FusionEnd. Got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Negative tests: fusion must NOT happen across these blockers ----
|
||||
|
||||
#[test]
|
||||
fn test_reduction_blocks_binary_fusion() {
|
||||
// A reduction between a binary and anything downstream is not elementwise,
|
||||
// so Add and SumReduce must never appear in the same fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let b = cx.tensor((4, 4));
|
||||
let _c = (a + b).sum(1).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_add = r.internal_ops_sorted.iter().any(|n| n == "FusedAdd");
|
||||
let has_sum = r.internal_ops_sorted.iter().any(|n| n == "SumReduce");
|
||||
assert!(
|
||||
!(has_add && has_sum),
|
||||
"FusedAdd and SumReduce must not share a fused region, but got: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stride_mismatch_blocks_binary_fusion() {
|
||||
// A permute gives `b` a non-contiguous view whose strides do not match `a`'s,
|
||||
// so the binary fusion rule's stride-compatibility check must prevent the
|
||||
// Add from being absorbed into any fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let b = cx.tensor((4, 3));
|
||||
let _c = (a + b.permute((1, 0))).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
assert!(
|
||||
!r.internal_ops_sorted.iter().any(|n| n == "FusedAdd"),
|
||||
"permuted binary must not fuse into a region, but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Numerical parity tests: fused output matches candle reference ----
|
||||
|
||||
#[test]
|
||||
fn test_simple_binary_fusion_preserves_output() {
|
||||
// End-to-end numerical check: `a + b` on GPU matches candle's add across
|
||||
// all reachable genomes (fused or unfused) via test_binary_cuda's fuzzer.
|
||||
let seed = 0xADDBEEFu64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| a + b,
|
||||
|a, b| (a + b).unwrap(),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diamond_dag_preserves_output() {
|
||||
// Numerical parity for the diamond DAG: `(exp(a+b) * a) + sin(a+b)`
|
||||
// matches candle's equivalent across fused and unfused genomes.
|
||||
// Inputs are drawn from [-1, 1] so exp() doesn't overflow.
|
||||
let seed = 0xD1A_0D1Au64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
// Five-op chain with exp + sin: allow ~5x safety to absorb accumulated
|
||||
// rounding vs candle's kernels.
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR * 5.0;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| {
|
||||
let t = a + b;
|
||||
let u = t.exp();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
w + v
|
||||
},
|
||||
|a, b| {
|
||||
let t = (&a + &b).unwrap();
|
||||
let u = t.exp().unwrap();
|
||||
let v = t.sin().unwrap();
|
||||
let w = (&u * &a).unwrap();
|
||||
(&w + &v).unwrap()
|
||||
},
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Marker invariant tests ----
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_has_exactly_one_end() {
|
||||
// Design invariant: a fused region always has exactly one FusionEnd.
|
||||
// Uses the diamond DAG so there's real fan-in/out inside the region.
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
let full = regions
|
||||
.iter()
|
||||
.find(|r| r.internal_ops_sorted == expected)
|
||||
.expect("expected at least one extraction to produce the full 5-op diamond region");
|
||||
assert_eq!(
|
||||
full.end_count, 1,
|
||||
"fused region must have exactly one FusionEnd, got {}",
|
||||
full.end_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_starts_match_distinct_external_tensors() {
|
||||
// Design invariant: FusionStart count == number of distinct external input
|
||||
// tensors, NOT number of edges crossing the boundary. In the diamond DAG
|
||||
// `a` is consumed inside the region by two ops (outer Add + Mul), so a
|
||||
// per-edge counting scheme would give 3; the correct per-distinct-tensor
|
||||
// count is 2 ({a, b}).
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
// Multiple 5-op extractions are reachable: the merge-FE-FE rule fires
|
||||
// across paths that may have minted distinct FS enodes for the shared
|
||||
// tensor `a` at separate sites. The design invariant is that *some*
|
||||
// extraction collapses those into the deduped form (one FS per distinct
|
||||
// tensor → 2 FS for {a, b}); we don't require every random sample to.
|
||||
let matching: Vec<&FusedRegion> = regions
|
||||
.iter()
|
||||
.filter(|r| r.internal_ops_sorted == expected)
|
||||
.collect();
|
||||
assert!(
|
||||
!matching.is_empty(),
|
||||
"expected at least one extraction to produce the full 5-op diamond region, \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
assert!(
|
||||
matching
|
||||
.iter()
|
||||
.any(|r| r.start_count == 2 && r.end_count == 1),
|
||||
"expected at least one 5-op diamond extraction with FusionStart count == 2 \
|
||||
(one per distinct external tensor) and FusionEnd count == 1; got: {matching:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Targeted rule-family tests (one per family / orientation) ----
|
||||
//
|
||||
// The structural and diamond tests above hit several rule families at once.
|
||||
// These narrow tests pin each rule family / orientation independently so a
|
||||
// regression in one rule shows up as a single failing test rather than a
|
||||
// confusing diamond mismatch.
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_unary_marker_form() {
|
||||
// Pair-fuse U→U: `a.sin().sqrt()` should be reachable as a marker-bracketed
|
||||
// region containing FusedSin and FusedSqrt (with one FusionStart for `a`).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_to_binary_rhs() {
|
||||
// Pair-fuse U→B (RHS variant): `a + b.sin()`. The unary is on the
|
||||
// binary's B input, so the rule's RHS-orientation version is what fires.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b.sin()).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts (RHS-side unary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_binary_to_binary_rhs() {
|
||||
// Pair-fuse B→B (RHS variant): `c * (a + b)`. The inner binary feeds the
|
||||
// outer binary's B input, exercising the mirror direction of the rule
|
||||
// covered by test_chain_of_binaries_fuses.
|
||||
// See test_chain_of_binaries_fuses for the LUMINAL_FUSION_FAMILIES note.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c * (a + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts (RHS-side inner binary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grow_fe_to_binary_rhs() {
|
||||
// Grow FE→B (RHS variant): `c + (a.sin() + b)`. Once the inner
|
||||
// `a.sin() + b` is fused, the outer `+ c` consumes that FE on its B input
|
||||
// (because we wrote `c + (...)` — `c` is on LHS, FE on RHS), exercising
|
||||
// grow-FE-B-rhs to absorb the outer Add into the same region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c + (a.sin() + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a 3-op fused region of {expected:?} with 3 FusionStarts (grow into RHS), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume two-FE merge shape; numerical correctness preserved"]
|
||||
fn test_merge_two_regions_at_outer_binary() {
|
||||
// Merge: `(sin(a) + b) + (sqrt(c) + d)`. Each side independently pair-fuses
|
||||
// U→B on its own (the unary gives the inner Add a fusion partner that
|
||||
// doesn't pull in the outer Add), so both sides become FEs. The outer Add
|
||||
// then fires merge-FE-FE-Add to collapse them into a single region.
|
||||
// Without the unaries, `(a+b) + (c+d)` would only ever pair-fuse one
|
||||
// inner Add at a time with the outer Add — merge wouldn't have two FEs to
|
||||
// combine because the inner Adds never become singleton FEs on their own.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let d = cx.tensor(8);
|
||||
let _e = ((a.sin() + b) + (c.sqrt() + d)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedAdd", "FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 4),
|
||||
"expected a 5-op merged region (two pair-fused sides combined at outer Add) with \
|
||||
4 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Microbench: time three unfused kernels (`add_k` → `sin_k` → `sqrt_k`)
|
||||
/// vs one fused kernel (`(a + b).sin().sqrt()` in a single launch) on a
|
||||
/// fixed-size input, using CUDA events for device-side timing. Mirrors
|
||||
/// the existing sqrt→recip bench but on the binary-inclusive 3-op DAG
|
||||
/// PR2's region codegen targets.
|
||||
///
|
||||
/// Ignored by default — run with
|
||||
/// `cargo test -p luminal_cuda_lite -- --ignored bench_fused_region_vs_unfused_3op --nocapture`.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn bench_fused_region_vs_unfused_3op() {
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use cudarc::driver::{CudaContext, LaunchConfig, PushKernelArg};
|
||||
|
||||
const N: usize = 1 << 20; // 1M elements
|
||||
const WARMUP: usize = 100;
|
||||
const TRIALS: usize = 2000;
|
||||
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return, // no GPU available, skip
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Inputs in (0, 1] keep `sin` < 1 and `sqrt` well-defined post-add.
|
||||
let host_a: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let host_b: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let d_a = stream.clone_htod(&host_a).unwrap();
|
||||
let d_b = stream.clone_htod(&host_b).unwrap();
|
||||
let mut d_scratch1 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_scratch2 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_out = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
|
||||
let compile = |src: &str, name: &str| {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
module.load_function(name).unwrap()
|
||||
};
|
||||
|
||||
let add_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void add_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = a[i] + b[i];
|
||||
}
|
||||
"#,
|
||||
"add_k",
|
||||
);
|
||||
let sin_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sin_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sinf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sin_k",
|
||||
);
|
||||
let sqrt_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sqrt_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sqrtf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sqrt_k",
|
||||
);
|
||||
let fused_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void fused_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
float v = a[i] + b[i];
|
||||
v = sinf(v);
|
||||
v = sqrtf(v);
|
||||
out[i] = v;
|
||||
}
|
||||
"#,
|
||||
"fused_k",
|
||||
);
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(N as u32);
|
||||
let n_arg: i64 = N as i64;
|
||||
|
||||
let launch_unfused =
|
||||
|d_out: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch1: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch2: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&add_k);
|
||||
b.arg(&mut *d_scratch1).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sin_k);
|
||||
b.arg(&mut *d_scratch2).arg(&*d_scratch1).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sqrt_k);
|
||||
b.arg(d_out).arg(&*d_scratch2).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
let launch_fused = |d_out: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&fused_k);
|
||||
b.arg(d_out).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
|
||||
// Warmup
|
||||
for _ in 0..WARMUP {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Host-side wall-clock timing: synchronize before/after each batch so the
|
||||
// measured interval covers exactly the GPU work for `TRIALS` iterations.
|
||||
// (CUDA event-based timing is the more precise option in principle, but
|
||||
// `event.elapsed_ms` on this driver/cudarc combo errors with
|
||||
// CUDA_ERROR_INVALID_HANDLE — see bench_fused_vs_unfused_sqrt_recip
|
||||
// above which fails the same way. Wall-clock is reliable here.)
|
||||
let unfused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let unfused_total_ms = unfused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let fused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let fused_total_ms = fused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let unfused_us = unfused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let fused_us = fused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let speedup = unfused_us / fused_us;
|
||||
|
||||
println!(
|
||||
"\n[fusion microbench, (a+b).sin().sqrt(), N={N}, trials={TRIALS}]\n\
|
||||
unfused (add_k; sin_k; sqrt_k): {unfused_us:8.3} us/iter ({unfused_total_ms:.2} ms total)\n\
|
||||
fused (one kernel): {fused_us:8.3} us/iter ({fused_total_ms:.2} ms total)\n\
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,6 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod cublaslt_rewrite_tests;
|
||||
#[cfg(test)]
|
||||
mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
@@ -19,8 +15,4 @@ mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod rope_test;
|
||||
#[cfg(test)]
|
||||
mod search_equivalence_fuzz;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
|
||||
//!
|
||||
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
|
||||
//! reference to catch incorrect HLIR kernel rewrites.
|
||||
//!
|
||||
//! These are marked ignored by default because each test builds a model-shaped
|
||||
//! graph and checks many extraction genomes. Run them explicitly with
|
||||
//! `cargo test -p luminal_cuda_lite -- --ignored` when touching extraction,
|
||||
//! scheduling, or model-pattern rewrites.
|
||||
//! reference to catch incorrect HLIR kernel fallback rewrites.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
@@ -305,7 +300,7 @@ fn fuzz_layer_no_attn(
|
||||
}
|
||||
|
||||
/// Test a SwiGLU MLP with HLIR-only to specifically verify
|
||||
/// the HLIR matmul decomposition (elementwise Mul + KernelSumReduce).
|
||||
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
|
||||
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -382,38 +377,32 @@ mod llama {
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
|
||||
}
|
||||
|
||||
/// Force HLIR-only (no block ops) to specifically test that extraction path.
|
||||
/// Force HLIR-only (no block ops) to specifically test the fallback path.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
|
||||
}
|
||||
@@ -435,26 +424,22 @@ mod gemma {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
|
||||
}
|
||||
|
||||
/// Gemma has extra post-attention and post-feedforward norms.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer_full_norms() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -579,14 +564,12 @@ mod gemma {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Gemma dimensions.
|
||||
/// Force HLIR-only to test fallback path with Gemma dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
|
||||
}
|
||||
@@ -608,26 +591,22 @@ mod qwen {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
|
||||
}
|
||||
|
||||
/// Qwen uses tied embeddings: lm_head = embedding^T
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_lm_head() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -689,20 +668,17 @@ mod qwen {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Qwen dimensions.
|
||||
/// Force HLIR-only to test fallback path with Qwen dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
|
||||
}
|
||||
|
||||
@@ -16,16 +16,9 @@ use super::utilities::{
|
||||
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
|
||||
};
|
||||
|
||||
// The property-based op tests each build/search CUDA graphs for multiple random
|
||||
// shapes. They are ignored by default to keep the main CUDA unit suite short;
|
||||
// run `cargo test -p luminal_cuda_lite -- --ignored` for the broader sweeps.
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -35,9 +28,6 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -47,27 +37,18 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_matmul(
|
||||
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
|
||||
@@ -138,8 +119,6 @@ proptest! {
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
@@ -148,9 +127,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
@@ -159,9 +135,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -169,9 +142,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
@@ -179,9 +149,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
@@ -190,17 +157,12 @@ proptest! {
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
test_mod(x, x, |a, b| a % b, seed);
|
||||
test_mod((y, x), (y, x), |a, b| a % b, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
@@ -373,8 +335,6 @@ proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
|
||||
use luminal::dtype::DType;
|
||||
@@ -567,9 +527,6 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(3))]
|
||||
|
||||
// This walks random extraction genomes and is intentionally opt-in so the
|
||||
// default CUDA unit suite keeps a tight feedback loop.
|
||||
#[ignore = "expensive CUDA genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
#[test]
|
||||
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
|
||||
fuzz_test_cuda_genomes_impl(seed);
|
||||
@@ -637,9 +594,6 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_embed_proptest(
|
||||
vocab_size in 10usize..200,
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Test that measures bandwidth utilization for a large element-wise add kernel.
|
||||
/// This demonstrates that generic fused Add can achieve reasonable bandwidth with large tensors.
|
||||
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
|
||||
#[test]
|
||||
pub fn kernel_add_bandwidth_test() {
|
||||
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
|
||||
@@ -40,7 +40,7 @@ pub fn kernel_add_bandwidth_test() {
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Print stats
|
||||
println!("\n=== Large Fused Add Bandwidth Test ===");
|
||||
println!("\n=== Large KernelAdd Bandwidth Test ===");
|
||||
println!(
|
||||
"Tensor size: {} elements ({} MB per tensor)",
|
||||
size,
|
||||
|
||||
@@ -3,15 +3,18 @@ use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 12;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
@@ -58,7 +61,6 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_values = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
@@ -72,9 +74,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -131,9 +133,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
@@ -174,9 +176,10 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
@@ -271,7 +274,7 @@ fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -293,7 +296,7 @@ fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::{graph::Graph, op::Runtime};
|
||||
|
||||
use crate::{kernel::apply_rope, runtime::CudaRuntime};
|
||||
|
||||
fn cpu_rope(x: &[f32], cos: &[f32], sin: &[f32], s: usize, h: usize, d: usize) -> Vec<f32> {
|
||||
assert!(d.is_multiple_of(2));
|
||||
let mut out = vec![0.0f32; s * h * d];
|
||||
for si in 0..s {
|
||||
for hi in 0..h {
|
||||
for i in 0..d {
|
||||
let xi = x[si * h * d + hi * d + i];
|
||||
let xpair = if i % 2 == 0 {
|
||||
-x[si * h * d + hi * d + i + 1]
|
||||
} else {
|
||||
x[si * h * d + hi * d + i - 1]
|
||||
};
|
||||
let c = cos[si * d + i];
|
||||
let sn = sin[si * d + i];
|
||||
out[si * h * d + hi * d + i] = xi * c + xpair * sn;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_matches_cpu_reference() {
|
||||
let s = 8;
|
||||
let h = 4;
|
||||
let d = 32;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
let x_data: Vec<f32> = (0..s * h * d).map(|i| ((i as f32) * 0.013).sin()).collect();
|
||||
let cos_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).cos()).collect();
|
||||
let sin_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).sin()).collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-5, "max abs error {max_err} too high");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_flux2_shape() {
|
||||
// Flux 2 transformer attention: S=1536 (img+txt), H=48, D=128.
|
||||
let s = 1536;
|
||||
let h = 48;
|
||||
let d = 128;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(11);
|
||||
let x_data: Vec<f32> = (0..s * h * d)
|
||||
.map(|_| rng.random_range(-2.0..2.0_f32))
|
||||
.collect();
|
||||
let cos_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
let sin_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope flux2: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-4, "max abs error {max_err} too high");
|
||||
}
|
||||
@@ -1,374 +0,0 @@
|
||||
//! End-to-end e-graph search-space equivalence fuzz tests.
|
||||
//!
|
||||
//! These tests do not compare against a hand-written reference. They assert the
|
||||
//! stronger search invariant: every selectable LLIR graph from the same e-graph
|
||||
//! must produce the same outputs for the same runtime inputs.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[path = "../../../../examples/llama/src/model.rs"]
|
||||
mod llama_model;
|
||||
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
|
||||
use super::utilities::{CudaSearchEquivalenceFuzzer, get_cuda_stream, random_f32_vec};
|
||||
|
||||
const SEARCH_EQUIV_SAMPLES: usize = 32;
|
||||
|
||||
fn random_bf16_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<bf16> {
|
||||
random_f32_vec(n, seed, low, high)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llama_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const CTX: usize = 3;
|
||||
const SLOTS: usize = 4;
|
||||
|
||||
let config = llama_model::LlamaConfig {
|
||||
layers: 2,
|
||||
hidden: 32,
|
||||
intermediate: 64,
|
||||
head_dim: 8,
|
||||
kv_groups: 2,
|
||||
vocab_size: 64,
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim('s', SEQ);
|
||||
cx.set_dim('c', CTX);
|
||||
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = llama_model::KVCache::new_with_config(&mut cx, SLOTS, config);
|
||||
let llama = llama_model::Llama::init_with_config(&mut cx, config);
|
||||
|
||||
let (logits, cache_outputs) =
|
||||
llama.forward(input, q_pos, scatter_idx, gather_idx, attn_mask, &kv_cache);
|
||||
let logits = logits.output();
|
||||
let mut fuzzer = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x5EED_1234)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 3e-3, 3e-3);
|
||||
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
|
||||
let k_out = k_out.output();
|
||||
let v_out = v_out.output();
|
||||
fuzzer = fuzzer.output_f32(k_out.id, format!("layer{layer}.k_cache"), 3e-3, 3e-3);
|
||||
fuzzer = fuzzer.output_f32(v_out.id, format!("layer{layer}.v_cache"), 3e-3, 3e-3);
|
||||
}
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0x11A_AA55);
|
||||
fuzzer = fuzzer
|
||||
.input_i32(input.id, vec![3, 17])
|
||||
.input_i32(q_pos.id, vec![1, 2])
|
||||
.input_i32(scatter_idx.id, vec![1, 2])
|
||||
.input_i32(gather_idx.id, vec![0, 1, 2])
|
||||
.input_f32(attn_mask.id, vec![0.0, 0.0, -1e4, 0.0, 0.0, 0.0]);
|
||||
|
||||
let kv_dim = config.kv_dim();
|
||||
for tensor in kv_cache.tensors() {
|
||||
fuzzer = fuzzer.input_f32(tensor.id, vec![0.0; SLOTS * kv_dim]);
|
||||
}
|
||||
for tensor in llama.parameter_tensors() {
|
||||
let elements = tensor
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|dim| dim.to_usize().expect("tiny llama test uses static params"))
|
||||
.product::<usize>();
|
||||
let data = (0..elements)
|
||||
.map(|_| rng.random_range(-0.08f32..0.08f32))
|
||||
.collect::<Vec<_>>();
|
||||
fuzzer = fuzzer.input_f32(tensor.id, data);
|
||||
}
|
||||
|
||||
let report = fuzzer.run();
|
||||
eprintln!("llama search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemma_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
const Q_DIM: usize = 24;
|
||||
const INTERMEDIATE: usize = 64;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let attn_norm_w = cx.tensor(HIDDEN);
|
||||
let post_attn_norm_w = cx.tensor(HIDDEN);
|
||||
let pre_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let post_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let proj_w = cx.tensor((Q_DIM, HIDDEN));
|
||||
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
|
||||
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
|
||||
let normed = rms_norm(input, attn_norm_w, EPS);
|
||||
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
|
||||
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
|
||||
let x = input + attn_normed;
|
||||
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
|
||||
let mlp_out =
|
||||
(gemma_gelu(ff_normed.matmul(w_gate.t())) * ff_normed.matmul(w_up.t())).matmul(w_down.t());
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x6E4D_4DAA)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 101, -0.15, 0.15))
|
||||
.input_f32(attn_norm_w.id, random_f32_vec(HIDDEN, 102, 0.7, 1.3))
|
||||
.input_f32(post_attn_norm_w.id, random_f32_vec(HIDDEN, 103, 0.7, 1.3))
|
||||
.input_f32(pre_ff_norm_w.id, random_f32_vec(HIDDEN, 104, 0.7, 1.3))
|
||||
.input_f32(post_ff_norm_w.id, random_f32_vec(HIDDEN, 105, 0.7, 1.3))
|
||||
.input_f32(proj_w.id, random_f32_vec(Q_DIM * HIDDEN, 106, -0.08, 0.08))
|
||||
.input_f32(
|
||||
o_proj_w.id,
|
||||
random_f32_vec(HIDDEN * Q_DIM, 107, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_gate.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 108, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_up.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 109, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_down.id,
|
||||
random_f32_vec(HIDDEN * INTERMEDIATE, 110, -0.08, 0.08),
|
||||
)
|
||||
.output_f32(out.id, "gemma_block", 5e-3, 5e-3)
|
||||
.run();
|
||||
eprintln!("gemma search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x0DEE_55EE)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.input_f32(
|
||||
router_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 201, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(
|
||||
expert_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 202, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(router_scale.id, random_f32_vec(HIDDEN, 203, 0.7, 1.3))
|
||||
.input_f32(
|
||||
router_proj.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 204, -0.2, 0.2),
|
||||
)
|
||||
.input_f32(
|
||||
per_expert_scale.id,
|
||||
random_f32_vec(NUM_EXPERTS, 205, 0.5, 1.5),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 206, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 207, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "gemma_moe_block", 5e-2, 5e-2)
|
||||
.run();
|
||||
eprintln!("moe search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_native_reference_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = input.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = input.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_weights = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let input_exp = input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = input_exp
|
||||
.matmul(gate_up_gathered.transpose(2, 3))
|
||||
.squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x51A7_E5ED)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.native_reference()
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 301, -0.15, 0.15))
|
||||
.input_f32(
|
||||
router.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 302, -0.2, 0.2),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 303, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 304, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "qwen_swiglu_moe_native_reference", 6e-2, 6e-2)
|
||||
.run();
|
||||
eprintln!("moe native-reference fuzz report: {report:?}");
|
||||
}
|
||||
@@ -2,8 +2,7 @@ use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use luminal::egglog_utils::{
|
||||
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
|
||||
validate_choice_set,
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use num_traits::{Num, Signed};
|
||||
@@ -129,399 +128,6 @@ pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CudaFuzzInput {
|
||||
F32(NodeIndex, Vec<f32>),
|
||||
Bf16(NodeIndex, Vec<bf16>),
|
||||
I32(NodeIndex, Vec<i32>),
|
||||
}
|
||||
|
||||
impl CudaFuzzInput {
|
||||
fn apply(&self, rt: &mut CudaRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_native(&self, rt: &mut NativeRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct F32OutputCheck {
|
||||
pub id: NodeIndex,
|
||||
pub name: String,
|
||||
pub rtol: f32,
|
||||
pub atol: f32,
|
||||
}
|
||||
|
||||
impl F32OutputCheck {
|
||||
pub fn new(id: NodeIndex, name: impl Into<String>, rtol: f32, atol: f32) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name: name.into(),
|
||||
rtol,
|
||||
atol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchEquivalenceFuzzConfig {
|
||||
pub seed: u64,
|
||||
pub samples: usize,
|
||||
pub generation_size: usize,
|
||||
pub mutations: usize,
|
||||
pub max_attempts: usize,
|
||||
pub build_options: BuildSearchSpaceOptions,
|
||||
pub reference: SearchEquivalenceReference,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SearchEquivalenceReference {
|
||||
FirstCudaExtraction,
|
||||
NativeRuntime,
|
||||
}
|
||||
|
||||
impl Default for SearchEquivalenceFuzzConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
seed: 0,
|
||||
samples: 32,
|
||||
generation_size: 16,
|
||||
mutations: 2,
|
||||
max_attempts: 1_000,
|
||||
build_options: BuildSearchSpaceOptions::default(),
|
||||
reference: SearchEquivalenceReference::FirstCudaExtraction,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SearchEquivalenceFuzzReport {
|
||||
pub tested: usize,
|
||||
pub skipped_invalid: usize,
|
||||
}
|
||||
|
||||
pub struct CudaSearchEquivalenceFuzzer<'a> {
|
||||
cx: &'a mut Graph,
|
||||
stream: &'a Arc<cudarc::driver::CudaStream>,
|
||||
inputs: Vec<CudaFuzzInput>,
|
||||
outputs: Vec<F32OutputCheck>,
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
}
|
||||
|
||||
impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
pub fn new(cx: &'a mut Graph, stream: &'a Arc<cudarc::driver::CudaStream>) -> Self {
|
||||
Self {
|
||||
cx,
|
||||
stream,
|
||||
inputs: Vec::new(),
|
||||
outputs: Vec::new(),
|
||||
config: SearchEquivalenceFuzzConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.config.seed = seed;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn samples(mut self, samples: usize) -> Self {
|
||||
self.config.samples = samples;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn generation_size(mut self, generation_size: usize) -> Self {
|
||||
self.config.generation_size = generation_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mutations(mut self, mutations: usize) -> Self {
|
||||
self.config.mutations = mutations;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_options(mut self, build_options: BuildSearchSpaceOptions) -> Self {
|
||||
self.config.build_options = build_options;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn native_reference(mut self) -> Self {
|
||||
self.config.reference = SearchEquivalenceReference::NativeRuntime;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_f32(mut self, id: NodeIndex, data: Vec<f32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::F32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_bf16(mut self, id: NodeIndex, data: Vec<bf16>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::Bf16(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_i32(mut self, id: NodeIndex, data: Vec<i32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::I32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn output_f32(
|
||||
mut self,
|
||||
id: NodeIndex,
|
||||
name: impl Into<String>,
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
) -> Self {
|
||||
self.outputs.push(F32OutputCheck::new(id, name, rtol, atol));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn run(self) -> SearchEquivalenceFuzzReport {
|
||||
fuzz_cuda_search_space_equivalence(
|
||||
self.cx,
|
||||
self.stream,
|
||||
&self.inputs,
|
||||
&self.outputs,
|
||||
self.config,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// End-to-end search-space equivalence fuzzing for CUDA.
|
||||
///
|
||||
/// This builds the normal CUDA e-graph search space, extracts random selectable
|
||||
/// LLIR graphs, runs each with identical inputs, and verifies every requested
|
||||
/// f32 output matches the first valid extraction. The reference is intentionally
|
||||
/// another selected LLIR graph, not a hand-written CPU implementation: this
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge.
|
||||
pub fn fuzz_cuda_search_space_equivalence(
|
||||
cx: &mut Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
) -> SearchEquivalenceFuzzReport {
|
||||
assert!(
|
||||
!outputs.is_empty(),
|
||||
"fuzz harness needs at least one output"
|
||||
);
|
||||
|
||||
let native_reference_outputs = if config.reference == SearchEquivalenceReference::NativeRuntime
|
||||
{
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_options(
|
||||
NativeRuntime::default(),
|
||||
SearchOptions::new(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
input.apply_native(&mut native_rt);
|
||||
}
|
||||
native_rt.execute(&cx.dyn_map);
|
||||
Some(
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| native_rt.get_f32(out.id).clone())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
cx.build_search_space_with_options::<CudaRuntime>(config.build_options);
|
||||
|
||||
let egraph = cx.egraph().expect("search space should be built");
|
||||
let ops = cx.egglog_ops().expect("search ops should be built");
|
||||
let seed = if native_reference_outputs.is_some() {
|
||||
config.seed.wrapping_add(0xC0DA_C0DA)
|
||||
} else {
|
||||
config.seed
|
||||
};
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let mut prev_selected = FxHashSet::default();
|
||||
let mut base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
|
||||
let mut skipped_invalid = 0usize;
|
||||
let reference_is_cuda = native_reference_outputs.is_none();
|
||||
let (reference_hash, reference_outputs, mut tested) =
|
||||
if let Some(reference_outputs) = native_reference_outputs {
|
||||
(0, reference_outputs, 0usize)
|
||||
} else {
|
||||
let mut attempts = 0usize;
|
||||
let (reference_hash, reference_outputs) = loop {
|
||||
attempts += 1;
|
||||
if attempts > config.max_attempts {
|
||||
panic!(
|
||||
"failed to extract a valid reference LLIR after {} attempts",
|
||||
config.max_attempts
|
||||
);
|
||||
}
|
||||
if validate_choice_set(egraph, &base, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
} else {
|
||||
let hash = hash_choice_set(&base);
|
||||
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
|
||||
Ok(values) => break (hash, values),
|
||||
Err(err) => {
|
||||
skipped_invalid += 1;
|
||||
eprintln!("skipping invalid reference candidate hash={hash}: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
};
|
||||
(reference_hash, reference_outputs, 1usize)
|
||||
};
|
||||
|
||||
let mut attempts = 0usize;
|
||||
while tested < config.samples && attempts < config.max_attempts {
|
||||
attempts += 1;
|
||||
let mut candidates = extract_generation(
|
||||
egraph,
|
||||
&base,
|
||||
config.generation_size,
|
||||
config.mutations,
|
||||
&mut prev_selected,
|
||||
&mut rng,
|
||||
);
|
||||
if candidates.is_empty() {
|
||||
let next = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&next));
|
||||
candidates.push(next);
|
||||
}
|
||||
|
||||
for candidate in candidates {
|
||||
if tested >= config.samples {
|
||||
break;
|
||||
}
|
||||
let candidate_hash = hash_choice_set(&candidate);
|
||||
if reference_is_cuda && candidate_hash == reference_hash {
|
||||
continue;
|
||||
}
|
||||
if validate_choice_set(egraph, &candidate, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate_outputs = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
|
||||
assert_fuzz_outputs_close(
|
||||
outputs,
|
||||
&reference_outputs,
|
||||
&candidate_outputs,
|
||||
reference_hash,
|
||||
candidate_hash,
|
||||
);
|
||||
base = candidate;
|
||||
tested += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
tested, config.samples,
|
||||
"only tested {tested}/{} LLIR samples before exhausting attempts",
|
||||
config.samples
|
||||
);
|
||||
SearchEquivalenceFuzzReport {
|
||||
tested,
|
||||
skipped_invalid,
|
||||
}
|
||||
}
|
||||
|
||||
fn run_choice_outputs<'a>(
|
||||
cx: &'a Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
) -> Result<Vec<Vec<f32>>, String> {
|
||||
let egraph = cx.egraph().ok_or("search space was not built")?;
|
||||
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let mut llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
choices.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
for input in inputs {
|
||||
input.apply(&mut rt);
|
||||
}
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
Ok(outputs.iter().map(|out| rt.get_f32(out.id)).collect())
|
||||
}
|
||||
|
||||
fn assert_fuzz_outputs_close(
|
||||
outputs: &[F32OutputCheck],
|
||||
expected: &[Vec<f32>],
|
||||
actual: &[Vec<f32>],
|
||||
reference_hash: u64,
|
||||
candidate_hash: u64,
|
||||
) {
|
||||
for ((spec, expected), actual) in outputs.iter().zip(expected.iter()).zip(actual.iter()) {
|
||||
assert_eq!(
|
||||
expected.len(),
|
||||
actual.len(),
|
||||
"output {} length mismatch for candidate hash={candidate_hash} reference hash={reference_hash}",
|
||||
spec.name
|
||||
);
|
||||
let mut max_abs = 0.0f32;
|
||||
let mut max_rel = 0.0f32;
|
||||
let mut worst = 0usize;
|
||||
for (i, (&a, &b)) in actual.iter().zip(expected.iter()).enumerate() {
|
||||
assert!(
|
||||
a.is_finite(),
|
||||
"output {} candidate hash={candidate_hash} produced non-finite value {a} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
assert!(
|
||||
b.is_finite(),
|
||||
"output {} reference hash={reference_hash} produced non-finite value {b} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
let abs = (a - b).abs();
|
||||
let rel = abs / b.abs().max(1e-12);
|
||||
if abs > max_abs {
|
||||
max_abs = abs;
|
||||
max_rel = rel;
|
||||
worst = i;
|
||||
}
|
||||
if abs > spec.atol + spec.rtol * b.abs() {
|
||||
panic!(
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={}",
|
||||
spec.name,
|
||||
spec.atol + spec.rtol * b.abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"fuzz output {} ok: candidate hash={candidate_hash} max_abs={max_abs} max_rel={max_rel} worst={worst}",
|
||||
spec.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
@@ -530,15 +136,14 @@ pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
|
||||
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
|
||||
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
|
||||
let Some((major, minor)) = gpu_compute_cap() else {
|
||||
let Some((major, _)) = gpu_compute_cap() else {
|
||||
return false;
|
||||
};
|
||||
match dtype {
|
||||
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
|
||||
luminal::dtype::DType::F8E4M3 | luminal::dtype::DType::F8E5M2 => {
|
||||
major > 8 || (major == 8 && minor >= 9)
|
||||
} // Ada/Hopper (sm_89+)
|
||||
luminal::dtype::DType::F4E2M1 | luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
luminal::dtype::DType::F4E2M1
|
||||
| luminal::dtype::DType::F8E4M3
|
||||
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
[package]
|
||||
name = "luminal_metal"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
description = "Metal backend for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
metal = { version = "0.31", features = ["mps"] }
|
||||
metal = "0.31"
|
||||
objc = "0.2"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
half = "2.7.1"
|
||||
tracing = "0.1.43"
|
||||
safetensors = "0.7.0"
|
||||
memmap2 = "0.9.9"
|
||||
bytemuck = "1.24.0"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, bytes_to_native_data, compile_backend};
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
|
||||
@@ -1,5 +1,227 @@
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MPSMatrixLayout {
|
||||
RowMajor,
|
||||
TransposedRowMajor,
|
||||
use super::{MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum MetalMatmulFamily {
|
||||
#[default]
|
||||
Naive,
|
||||
RegularTiled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulDescriptor {
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub batch_shape: Vec<Expression>,
|
||||
pub lhs_strides: Vec<Expression>,
|
||||
pub rhs_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub transpose_lhs: bool,
|
||||
pub transpose_rhs: bool,
|
||||
}
|
||||
|
||||
impl MatmulDescriptor {
|
||||
pub fn from_mul_and_sum(
|
||||
mul_info: &MetalMulInfo,
|
||||
sum_info: &MetalSumReduceInfo,
|
||||
) -> Option<Self> {
|
||||
let zero = Expression::from(0);
|
||||
let z = Expression::from('z');
|
||||
|
||||
let is_simple_2d_matmul = mul_info.shape.len() == 3
|
||||
&& sum_info.shape.len() == 2
|
||||
&& mul_info.a_strides.len() == 3
|
||||
&& mul_info.b_strides.len() == 3
|
||||
&& sum_info.strides.len() == 2
|
||||
&& mul_info.shape[0] == sum_info.shape[0]
|
||||
&& mul_info.shape[1] == sum_info.shape[1]
|
||||
&& mul_info.shape[2] == sum_info.iters
|
||||
&& mul_info.a_strides[1] == zero
|
||||
&& mul_info.a_strides[2] == z
|
||||
&& mul_info.b_strides[0] == zero
|
||||
&& mul_info.b_strides[1] == z
|
||||
&& sum_info.strides[1] == z
|
||||
&& sum_info.iter_stride == z;
|
||||
|
||||
if !is_simple_2d_matmul {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
m: sum_info.shape[0],
|
||||
n: sum_info.shape[1],
|
||||
k: sum_info.iters,
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: mul_info.a_strides.clone(),
|
||||
rhs_strides: mul_info.b_strides.clone(),
|
||||
out_strides: sum_info.strides.clone(),
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulPlan {
|
||||
pub family: MetalMatmulFamily,
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub lda: Expression,
|
||||
pub ldb: Expression,
|
||||
pub ldd: Expression,
|
||||
pub batch_size: u32,
|
||||
pub batch_stride_a: u32,
|
||||
pub batch_stride_b: u32,
|
||||
pub batch_stride_d: u32,
|
||||
pub bm: u16,
|
||||
pub bn: u16,
|
||||
pub bk: u16,
|
||||
pub wm: u16,
|
||||
pub wn: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct MetalMatmulPlanner;
|
||||
|
||||
impl MetalMatmulPlanner {
|
||||
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
|
||||
let family = if desc.batch_shape.is_empty()
|
||||
&& desc.m.as_num().is_some_and(|m| m >= 32)
|
||||
&& desc.n.as_num().is_some_and(|n| n >= 32)
|
||||
&& desc.k.as_num().is_some_and(|k| k >= 32)
|
||||
{
|
||||
MetalMatmulFamily::RegularTiled
|
||||
} else {
|
||||
MetalMatmulFamily::Naive
|
||||
};
|
||||
MatmulPlan {
|
||||
family,
|
||||
m: desc.m,
|
||||
n: desc.n,
|
||||
k: desc.k,
|
||||
lda: desc.lhs_strides[0],
|
||||
ldb: desc.rhs_strides[2],
|
||||
ldd: desc.out_strides[0],
|
||||
batch_size: 1,
|
||||
batch_stride_a: 0,
|
||||
batch_stride_b: 0,
|
||||
batch_stride_d: 0,
|
||||
bm: 16,
|
||||
bn: 16,
|
||||
bk: 8,
|
||||
wm: 2,
|
||||
wn: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn descriptor_recovers_simple_2d_matmul() {
|
||||
let mul = MetalMulInfo {
|
||||
shape: vec![
|
||||
Expression::from(4),
|
||||
Expression::from(8),
|
||||
Expression::from(16),
|
||||
],
|
||||
a_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
b_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
output_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from('z') * 8,
|
||||
Expression::from('z'),
|
||||
],
|
||||
};
|
||||
let sum = MetalSumReduceInfo {
|
||||
shape: vec![Expression::from(4), Expression::from(8)],
|
||||
strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
iters: Expression::from(16),
|
||||
iter_stride: Expression::from('z'),
|
||||
};
|
||||
|
||||
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
|
||||
assert_eq!(desc.m, Expression::from(4));
|
||||
assert_eq!(desc.n, Expression::from(8));
|
||||
assert_eq!(desc.k, Expression::from(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_keeps_small_problems_on_naive_path() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(4),
|
||||
n: Expression::from(8),
|
||||
k: Expression::from(16),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::Naive);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
assert_eq!(plan.lda, Expression::from('z') * 16);
|
||||
assert_eq!(plan.ldb, Expression::from('z') * 8);
|
||||
assert_eq!(plan.ldd, Expression::from('z') * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_promotes_large_problems_to_regular_tiled() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(64),
|
||||
n: Expression::from(64),
|
||||
k: Expression::from(64),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 64,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 64,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ pub use ops::*;
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
@@ -32,7 +32,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
device: &Device,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) -> Option<ComputePipelineState>;
|
||||
) -> ComputePipelineState;
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.first().copied().unwrap_or(DType::F32)
|
||||
@@ -40,7 +40,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
fn encode_compute(
|
||||
fn encode(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
pipeline: &ComputePipelineState,
|
||||
@@ -49,26 +49,6 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
);
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
dyn_buffer: &Buffer,
|
||||
_input_dtypes: &[DType],
|
||||
_output_dtype: DType,
|
||||
) {
|
||||
let pipeline = pipeline.expect("compute pipeline not compiled");
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let dyn_idx = inputs.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(dyn_buffer), 0);
|
||||
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Performance Metrics for MBU/MFU Calculation
|
||||
// ========================================================================
|
||||
@@ -93,10 +73,6 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_matmul(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,31 +1,21 @@
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalKernelOp};
|
||||
use half::{bf16, f16};
|
||||
use crate::kernel::{
|
||||
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
|
||||
};
|
||||
use half::f16;
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::{BucketLLIR, DimBucket, Graph, LLIRGraph},
|
||||
graph::LLIRGraph,
|
||||
hlir::{Input, NativeData, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
prelude::{
|
||||
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
|
||||
FxHashMap, NodeIndex, ToId,
|
||||
petgraph::{Direction, algo::toposort, prelude::StableGraph, visit::EdgeRef},
|
||||
},
|
||||
};
|
||||
use memmap2::MmapOptions;
|
||||
use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptions};
|
||||
use objc::rc::autoreleasepool;
|
||||
use objc::runtime::Object;
|
||||
use safetensors::{Dtype, SafeTensors};
|
||||
use std::{fs::File, time::Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalCompiledBucket {
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: LLIRGraph,
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
}
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct MetalRuntime {
|
||||
device: Device,
|
||||
@@ -44,124 +34,83 @@ pub struct MetalRuntime {
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
/// LLIR output node -> input node whose buffer contains the output.
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Bucket definitions for dynamic dimensions.
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
/// Compiled LLIR variants, one per bucket combination.
|
||||
compiled_buckets: Vec<MetalCompiledBucket>,
|
||||
/// Currently active compiled bucket.
|
||||
active_bucket: usize,
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn input_dtype(&self, id: NodeIndex) -> Option<DType> {
|
||||
self.llir_graph.node_indices().find_map(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_op::<Input>()
|
||||
.and_then(|input| (input.node == id.index()).then_some(input.dtype))
|
||||
})
|
||||
}
|
||||
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
|
||||
let mut graph = llir_graph.clone();
|
||||
let planner = MetalMatmulPlanner;
|
||||
let mut rewrites = Vec::new();
|
||||
|
||||
fn output_data_node(&self, id: NodeIndex) -> NodeIndex {
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
for sum_node in graph.node_indices().collect::<Vec<_>>() {
|
||||
let Some(sum_info) = graph[sum_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.sum_reduce_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
self.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap()
|
||||
}
|
||||
let input_edges: Vec<_> = graph
|
||||
.edges_directed(sum_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if input_edges.len() != 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
fn follow_aliases(&self, mut node: NodeIndex) -> NodeIndex {
|
||||
while let Some(target) = self.output_alias_map.get(&node) {
|
||||
node = *target;
|
||||
let mul_node = input_edges[0];
|
||||
let Some(mul_info) = graph[mul_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.mul_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mul_inputs: Vec<_> = graph
|
||||
.edges_directed(mul_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
|
||||
}
|
||||
node
|
||||
}
|
||||
|
||||
fn buffer_for_llir_node<'a>(
|
||||
&'a self,
|
||||
node: NodeIndex,
|
||||
llir_to_hlir: &FxHashMap<NodeIndex, NodeIndex>,
|
||||
) -> &'a Buffer {
|
||||
let data_node = self.follow_aliases(node);
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&data_node) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&data_node)
|
||||
.expect("Intermediate buffer not found!")
|
||||
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
|
||||
graph[sum_node] =
|
||||
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
|
||||
m: plan.m,
|
||||
n: plan.n,
|
||||
k: plan.k,
|
||||
lda: plan.lda,
|
||||
ldb: plan.ldb,
|
||||
ldd: plan.ldd,
|
||||
family: plan.family,
|
||||
bm: plan.bm,
|
||||
bn: plan.bn,
|
||||
bk: plan.bk,
|
||||
wm: plan.wm,
|
||||
wn: plan.wn,
|
||||
batch_size: plan.batch_size,
|
||||
batch_stride_a: plan.batch_stride_a,
|
||||
batch_stride_b: plan.batch_stride_b,
|
||||
batch_stride_d: plan.batch_stride_d,
|
||||
}));
|
||||
|
||||
graph.remove_node(mul_node);
|
||||
graph.add_edge(mul_inputs[0], sum_node, ());
|
||||
graph.add_edge(mul_inputs[1], sum_node, ());
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_from_slice<T>(&self, values: &[T]) -> Buffer {
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
graph
|
||||
}
|
||||
|
||||
fn buffer_from_safetensor(
|
||||
&self,
|
||||
tensor: &safetensors::tensor::TensorView<'_>,
|
||||
dtype: DType,
|
||||
) -> Buffer {
|
||||
match (tensor.dtype(), dtype) {
|
||||
(Dtype::F32, DType::F32) | (Dtype::F16, DType::F16) => {
|
||||
let data = tensor.data();
|
||||
self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const _,
|
||||
data.len() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
(Dtype::F16, DType::F32) => {
|
||||
let values: Vec<f32> = bytemuck::cast_slice::<u8, f16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::BF16, DType::F32) => {
|
||||
let values: Vec<f32> = bytemuck::cast_slice::<u8, bf16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::F32, DType::F16) => {
|
||||
let values: Vec<f16> = bytemuck::cast_slice::<u8, f32>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::BF16, DType::F16) => {
|
||||
let values: Vec<f16> = bytemuck::cast_slice::<u8, bf16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(v.to_f32()))
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(tensor_dtype, dtype) => {
|
||||
panic!("Cannot load safetensor dtype {tensor_dtype:?} into Metal dtype {dtype:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn contains_matmul(&self) -> bool {
|
||||
self.llir_graph.node_indices().any(|node| {
|
||||
@@ -183,69 +132,29 @@ impl MetalRuntime {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_safetensors(&mut self, cx: &Graph, file_path: &str) {
|
||||
let f = File::open(file_path).unwrap();
|
||||
let mmap = unsafe { MmapOptions::new().map(&f).unwrap() };
|
||||
let st = SafeTensors::deserialize(&mmap).unwrap();
|
||||
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node]).as_any().downcast_ref::<Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let buffer = self.buffer_from_safetensor(&tensor, input.dtype);
|
||||
self.input_data.remove(&node);
|
||||
self.hlir_buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
|
||||
let id = id.to_id();
|
||||
let data = data.into();
|
||||
if let Some(dtype) = self.input_dtype(id) {
|
||||
let buffer = self.create_input_buffer(&data, dtype);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
}
|
||||
self.input_data.insert(id, data);
|
||||
}
|
||||
|
||||
pub fn set_zeros(&mut self, id: impl ToId, num_bytes: usize) {
|
||||
let id = id.to_id();
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(num_bytes as u64, MTLResourceOptions::StorageModeShared);
|
||||
unsafe {
|
||||
std::ptr::write_bytes(buffer.contents(), 0, num_bytes);
|
||||
}
|
||||
self.input_data.remove(&id);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
}
|
||||
|
||||
pub fn remove_buffer(&mut self, id: impl ToId) -> Buffer {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
|
||||
if let Some(buffer) = self.buffers.remove(&data_id) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if let Some(Input { node, .. }) = self.llir_graph[data_id].to_op::<Input>() {
|
||||
return self
|
||||
.hlir_buffers
|
||||
.remove(&NodeIndex::new(*node))
|
||||
.expect("Cannot find input tensor in runtime!");
|
||||
}
|
||||
|
||||
panic!("Cannot find tensor in runtime!");
|
||||
}
|
||||
|
||||
pub fn set_buffer(&mut self, id: impl ToId, buffer: Buffer) {
|
||||
let id = id.to_id();
|
||||
self.input_data.remove(&id);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
self.input_data.insert(id.to_id(), data.into());
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
let id = id.to_id();
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
|
||||
let data_id = self
|
||||
.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
let buffer = self
|
||||
.buffers
|
||||
@@ -322,10 +231,6 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: StableGraph::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
output_alias_map: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
compiled_buckets: vec![],
|
||||
active_bucket: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,10 +240,50 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
self.buffers.clear();
|
||||
self.dim_buckets.clear();
|
||||
self.compiled_buckets = vec![self.compile_bucket(FxHashMap::default(), llir_graph)];
|
||||
self.activate_bucket(0);
|
||||
self.hlir_buffers.clear();
|
||||
self.node_dtypes.clear();
|
||||
self.llir_graph = Self::fuse_matmuls(llir_graph);
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
self.node_dtypes.insert(node, input.dtype);
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -347,7 +292,6 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
@@ -366,105 +310,73 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) -> Self::ExecReturn {
|
||||
autoreleasepool(|| {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
});
|
||||
}
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
fn clear_intermediate_buffers(&mut self) {
|
||||
self.buffers.clear();
|
||||
}
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
fn load_llir_buckets(
|
||||
&mut self,
|
||||
dim_buckets: &FxHashMap<char, Vec<DimBucket>>,
|
||||
bucket_llirs: &[BucketLLIR],
|
||||
) {
|
||||
self.buffers.clear();
|
||||
self.dim_buckets = dim_buckets.clone();
|
||||
self.compiled_buckets = bucket_llirs
|
||||
.iter()
|
||||
.map(|(bucket_indices, _, llir)| self.compile_bucket(bucket_indices.clone(), llir))
|
||||
.collect();
|
||||
assert!(
|
||||
!self.compiled_buckets.is_empty(),
|
||||
"Metal runtime received no bucketed LLIRs"
|
||||
);
|
||||
self.activate_bucket(0);
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
// Bind dyn dims right after the output slot:
|
||||
// [inputs..., output, dyn, bytes...]
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,164 +437,23 @@ impl MetalRuntime {
|
||||
}
|
||||
|
||||
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
}
|
||||
|
||||
fn allocate_active_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let mut planned = Vec::new();
|
||||
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
if kernel_op.output_aliases_input().is_some() {
|
||||
continue;
|
||||
}
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let bytes = (size * dtype.bits().div_ceil(8)) as u64;
|
||||
let needs_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.is_none_or(|buffer| buffer.length() != bytes);
|
||||
|
||||
planned.push((node, bytes, needs_buffer));
|
||||
}
|
||||
}
|
||||
|
||||
for (node, bytes, needs_buffer) in planned {
|
||||
if needs_buffer {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
|
||||
let buffer = self.device.new_buffer(
|
||||
(size * dtype.bits().div_ceil(8)) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_bucket(
|
||||
&self,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: &LLIRGraph,
|
||||
) -> MetalCompiledBucket {
|
||||
let mut node_dtypes = FxHashMap::default();
|
||||
let mut pipelines = FxHashMap::default();
|
||||
let mut output_alias_map = FxHashMap::default();
|
||||
let llir_graph = llir_graph.clone();
|
||||
|
||||
let topo_order = toposort(&llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = llir_graph[node].to_op::<Input>() {
|
||||
node_dtypes.insert(node, input.dtype);
|
||||
continue;
|
||||
}
|
||||
|
||||
if llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
node_dtypes.insert(node, output_dtype);
|
||||
if let Some(pipeline) = pipeline {
|
||||
pipelines.insert(node, pipeline);
|
||||
}
|
||||
if let Some(input_idx) = kernel_op.output_aliases_input()
|
||||
&& let Some(target) = input_nodes.get(input_idx).copied()
|
||||
{
|
||||
output_alias_map.insert(node, target);
|
||||
}
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
|
||||
MetalCompiledBucket {
|
||||
bucket_indices,
|
||||
llir_graph,
|
||||
node_dtypes,
|
||||
pipelines,
|
||||
output_alias_map,
|
||||
}
|
||||
}
|
||||
|
||||
fn activate_bucket(&mut self, index: usize) {
|
||||
let bucket = self
|
||||
.compiled_buckets
|
||||
.get(index)
|
||||
.unwrap_or_else(|| panic!("Metal bucket index {index} is not compiled"))
|
||||
.clone();
|
||||
self.active_bucket = index;
|
||||
self.llir_graph = bucket.llir_graph;
|
||||
self.node_dtypes = bucket.node_dtypes;
|
||||
self.pipelines = bucket.pipelines;
|
||||
self.output_alias_map = bucket.output_alias_map;
|
||||
self.refresh_input_data_buffers();
|
||||
self.buffers.clear();
|
||||
}
|
||||
|
||||
fn refresh_input_data_buffers(&mut self) {
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn select_bucket(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
if self.compiled_buckets.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let index = self.resolve_bucket(dyn_map);
|
||||
if index != self.active_bucket {
|
||||
self.activate_bucket(index);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bucket(&self, dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
self.compiled_buckets
|
||||
.iter()
|
||||
.position(|bucket| {
|
||||
self.dim_buckets.iter().all(|(dim, buckets)| {
|
||||
let value = dyn_map.get(dim).copied().unwrap_or(0);
|
||||
let bucket_index = bucket.bucket_indices.get(dim).copied().unwrap_or(0);
|
||||
buckets
|
||||
.get(bucket_index)
|
||||
.map(|bucket| bucket.contains(value))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"No Metal bucket matches dyn_map {:?}. Defined buckets: {:?}",
|
||||
dyn_map, self.dim_buckets
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn update_dyn_buffer(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let ptr = self.dyn_buffer.contents() as *mut i32;
|
||||
unsafe {
|
||||
@@ -702,99 +473,87 @@ impl MetalRuntime {
|
||||
|
||||
/// Execute and return GPU-side execution time in microseconds.
|
||||
fn execute_timed(&mut self, dyn_map: &FxHashMap<char, usize>) -> (f64, TimingMethod) {
|
||||
autoreleasepool(|| {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
// gpuStartTime and gpuEndTime are available on macOS 10.15+
|
||||
let gpu_start: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUStartTime]
|
||||
};
|
||||
let gpu_end: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUEndTime]
|
||||
};
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
let gpu_time_seconds = gpu_end - gpu_start;
|
||||
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
(gpu_time_us, TimingMethod::DeviceTimestamp)
|
||||
})
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
// gpuStartTime and gpuEndTime are available on macOS 10.15+
|
||||
let gpu_start: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUStartTime]
|
||||
};
|
||||
let gpu_end: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUEndTime]
|
||||
};
|
||||
|
||||
let gpu_time_seconds = gpu_end - gpu_start;
|
||||
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
|
||||
|
||||
(gpu_time_us, TimingMethod::DeviceTimestamp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
use crate::{kernel::lower_expression_for_metal, runtime::MetalRuntime};
|
||||
use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
|
||||
use half::{bf16, f16};
|
||||
use half::f16;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use safetensors::{Dtype, tensor::TensorView};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::PathBuf,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
static SAFETENSORS_TEST_FILE_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
|
||||
assert_eq!(
|
||||
@@ -34,32 +26,6 @@ fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
bytemuck::cast_slice(values).to_vec()
|
||||
}
|
||||
|
||||
fn write_test_safetensors(tensors: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> PathBuf {
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = tensors
|
||||
.iter()
|
||||
.map(|(name, dtype, shape, data)| {
|
||||
(
|
||||
(*name).to_string(),
|
||||
TensorView::new(*dtype, shape.clone(), data).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let serialized = safetensors::serialize(&tensor_views, None).unwrap();
|
||||
let id = SAFETENSORS_TEST_FILE_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(format!(
|
||||
"luminal_metal_runtime_{}_{}.safetensors",
|
||||
std::process::id(),
|
||||
id
|
||||
));
|
||||
std::fs::write(&path, serialized).unwrap();
|
||||
path
|
||||
}
|
||||
|
||||
const TRANSFORMER_SEQ: usize = 4;
|
||||
const TRANSFORMER_HIDDEN: usize = 16;
|
||||
const TRANSFORMER_INTERMEDIATE: usize = 32;
|
||||
@@ -284,53 +250,6 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
assert_close(&out, &[9.0, 12.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(('s', 4));
|
||||
let output = (input + input).output();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
cx.set_dim('s', 1);
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, vec![1.0f32; 4]);
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
rt.set_data(input, s1_input.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let s1_out = rt.get_f32(output);
|
||||
assert_close(&s1_out[..4], &[2.0, 4.0, 6.0, 8.0], 0.001);
|
||||
|
||||
cx.set_dim('s', 3);
|
||||
let s3_input: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
let s3_expected: Vec<f32> = s3_input.iter().map(|v| v * 2.0).collect();
|
||||
rt.set_data(input, s3_input);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let s3_out = rt.get_f32(output);
|
||||
assert_close(&s3_out[..12], &s3_expected, 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_int_arithmetic_preserves_large_values() {
|
||||
let mut cx = Graph::default();
|
||||
let token = cx.tensor(1).as_dtype(DType::Int);
|
||||
let large_index = (token * 1024) + 123;
|
||||
let mod_output = (large_index % 65_537).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(mod_output), vec![891.0]);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
@@ -709,13 +628,8 @@ fn metal_regular_tiled_matmul_path() {
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSMatmul")),
|
||||
"expected MPS matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"MPS-compatible matmul should not extract the generic fallback, kernels: {:?}",
|
||||
kernels.iter().any(|k| k.contains("family: RegularTiled")),
|
||||
"expected regular tiled matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
@@ -733,287 +647,6 @@ fn metal_regular_tiled_matmul_path() {
|
||||
assert_close(&result, &expected, 2e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 7;
|
||||
let k = 11;
|
||||
let n = 13;
|
||||
let a = cx.tensor((m, k));
|
||||
let weight = cx.tensor((n, k));
|
||||
let output = a.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.35, -0.17);
|
||||
let weight_data = seeded_data(n * k, 0.21, -0.09);
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
|
||||
let ref_weight = CandleTensor::from_vec(weight_data, (n, k), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_weight.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 5;
|
||||
let k = 9;
|
||||
let n = 6;
|
||||
let lhs_storage = cx.tensor((k, m));
|
||||
let rhs = cx.tensor((k, n));
|
||||
let output = lhs_storage.t().matmul(rhs).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let lhs_data = seeded_data(k * m, 0.31, -0.12);
|
||||
let rhs_data = seeded_data(k * n, 0.27, -0.08);
|
||||
|
||||
rt.set_data(lhs_storage, &lhs_data);
|
||||
rt.set_data(rhs, &rhs_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_lhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major LHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_lhs = CandleTensor::from_vec(lhs_data, (k, m), &device)
|
||||
.unwrap()
|
||||
.t()
|
||||
.unwrap();
|
||||
let ref_rhs = CandleTensor::from_vec(rhs_data, (k, n), &device).unwrap();
|
||||
let expected = ref_lhs.matmul(&ref_rhs).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_batched_matmul_row_row_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let batch = 3;
|
||||
let m = 4;
|
||||
let k = 5;
|
||||
let n = 6;
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let b = cx.tensor((batch, k, n));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.17, -0.08);
|
||||
let b_data = seeded_data(batch * k * n, 0.11, -0.05);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSBatchedMatmul")),
|
||||
"expected MPS batched matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; batch * m * n];
|
||||
for batch_idx in 0..batch {
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..k {
|
||||
sum += a_data[batch_idx * m * k + row * k + inner]
|
||||
* b_data[batch_idx * k * n + inner * n + col];
|
||||
}
|
||||
expected[batch_idx * m * n + row * n + col] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, &attn_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"expected generic matmul fallback for non-contiguous merged-head projection, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| {
|
||||
k.contains("MetalMul") && k.contains(&format!("shape: [{seq}, {out_dim}, {hidden}]"))
|
||||
}),
|
||||
"generic fallback should remove the broadcast multiply intermediate, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; seq * out_dim];
|
||||
for token in 0..seq {
|
||||
for out_col in 0..out_dim {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..hidden {
|
||||
let head = inner / head_dim;
|
||||
let dim = inner % head_dim;
|
||||
let attn_idx = head * seq * head_dim + token * head_dim + dim;
|
||||
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
|
||||
}
|
||||
expected[token * out_dim + out_col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_batched_matmul_transposed_rhs_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let batch = 4;
|
||||
let m = 3;
|
||||
let k = 7;
|
||||
let n = 5;
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let weight = cx.tensor((batch, n, k));
|
||||
let output = a.matmul(weight.permute((0, 2, 1))).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.13, -0.06);
|
||||
let weight_data = seeded_data(batch * n * k, 0.09, -0.04);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels
|
||||
.iter()
|
||||
.any(|k| k.contains("MPSBatchedMatmul") && k.contains("transpose_rhs: true")),
|
||||
"expected MPS batched matmul transposed RHS path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; batch * m * n];
|
||||
for batch_idx in 0..batch {
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..k {
|
||||
sum += a_data[batch_idx * m * k + row * k + inner]
|
||||
* weight_data[batch_idx * n * k + col * k + inner];
|
||||
}
|
||||
expected[batch_idx * m * n + row * n + col] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 6;
|
||||
let k = 10;
|
||||
let n = 7;
|
||||
let a = cx.tensor((m, k)).as_dtype(DType::F16);
|
||||
let weight = cx.tensor((n, k)).as_dtype(DType::F16);
|
||||
let output = a.matmul(weight.t()).cast(DType::F32).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.22, -0.07);
|
||||
let weight_data = seeded_data(n * k, 0.18, -0.05);
|
||||
|
||||
rt.set_data(a, to_f16_vec(&a_data));
|
||||
rt.set_data(weight, to_f16_vec(&weight_data));
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS F16 matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
|
||||
let ref_weight = CandleTensor::from_vec(weight_data, (n, k), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_weight.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 5e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_rms_norm() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1338,153 +971,6 @@ fn test_scatter_basic() {
|
||||
assert_close(&out, &[0.0, 10.0, 0.0, 20.0, 30.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_buffer_roundtrip() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(1);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let cache = cx.tensor(4).persist();
|
||||
let cache_out = src.scatter(indexes, cache);
|
||||
let read = cache_out.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[0.0]);
|
||||
rt.set_data(indexes, &[0.0]);
|
||||
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
for (pos, value, expected) in [
|
||||
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
|
||||
(1, 20.0, [10.0, 20.0, 0.0, 0.0]),
|
||||
(2, 30.0, [10.0, 20.0, 30.0, 0.0]),
|
||||
] {
|
||||
rt.set_data(src, &[value]);
|
||||
rt.set_data(indexes, &[pos as f32]);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(read), &expected, 0.001);
|
||||
|
||||
let updated_cache = rt.remove_buffer(cache_out);
|
||||
rt.set_buffer(cache, updated_cache);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
|
||||
let mut cx = Graph::default();
|
||||
let weights = cx.named_tensor("weights", 3);
|
||||
let bias = cx.named_tensor("bias", 3);
|
||||
let out = (weights + bias).output();
|
||||
|
||||
let weight_values = [1.25f32, -2.5, 4.0];
|
||||
let tensors = [("weights", Dtype::F32, vec![3], bytes_of(&weight_values))];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(weights, &[99.0, 99.0, 99.0]);
|
||||
rt.set_data(bias, &[0.5, 1.0, -1.5]);
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[1.75, -1.5, 2.5], 0.001);
|
||||
std::fs::remove_file(path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_safetensors_converts_supported_float_dtypes() {
|
||||
let mut cx = Graph::default();
|
||||
let f16_to_f32 = cx.named_tensor("f16_to_f32", 2);
|
||||
let bf16_to_f32 = cx.named_tensor("bf16_to_f32", 2);
|
||||
let f16_to_f16 = cx.named_tensor("f16_to_f16", 2).as_dtype(DType::F16);
|
||||
let f32_to_f16 = cx.named_tensor("f32_to_f16", 2).as_dtype(DType::F16);
|
||||
let bf16_to_f16 = cx.named_tensor("bf16_to_f16", 2).as_dtype(DType::F16);
|
||||
|
||||
let f16_to_f32_out = (f16_to_f32 + 0.0).output();
|
||||
let bf16_to_f32_out = (bf16_to_f32 + 0.0).output();
|
||||
let f16_to_f16_out = f16_to_f16.cast(DType::F32).output();
|
||||
let f32_to_f16_out = f32_to_f16.cast(DType::F32).output();
|
||||
let bf16_to_f16_out = bf16_to_f16.cast(DType::F32).output();
|
||||
|
||||
let f16_to_f32_values = [f16::from_f32(1.5), f16::from_f32(-2.25)];
|
||||
let bf16_to_f32_values = [bf16::from_f32(3.5), bf16::from_f32(-4.25)];
|
||||
let f16_to_f16_values = [f16::from_f32(5.5), f16::from_f32(-6.25)];
|
||||
let f32_to_f16_values = [7.5f32, -8.25];
|
||||
let bf16_to_f16_values = [bf16::from_f32(9.5), bf16::from_f32(-10.25)];
|
||||
let tensors = [
|
||||
(
|
||||
"f16_to_f32",
|
||||
Dtype::F16,
|
||||
vec![2],
|
||||
bytes_of(&f16_to_f32_values),
|
||||
),
|
||||
(
|
||||
"bf16_to_f32",
|
||||
Dtype::BF16,
|
||||
vec![2],
|
||||
bytes_of(&bf16_to_f32_values),
|
||||
),
|
||||
(
|
||||
"f16_to_f16",
|
||||
Dtype::F16,
|
||||
vec![2],
|
||||
bytes_of(&f16_to_f16_values),
|
||||
),
|
||||
(
|
||||
"f32_to_f16",
|
||||
Dtype::F32,
|
||||
vec![2],
|
||||
bytes_of(&f32_to_f16_values),
|
||||
),
|
||||
(
|
||||
"bf16_to_f16",
|
||||
Dtype::BF16,
|
||||
vec![2],
|
||||
bytes_of(&bf16_to_f16_values),
|
||||
),
|
||||
];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(f16_to_f32_out), &[1.5, -2.25], 0.001);
|
||||
assert_close(&rt.get_f32(bf16_to_f32_out), &[3.5, -4.25], 0.001);
|
||||
assert_close(&rt.get_f32(f16_to_f16_out), &[5.5, -6.25], 0.001);
|
||||
assert_close(&rt.get_f32(f32_to_f16_out), &[7.5, -8.25], 0.001);
|
||||
assert_close(&rt.get_f32(bf16_to_f16_out), &[9.5, -10.25], 0.001);
|
||||
std::fs::remove_file(path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((4, 3));
|
||||
let data = input.transpose(0, 1);
|
||||
let indexes = cx.tensor((2, 2)).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(
|
||||
input,
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[0.0, 9.0, 1.0, 10.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_into_nonzero_dest() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1499,12 +985,6 @@ fn test_scatter_into_nonzero_dest() {
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"expected no-copy scatter for consumed destination, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1512,89 +992,6 @@ fn test_scatter_into_nonzero_dest() {
|
||||
assert_close(&out, &[1.0, 2.0, 99.0, 4.0, 5.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_remove_buffer_aliases_dest() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(2);
|
||||
let indexes = cx.tensor(2).as_dtype(DType::Int);
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[7.0, 8.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let moved = rt.remove_buffer(result);
|
||||
let moved_values = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
moved.contents() as *const f32,
|
||||
moved.length() as usize / std::mem::size_of::<f32>(),
|
||||
)
|
||||
.to_vec()
|
||||
};
|
||||
assert_close(&moved_values, &[10.0, 7.0, 30.0, 8.0, 50.0], 0.001);
|
||||
rt.set_buffer(dest.id, moved);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_handles_2d_destination() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(2);
|
||||
let indexes = cx.tensor(2).as_dtype(DType::Int);
|
||||
let dest = cx.tensor((2, 3));
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[9.0, 8.0]);
|
||||
rt.set_data(indexes, &[2.0, 4.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"expected no-copy scatter for 2D destination, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(result), &[1.0, 2.0, 9.0, 4.0, 8.0, 6.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(1);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let dest = cx.tensor(4);
|
||||
let scatter = src.scatter(indexes, dest).output();
|
||||
let dest_plus_one = (dest + 1.0).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"no-copy scatter should not be selected when dest is also consumed, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(scatter), &[10.0, 99.0, 30.0, 40.0], 0.001);
|
||||
assert_close(&rt.get_f32(dest_plus_one), &[11.0, 21.0, 31.0, 41.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_all_positions() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1615,21 +1012,3 @@ fn test_scatter_all_positions() {
|
||||
let out = rt.get_f32(result);
|
||||
assert_close(&out, &[10.0, 20.0, 30.0, 40.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_preserves_data_dtype() {
|
||||
let mut cx = Graph::default();
|
||||
let data = cx.tensor(2);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[2.5], 0.001);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "luminal_nn"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
@@ -61,8 +61,7 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -71,7 +70,7 @@ impl MoE {
|
||||
mod tests {
|
||||
use super::MoE;
|
||||
use luminal::prelude::*;
|
||||
use rand::{Rng, rng};
|
||||
use rand::{rng, Rng};
|
||||
|
||||
fn random_vec(n: usize) -> Vec<f32> {
|
||||
let mut r = rng();
|
||||
@@ -479,8 +478,7 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -782,86 +782,3 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
|
||||
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
|
||||
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-30 — `translate_grouped_mm` casted the full expert weight to F32, OOMing search on Qwen3-MoE
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`benchmarks/ttft/run.py --config qwen3-moe` crashed every search-profile attempt with:
|
||||
```
|
||||
crates/luminal_cuda_lite/src/runtime.rs:711: called `Result::unwrap()` on an `Err` value:
|
||||
DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")
|
||||
```
|
||||
The DB shows this had been failing every run for ~2 weeks. The rust `examples/qwen3_moe` ran fine end-to-end. python_baseline / python_torch_compile / qwen3-4b were all fine — only python_luminal × qwen3-moe failed.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`translate_grouped_mm` in `crates/luminal_python/rust/src/translator/tensor.rs` was lowering HF's `_grouped_mm(input, weight, offs)` op to a *full-broadcast* batched matmul plus a group-mask:
|
||||
|
||||
```rust
|
||||
let weight_f = weight.cast(DType::F32); // [G=128, K, N] cast → 1.5 GB / layer
|
||||
let input_batched = input_f.expand_dim(0, g);
|
||||
let all_out = input_batched.matmul(weight_f); // [G, S, N]
|
||||
let mask = ... (g_arange == expert_id).cast(F32);
|
||||
let out = (all_out * mask.expand_dim(2, n)).sum(0); // mask + sum over G
|
||||
```
|
||||
|
||||
The full `[G, K, N]` F32 cast intermediate is 1.5 GB / layer for gate-up and 0.6 GB / layer for down on Qwen3-30B-A3B. With 60 GB of persistent bf16 weights already on a 97 GB GPU, the search-time profiler ran out of memory allocating those casts.
|
||||
|
||||
By contrast, `examples/qwen3_moe`'s `gather_experts` gathers only the top-K active experts per token first, then casts that small `[s, k, d1, d2]` slice (~100 MB / layer). The GLUMoE host op (`crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg`) is also wired to this gather pattern.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Code path was reasonable in isolation**: at small scale (`test_grouped_mm_fallback`: g=2, K=8, N=16) the broadcast version was fine — the F32 cast was only 1 KB, and search profiling never noticed.
|
||||
2. **The error reported "out of memory" but the rest of the system looked healthy**: 60 GB weights + 37 GB headroom looks like plenty until you realise 48 layers × 2.1 GB cast intermediates per layer doesn't fit, even after loop rolling.
|
||||
3. **The DB's `code 1` failures looked the same as a Python exception** — the actual panic site (`runtime.rs:711:64` `stream.alloc_zeros(needed_bytes).unwrap()`) had to be recovered from a tmux scrollback because the orchestrator's stdout was already torn down by the time we looked.
|
||||
|
||||
### The fix
|
||||
|
||||
Rewrote `translate_grouped_mm` to gather first, matmul second:
|
||||
|
||||
```rust
|
||||
// expert_id[m] = first g s.t. m < offs[g], clamped to [0, G-1]
|
||||
let expert_id = ge_boundary.sum(0).minimum_f32(g_max_f).cast(DType::Int);
|
||||
|
||||
// flat_idx = expert_id * (K*N) + iota('z', (K, N)) — same shape as
|
||||
// rust qwen3_moe's `gather_experts`
|
||||
let flat_idx = (expert_id * (k * n))
|
||||
.expand_dim(1, k).expand_dim(2, n)
|
||||
+ self.graph.iota(Expression::from('z'), (k, n)).expand_dim(0, s);
|
||||
|
||||
let weight_gathered = weight.gather(flat_idx); // [S, K, N], bf16
|
||||
let result = input.cast(F32).unsqueeze(1)
|
||||
.matmul(weight_gathered.cast(F32)) // [S, 1, N]
|
||||
.squeeze(1);
|
||||
```
|
||||
|
||||
Two important details:
|
||||
|
||||
1. **Clamp `expert_id` to `[0, G-1]`**: at search time, dummy data fills `offs` with all-1s (`make_ones_bytes` in `compile_backend`). For S>1 that pushes `expert_id` to G (boundary count = G), which is one past the last valid expert and OOBs the gather. HF's own grouped-MM forward also clamps for the same reason (invalid expert IDs from EP).
|
||||
2. **Don't cast the full weight**: the cast moved from before the batched-matmul (over `[G, K, N]`) to after the gather (over `[S, K, N]`). 16× shrink at prefill (S=top_k=8 vs G=128).
|
||||
|
||||
### Result
|
||||
|
||||
`search-iters=1` end-to-end works on Qwen3-30B-A3B: `BENCH_RESULT … "ttft_ms": 9350.5, "tpot_ms": 1166.7`. The OOM is gone.
|
||||
|
||||
`search-iters>=5` still crashes — but with a *different*, downstream `CUDA_ERROR_ILLEGAL_ADDRESS` during execution after search completes. That looks like the same family as the 2026-03-07 / 2026-03-09 egglog-extractor non-determinism bugs (some mutation during search picks a kernel/rewrite combo that's broken at this scale). It's a separate investigation — the gather-based lowering is correct in isolation (`test_grouped_mm_fallback` passes; a synthetic `g=128, S=8, K=2048, N=1536` bf16 test passes with max-diff ~2.4e-4).
|
||||
|
||||
### General principle
|
||||
|
||||
**When lowering an op that takes a per-row index over a large parameter, gather first and cast second — never cast the full parameter to F32 just because your matmul kernel is F32-only.** A "broadcast over G + mask" pattern is mathematically equivalent to "gather per-row" but materialises a G× larger intermediate — fine for tests, ruinous on real MoE checkpoints. When in doubt, mirror the rust example's pattern: the egglog fusion rules (GLUMoE here) are written to recognise the gather form, not the broadcast-and-mask form.
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
2. **Root cause #1**: the dispatch table in `crates/luminal_python/rust/src/translator/dispatch.rs` mapped `sigmoid`, `tanh`, `relu` etc. but not `gelu` or `silu`. Whisper's encoder uses `F.gelu`, so the activation hit a hole.
|
||||
3. **Root cause #2**: PyTorch serializes `float("-inf")` in PT2 as the string `"-Infinity"` (and `"NaN"`/`"Infinity"` analogously). `translate_full`'s `get_float_arg` only accepts numeric float/int payloads, so any `torch.full((..), -inf)` (the obvious way to write a causal mask) blows up. Decoder mask code is the most common spot.
|
||||
4. **Why it was tricky**: both errors arrive from inside `pt2_backend` with a stack trace that ends in `process_pt2`, hiding the actual ATen target inside the message. You only see the offending op name in the error string itself, so you have to read `RuntimeError: Failed to translate node N: …` carefully and grep `dispatch.rs` for it.
|
||||
5. **Fix in this session**:
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# luminal_python
|
||||
|
||||
PyTorch `torch.compile` integration for Luminal.
|
||||
|
||||
## CUDA Tests
|
||||
|
||||
The Python CUDA CI job builds the Rust extension with the CUDA feature and runs
|
||||
the non-slow pytest suite:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
|
||||
```
|
||||
|
||||
The slow tests are explicit opt-in. They include large/pretrained model tests,
|
||||
full-width architecture compiles, Whisper end-to-end cases, and other cases that
|
||||
can take a long time or need a large GPU / Hugging Face cache.
|
||||
|
||||
Run the full Python CUDA suite, including slow tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s
|
||||
```
|
||||
|
||||
Run only the slow Python CUDA tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m slow
|
||||
```
|
||||
|
||||
The helper script follows the same convention:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
./run_tests_cuda.sh # non-slow CUDA suite
|
||||
./run_tests_cuda.sh --slow-only # only slow CUDA tests
|
||||
./run_tests_cuda.sh --include-slow
|
||||
```
|
||||
|
||||
The GitHub/Modal entrypoint uses the same marker split:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s -m "not slow"
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s
|
||||
```
|
||||
|
||||
@@ -1,497 +0,0 @@
|
||||
"""Whisper transcription demo using the luminal torch.compile backend.
|
||||
|
||||
Implements a small PyTorch port of ``openai/whisper-tiny.en`` that mirrors the
|
||||
luminal Rust example (``examples/whisper`` in the workspace), loads the official
|
||||
HuggingFace weights, and runs greedy decoding through the luminal backend via
|
||||
``torch.compile``.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python examples/whisper.py [path/to/audio.wav]
|
||||
|
||||
If no path is provided, falls back to the JFK sample bundled with the Rust
|
||||
``examples/whisper`` crate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
WhisperFeatureExtractor,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperTokenizer,
|
||||
)
|
||||
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
REPO_ID = "openai/whisper-tiny.en"
|
||||
|
||||
# whisper-tiny.en hyperparameters
|
||||
N_MELS = 80
|
||||
N_AUDIO_CTX = 1500
|
||||
D_MODEL = 384
|
||||
N_HEADS = 6
|
||||
HEAD_DIM = D_MODEL // N_HEADS
|
||||
N_AUDIO_LAYER = 4
|
||||
N_TEXT_LAYER = 4
|
||||
N_TEXT_CTX = 448
|
||||
FF_DIM = 4 * D_MODEL
|
||||
N_VOCAB = 51864
|
||||
LAYER_NORM_EPS = 1e-5
|
||||
|
||||
# Decoder special tokens
|
||||
TOKEN_SOT = 50257
|
||||
TOKEN_NO_TIMESTAMPS = 50362
|
||||
TOKEN_EOT = 50256
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model — mirrors the HLIR encoder/decoder in examples/whisper/src/model.rs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhisperAttention(torch.nn.Module):
|
||||
"""Multi-head attention with separate q/k/v projections (no bias on k_proj)."""
|
||||
|
||||
def __init__(self, d_model: int = D_MODEL, n_heads: int = N_HEADS):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.q_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.out_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
kv_input: Optional[torch.Tensor] = None,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# x: (seq, d_model). kv_input is None → self-attn; otherwise cross-attn.
|
||||
kv = x if kv_input is None else kv_input
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(kv)
|
||||
v = self.v_proj(kv)
|
||||
|
||||
seq_q = q.shape[0]
|
||||
seq_kv = k.shape[0]
|
||||
|
||||
# (seq, d_model) -> (n_heads, seq, head_dim)
|
||||
q = q.reshape(seq_q, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
scale = 1.0 / (self.head_dim**0.5)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (h, sq, sk)
|
||||
if causal:
|
||||
# Use a large finite negative instead of -inf so the export pipeline
|
||||
# serializes a float instead of the unsupported "-Infinity" sentinel.
|
||||
mask = torch.triu(
|
||||
torch.full((seq_q, seq_kv), -1e10, device=x.device),
|
||||
diagonal=1,
|
||||
)
|
||||
scores = scores + mask
|
||||
weights = torch.softmax(scores, dim=-1)
|
||||
attn = torch.matmul(weights, v) # (h, sq, hd)
|
||||
merged = attn.transpose(0, 1).reshape(seq_q, -1)
|
||||
return self.out_proj(merged)
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x))
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(
|
||||
N_MELS, D_MODEL, kernel_size=3, padding=1, bias=True
|
||||
)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
D_MODEL, D_MODEL, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
# Position embedding stored as a regular parameter (matches HF layout).
|
||||
self.embed_positions = torch.nn.Embedding(N_AUDIO_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[EncoderLayer() for _ in range(N_AUDIO_LAYER)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
# mel: (n_mels, 3000) -> add batch dim for conv1d
|
||||
x = mel.unsqueeze(0)
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
# (1, d_model, 1500) -> (1500, d_model)
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
x = x + self.embed_positions.weight
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.layer_norm(x)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.encoder_attn = WhisperAttention()
|
||||
self.encoder_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x), causal=True)
|
||||
x = x + self.encoder_attn(self.encoder_attn_layer_norm(x), kv_input=xa)
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(N_VOCAB, D_MODEL)
|
||||
self.embed_positions = torch.nn.Embedding(N_TEXT_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList([DecoderLayer() for _ in range(N_TEXT_LAYER)])
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
# tokens: (seq,) of int64 — absolute positions are 0..seq-1
|
||||
seq = tokens.shape[0]
|
||||
pos = torch.arange(seq, dtype=torch.long, device=tokens.device)
|
||||
x = self.embed_tokens(tokens) + self.embed_positions(pos)
|
||||
for layer in self.layers:
|
||||
x = layer(x, xa)
|
||||
x = self.layer_norm(x)
|
||||
# Tied projection
|
||||
return torch.matmul(x, self.embed_tokens.weight.transpose(0, 1))
|
||||
|
||||
|
||||
class Whisper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = WhisperEncoder()
|
||||
self.decoder = WhisperDecoder()
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
xa = self.encoder(mel)
|
||||
return self.decoder(tokens, xa)
|
||||
|
||||
|
||||
class DecoderWithFixedXa(torch.nn.Module):
|
||||
"""Wraps the decoder with the encoder output stored as a buffer.
|
||||
|
||||
The audio is fixed for the whole utterance, so ``xa`` is a constant relative
|
||||
to the per-token decode loop. Storing it as a buffer lets us compile the
|
||||
decoder once with a single dynamic-length ``tokens`` input, avoiding a full
|
||||
recompilation at every step as the sequence grows.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder: WhisperDecoder, xa: torch.Tensor):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.register_buffer("xa", xa)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(tokens, self.xa)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight loading: HF state_dict -> our model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_hf_weights_into(model: Whisper) -> None:
|
||||
"""Copy HF whisper-tiny.en weights into our matching modules."""
|
||||
hf = WhisperForConditionalGeneration.from_pretrained(REPO_ID).eval()
|
||||
sd = hf.state_dict()
|
||||
|
||||
def get(name: str) -> torch.Tensor:
|
||||
return sd[f"model.{name}"].clone()
|
||||
|
||||
enc = model.encoder
|
||||
enc.conv1.weight.data.copy_(get("encoder.conv1.weight"))
|
||||
enc.conv1.bias.data.copy_(get("encoder.conv1.bias"))
|
||||
enc.conv2.weight.data.copy_(get("encoder.conv2.weight"))
|
||||
enc.conv2.bias.data.copy_(get("encoder.conv2.bias"))
|
||||
enc.embed_positions.weight.data.copy_(get("encoder.embed_positions.weight"))
|
||||
enc.layer_norm.weight.data.copy_(get("encoder.layer_norm.weight"))
|
||||
enc.layer_norm.bias.data.copy_(get("encoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(enc.layers):
|
||||
prefix = f"encoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
dec = model.decoder
|
||||
dec.embed_tokens.weight.data.copy_(get("decoder.embed_tokens.weight"))
|
||||
dec.embed_positions.weight.data.copy_(get("decoder.embed_positions.weight"))
|
||||
dec.layer_norm.weight.data.copy_(get("decoder.layer_norm.weight"))
|
||||
dec.layer_norm.bias.data.copy_(get("decoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(dec.layers):
|
||||
prefix = f"decoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.encoder_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.q_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.k_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.bias")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio loading + decoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_wav_16k_mono(path: Path) -> np.ndarray:
|
||||
with wave.open(str(path), "rb") as w:
|
||||
sr = w.getframerate()
|
||||
n = w.getnframes()
|
||||
ch = w.getnchannels()
|
||||
sw = w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
|
||||
if sw == 2:
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sw == 4:
|
||||
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
elif sw == 1:
|
||||
samples = (
|
||||
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||
) / 128.0
|
||||
else:
|
||||
raise ValueError(f"unsupported sample width {sw}")
|
||||
|
||||
if ch > 1:
|
||||
samples = samples.reshape(-1, ch).mean(axis=1)
|
||||
|
||||
if sr != 16000:
|
||||
ratio = sr / 16000
|
||||
out_len = int(len(samples) / ratio)
|
||||
idx = np.arange(out_len, dtype=np.float64) * ratio
|
||||
lo = idx.astype(np.int64)
|
||||
frac = (idx - lo).astype(np.float32)
|
||||
hi = np.clip(lo + 1, 0, len(samples) - 1)
|
||||
samples = samples[lo] * (1.0 - frac) + samples[hi] * frac
|
||||
|
||||
return samples.astype(np.float32)
|
||||
|
||||
|
||||
def greedy_decode(logits_row: torch.Tensor, suppress_first_eot: bool) -> int:
|
||||
masked = logits_row.clone()
|
||||
masked[TOKEN_SOT:] = float("-inf")
|
||||
if suppress_first_eot:
|
||||
masked[TOKEN_EOT] = float("-inf")
|
||||
return int(torch.argmax(masked).item())
|
||||
|
||||
|
||||
def find_default_audio() -> Optional[Path]:
|
||||
here = Path(__file__).resolve()
|
||||
workspace_root = here.parents[3]
|
||||
candidate = workspace_root / "examples" / "whisper" / "assets" / "jfk.wav"
|
||||
return candidate if candidate.exists() else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
audio_arg = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
if audio_arg:
|
||||
audio_path = Path(audio_arg)
|
||||
else:
|
||||
audio_path = find_default_audio()
|
||||
if audio_path is None:
|
||||
print(
|
||||
"error: no audio file given and bundled jfk.wav not found",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Loading audio:", audio_path)
|
||||
audio = load_wav_16k_mono(audio_path)
|
||||
|
||||
print("Computing log-mel features...")
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(REPO_ID)
|
||||
features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
|
||||
mel: torch.Tensor = features.input_features[0].to(device) # (80, 3000)
|
||||
assert mel.shape == (N_MELS, 3000), mel.shape
|
||||
|
||||
print("Building model and loading weights...")
|
||||
model = Whisper().eval().to(device)
|
||||
load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = 100
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
# 1. Run the encoder once eagerly. The audio doesn't change during decode,
|
||||
# so xa is a constant input to the decoder.
|
||||
with torch.no_grad():
|
||||
xa = model.encoder(mel)
|
||||
|
||||
# 2. Wrap the decoder so its only varying input is `tokens`, then compile
|
||||
# once with a dynamic length dim. Subsequent calls reuse the same
|
||||
# compiled graph — no recompile per token.
|
||||
decoder_only = DecoderWithFixedXa(model.decoder, xa).eval().to(device)
|
||||
example_tokens = torch.tensor(
|
||||
[TOKEN_SOT, TOKEN_NO_TIMESTAMPS], dtype=torch.long, device=device
|
||||
)
|
||||
print(
|
||||
f"Compiling decoder with dynamic seq dim (search_iters={search_iters})..."
|
||||
)
|
||||
compile_start = time.time()
|
||||
compiled_decoder = luminal_compile(
|
||||
decoder_only,
|
||||
example_tokens,
|
||||
search_iterations=search_iters,
|
||||
dynamic_dim=0,
|
||||
)
|
||||
print(f"Compiled in {time.time() - compile_start:.1f}s")
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
out = compiled_decoder(decoder_input_ids)
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
else:
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return model(mel, decoder_input_ids)
|
||||
|
||||
tokens = [TOKEN_SOT, TOKEN_NO_TIMESTAMPS]
|
||||
|
||||
print("Transcribing", end="", flush=True)
|
||||
decode_start = time.time()
|
||||
for step in range(max_new_tokens):
|
||||
decoder_input_ids = torch.tensor(tokens, dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
logits = step_logits(decoder_input_ids)
|
||||
|
||||
next_token = greedy_decode(logits[-1], suppress_first_eot=(step == 0))
|
||||
if next_token == TOKEN_EOT:
|
||||
break
|
||||
tokens.append(next_token)
|
||||
piece = tokenizer.decode([next_token], skip_special_tokens=False)
|
||||
print(piece, end="", flush=True)
|
||||
elapsed = time.time() - decode_start
|
||||
print()
|
||||
|
||||
transcription = tokenizer.decode(tokens[2:], skip_special_tokens=True)
|
||||
print(f"\nFinal transcription: {transcription}")
|
||||
print(
|
||||
f"Generated {len(tokens) - 2} tokens in {elapsed:.2f}s "
|
||||
f"({(len(tokens) - 2) / max(elapsed, 1e-6):.1f} tok/s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -22,7 +22,7 @@ from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 2 * 60 * 60
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
@@ -168,37 +168,6 @@ def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _build_cuda_extension(env: dict[str, str]) -> None:
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"maturin",
|
||||
"develop",
|
||||
"--manifest-path",
|
||||
f"{PROJECT_DIR}/rust/Cargo.toml",
|
||||
"--features",
|
||||
"cuda",
|
||||
"--profile",
|
||||
"release",
|
||||
]
|
||||
subprocess.run(cmd, env=env, cwd=PROJECT_DIR, check=True)
|
||||
|
||||
|
||||
def _effective_timeout(timeout: int) -> int:
|
||||
if os.environ.get("GITHUB_ACTIONS") == "true" and timeout < DEFAULT_TIMEOUT:
|
||||
print(
|
||||
f"Using Modal timeout {DEFAULT_TIMEOUT}s instead of requested "
|
||||
f"{timeout}s in GitHub Actions.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return DEFAULT_TIMEOUT
|
||||
return timeout
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
@@ -225,8 +194,6 @@ class TestRunner:
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
_build_cuda_extension(env)
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
@@ -251,6 +218,8 @@ class TestRunner:
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
@@ -316,7 +285,7 @@ class TestRunner:
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int, bool, str | None, list[str]]:
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
@@ -331,8 +300,7 @@ def _parse_cli_args(
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=DEFAULT_TIMEOUT,
|
||||
help="Modal execution timeout in seconds. Defaults to %(default)s seconds.",
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
@@ -366,11 +334,11 @@ def main(*cli_args: str):
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
timeout = _effective_timeout(timeout)
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
runner_options["timeout"] = timeout
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
|
||||
@@ -32,7 +32,7 @@ module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: tests that download large models, compile full-width model graphs, fuzz many CUDA search choices, or otherwise require explicit opt-in",
|
||||
"slow: tests that download large models or require pre-generated artifacts",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -1,43 +1,34 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/"
|
||||
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 1: Building native backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
uv run --group dev pytest $NATIVE_TESTS -v
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Building CUDA backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "Slow CUDA tests are opt-in. To include them, run:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -v -s"
|
||||
echo "Or, for only slow tests:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -m slow -v -s"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -4,34 +4,17 @@ set -e
|
||||
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
|
||||
echo ""
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
PYTEST_MARK='not slow'
|
||||
if [[ "${1:-}" == "--include-slow" ]]; then
|
||||
PYTEST_MARK=''
|
||||
elif [[ "${1:-}" == "--slow-only" ]]; then
|
||||
PYTEST_MARK='slow'
|
||||
elif [[ "${1:-}" != "" ]]; then
|
||||
echo "Usage: ./run_tests_cuda.sh [--include-slow|--slow-only]"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
if [[ -n "$PYTEST_MARK" ]]; then
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -m "$PYTEST_MARK" -v -s
|
||||
else
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -v -s
|
||||
fi
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -12,67 +12,6 @@ use crate::typed_data::TypedData;
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Recover a single-variable dim's variable value from an observed runtime size.
|
||||
///
|
||||
/// Returns `Some((var, value))` when the expression contains exactly one
|
||||
/// variable, is affine in that variable, and `value` round-trips through
|
||||
/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise
|
||||
/// — multi-variable expressions, non-affine forms, slope==0, and inversions
|
||||
/// that don't divide cleanly are all rejected so we never write a wrong
|
||||
/// guess into `dyn_map`.
|
||||
fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> {
|
||||
use luminal::shape::Term;
|
||||
let terms = expr.terms.read();
|
||||
|
||||
// Identify the unique variable, if any.
|
||||
let mut var: Option<char> = None;
|
||||
for t in terms.iter() {
|
||||
if let Term::Var(c) = t {
|
||||
match var {
|
||||
None => var = Some(*c),
|
||||
Some(existing) if existing == *c => {}
|
||||
Some(_) => return None, // multi-var — bail out
|
||||
}
|
||||
}
|
||||
}
|
||||
let var = var?;
|
||||
|
||||
// Bare-var fast path — terms is exactly `[Var]`.
|
||||
if terms.len() == 1 {
|
||||
return Some((var, dim_val));
|
||||
}
|
||||
|
||||
// Probe two points to recover slope/intercept of an assumed affine form
|
||||
// `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default
|
||||
// dynamic-dim min is 2, and 3 keeps the inputs small in case the
|
||||
// expression includes a multiplication that could overflow at scale).
|
||||
drop(terms);
|
||||
let f2 = expr.exec_single_var_checked(2)? as i64;
|
||||
let f3 = expr.exec_single_var_checked(3)? as i64;
|
||||
let slope = f3 - f2;
|
||||
if slope == 0 {
|
||||
return None;
|
||||
}
|
||||
let intercept = f2 - 2 * slope;
|
||||
let target = dim_val as i64 - intercept;
|
||||
if slope == 0 || target % slope != 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = target / slope;
|
||||
if candidate < 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = candidate as usize;
|
||||
|
||||
// Verify by re-evaluating with the candidate value. Catches non-affine
|
||||
// forms whose probe points happen to be collinear (e.g. `min(s, 100)`
|
||||
// would look affine for s ∈ {2, 3} but flatten beyond 100).
|
||||
if expr.exec_single_var_checked(candidate)? != dim_val {
|
||||
return None;
|
||||
}
|
||||
Some((var, candidate))
|
||||
}
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
@@ -98,12 +37,7 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -129,9 +63,7 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -158,21 +90,17 @@ impl CompiledGraph {
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
let WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
} = weight_data;
|
||||
|
||||
// Build compile args from WeightData.
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weights
|
||||
weights: weight_data
|
||||
.weights
|
||||
.iter()
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
};
|
||||
|
||||
// Create backend via the factory directly
|
||||
@@ -291,27 +219,17 @@ impl CompiledGraph {
|
||||
}
|
||||
|
||||
/// Auto-detect and set dynamic dimensions from input tensor shapes.
|
||||
///
|
||||
/// For each user input we walk the symbolic shape expressions side-by-side
|
||||
/// with the concrete sizes Dynamo handed us at runtime and try to recover
|
||||
/// each unbound variable's value. Two cases are handled:
|
||||
///
|
||||
/// * Bare-variable dim (`s`): set directly from the size.
|
||||
/// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a`
|
||||
/// by sampling the expression at two probe points to extract the
|
||||
/// slope, recovering the intercept, and verifying that plugging the
|
||||
/// recovered value back through `exec_single_var_checked` reproduces
|
||||
/// the observed size. The verification step rejects everything
|
||||
/// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong
|
||||
/// guess to `dyn_map`.
|
||||
///
|
||||
/// Multi-variable dims are skipped here; another input's shape — or an
|
||||
/// explicit `set_dim` call — is expected to bind those.
|
||||
/// For each user input, matches the concrete shape against its symbolic
|
||||
/// shape expressions and sets the corresponding dyn_map entries.
|
||||
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
|
||||
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
|
||||
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
|
||||
if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) {
|
||||
self.graph.set_dim(var, value);
|
||||
// Check if this expression is a bare symbolic variable
|
||||
let terms = dim_expr.terms.read();
|
||||
if terms.len() == 1
|
||||
&& let luminal::shape::Term::Var(c) = terms[0]
|
||||
{
|
||||
self.graph.set_dim(c, dim_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -391,7 +309,7 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
@@ -452,7 +370,7 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Register a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
@@ -487,7 +405,10 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes.clone()
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -3,7 +3,6 @@ pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
mod pt2_expr;
|
||||
mod pt2_parser;
|
||||
mod pt2_schema;
|
||||
mod pt2_util;
|
||||
|
||||
@@ -2,11 +2,10 @@ use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use pyo3::types::{PyAny, PyCapsule, PyCapsuleMethods, PyDict};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_expr::parse_sympy_expr;
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::typed_data::TypedData;
|
||||
@@ -15,6 +14,58 @@ use crate::{pt2_parser, pt2_util};
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct CompileOptions {
|
||||
search_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
search_iterations: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
fn from_py(options: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
|
||||
let mut parsed = Self::default();
|
||||
|
||||
let Some(options) = options else {
|
||||
return Ok(parsed);
|
||||
};
|
||||
|
||||
let options = options.cast::<PyDict>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err("luminal backend options must be a dict")
|
||||
})?;
|
||||
|
||||
for (key, value) in options.iter() {
|
||||
let key = key.extract::<String>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option keys must be strings",
|
||||
)
|
||||
})?;
|
||||
|
||||
match key.as_str() {
|
||||
"search_iterations" => {
|
||||
parsed.search_iterations = value.extract::<usize>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option 'search_iterations' must be an integer",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
other => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Unsupported luminal backend option '{other}'. Supported options: search_iterations",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
@@ -22,47 +73,32 @@ fn resolve_dim_sizes(
|
||||
sizes
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int),
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
let s = e.as_expr.expr_str.trim();
|
||||
// Try the full sympy-style parse first so compound forms like
|
||||
// `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and
|
||||
// similar dim-altering ops) propagate as a real Expression
|
||||
// rather than collapsing to the size-1 fallback. Fall back to
|
||||
// the bare-Symbol fast path when that fails — the parser
|
||||
// bails on unrecognised heads (Pow, Min, etc.) and we'd
|
||||
// rather lose the symbolic info than misinterpret it.
|
||||
parse_sympy_expr(s, sym_to_char)
|
||||
.or_else(|| {
|
||||
pt2_parser::extract_symbol_name_pub(s)
|
||||
.and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c)))
|
||||
})
|
||||
.or_else(|| {
|
||||
// As a last resort, if the EP gave us a concrete `hint`
|
||||
// (the value used to seed shape tracing), use it. The
|
||||
// dim is technically dynamic but at least output-shape
|
||||
// resolution won't return 1 for unset dims.
|
||||
e.as_expr
|
||||
.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(Expression::from)
|
||||
})
|
||||
.unwrap_or_else(|| Expression::from(1usize))
|
||||
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
|
||||
if let Some(c) = sym_to_char.get(&sym) {
|
||||
Expression::from(*c)
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
#[pyo3(signature = (pt2_path, weights_path, factory_capsule, weight_device_ptrs=None, options=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
options: Option<&Bound<'_, PyAny>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
let options = CompileOptions::from_py(options)?;
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
@@ -100,7 +136,7 @@ pub fn process_pt2(
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
search_iters,
|
||||
&options,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
@@ -110,14 +146,14 @@ pub fn process_pt2(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
options: &CompileOptions,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
|
||||
CompiledGraph::parse_graph(translation, weights, factory, options.search_iterations)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
@@ -130,13 +166,10 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,14 +185,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -423,3 +456,72 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CompileOptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
use std::sync::Once;
|
||||
|
||||
fn with_python(f: impl FnOnce(Python<'_>)) {
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(Python::initialize);
|
||||
Python::attach(f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_apply() {
|
||||
let options = CompileOptions::from_py(None).unwrap();
|
||||
assert_eq!(options.search_iterations, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_dict_overlays_defaults() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", 3).unwrap();
|
||||
|
||||
let parsed = CompileOptions::from_py(Some(options.as_any())).unwrap();
|
||||
assert_eq!(parsed.search_iterations, 3);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_unknown_keys() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("unknown", 1).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Unsupported luminal backend option 'unknown'")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_non_dict() {
|
||||
with_python(|py| {
|
||||
let options = 123usize.into_pyobject(py).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("options must be a dict"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_bad_search_iterations_type() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", "fast").unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("search_iterations"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_parser::SymDimMap;
|
||||
use crate::pt2_schema::RangeConstraint;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) struct ExprBounds {
|
||||
pub(crate) min: Option<i64>,
|
||||
pub(crate) max: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct ParsedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
impl ParsedExpr {
|
||||
fn exact(expr: Expression, value: i64) -> Self {
|
||||
Self {
|
||||
expr,
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct BoundedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal `Expression`.
|
||||
///
|
||||
/// Supports the subset of sympy heads PT2 emits for symbolic shape metadata.
|
||||
pub(crate) fn parse_sympy_expr(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr, sym_to_char, &HashMap::new())
|
||||
}
|
||||
|
||||
pub(crate) fn parse_sympy_expr_with_ranges(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_inner(expr, sym_to_char, ranges).map(|parsed| parsed.expr)
|
||||
}
|
||||
|
||||
pub(crate) fn sym_char_ranges(sym_map: &SymDimMap) -> FxHashMap<char, ExprBounds> {
|
||||
sym_map
|
||||
.sym_to_char
|
||||
.iter()
|
||||
.map(|(sym_name, sym_char)| {
|
||||
let range = sym_map.ranges.get(sym_name);
|
||||
let min = range
|
||||
.and_then(|range| range.min_val)
|
||||
.map(|min| min.max(0))
|
||||
.or(Some(0));
|
||||
let max = range
|
||||
.and_then(|range| range.max_val)
|
||||
.filter(|max| *max >= 0);
|
||||
(*sym_char, ExprBounds { min, max })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn simplify_expr_with_ranges(
|
||||
expr: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Expression {
|
||||
simplify_bound_expr(expr, sym_ranges).expr
|
||||
}
|
||||
|
||||
pub(crate) fn same_expr_with_ranges(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
let lhs = simplify_bound_expr(lhs, sym_ranges);
|
||||
let rhs = simplify_bound_expr(rhs, sym_ranges);
|
||||
lhs.expr == rhs.expr
|
||||
|| lhs.expr.egglog_equal(rhs.expr)
|
||||
|| (exact_value(lhs) == exact_value(rhs) && exact_value(lhs).is_some())
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_equal_expr(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Option<Expression> {
|
||||
if !same_expr_with_ranges(lhs, rhs, sym_ranges) {
|
||||
return None;
|
||||
}
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, sym_ranges);
|
||||
Some(if lhs_simplified.len() <= rhs_simplified.len() {
|
||||
lhs_simplified
|
||||
} else {
|
||||
rhs_simplified
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_sympy_expr_inner(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<ParsedExpr> {
|
||||
let expr = expr.trim();
|
||||
if expr.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Ok(value) = expr.parse::<i64>() {
|
||||
return Some(ParsedExpr::exact(Expression::from(value), value));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(expr)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
let name = extract_first_quoted(body)?;
|
||||
let bounds = infer_symbol_bounds(body, ranges.get(&name));
|
||||
sym_to_char.get(&name).map(|c| ParsedExpr {
|
||||
expr: Expression::from(*c),
|
||||
bounds,
|
||||
})
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let value = body.trim().parse::<i64>().ok()?;
|
||||
Some(ParsedExpr::exact(Expression::from(value), value))
|
||||
}
|
||||
"NegativeOne" => Some(ParsedExpr::exact(Expression::from(-1i64), -1)),
|
||||
"Zero" => Some(ParsedExpr::exact(Expression::from(0i64), 0)),
|
||||
"One" => Some(ParsedExpr::exact(Expression::from(1i64), 1)),
|
||||
"Mul" | "Add" | "Min" | "Max" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr_inner(iter.next()?, sym_to_char, ranges)?;
|
||||
for part in iter {
|
||||
let rhs = parse_sympy_expr_inner(part, sym_to_char, ranges)?;
|
||||
acc = match head {
|
||||
"Mul" => ParsedExpr {
|
||||
expr: normalize_mul_expr(acc.expr, rhs.expr),
|
||||
bounds: mul_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Add" => ParsedExpr {
|
||||
expr: normalize_add_expr(acc.expr, rhs.expr),
|
||||
bounds: add_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Min" => reduce_min(acc, rhs),
|
||||
"Max" => reduce_max(acc, rhs),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
"FloorDiv" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr / rhs.expr,
|
||||
bounds: div_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
"Mod" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr % rhs.expr,
|
||||
bounds: mod_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_symbol_bounds(body: &str, range: Option<&RangeConstraint>) -> ExprBounds {
|
||||
let mut bounds = ExprBounds::default();
|
||||
if body.contains("positive=True") {
|
||||
bounds.min = Some(1);
|
||||
} else if body.contains("nonnegative=True") {
|
||||
bounds.min = Some(0);
|
||||
}
|
||||
if let Some(range) = range {
|
||||
bounds.min = match (bounds.min, range.min_val) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
(None, Some(rhs)) => Some(rhs),
|
||||
(lhs, None) => lhs,
|
||||
};
|
||||
bounds.max = range.max_val;
|
||||
}
|
||||
bounds
|
||||
}
|
||||
|
||||
fn exact_expr(value: i64) -> BoundedExpr {
|
||||
BoundedExpr {
|
||||
expr: Expression::from(value),
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn exact_value(expr: BoundedExpr) -> Option<i64> {
|
||||
expr.expr.as_num().or({
|
||||
(expr.bounds.min == expr.bounds.max)
|
||||
.then_some(expr.bounds.min)
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn exact_bound_value(bounds: ExprBounds) -> Option<i64> {
|
||||
(bounds.min == bounds.max).then_some(bounds.min).flatten()
|
||||
}
|
||||
|
||||
fn with_bounds(expr: Expression, bounds: ExprBounds) -> BoundedExpr {
|
||||
BoundedExpr { expr, bounds }
|
||||
}
|
||||
|
||||
fn bool_bounds() -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: Some(0),
|
||||
max: Some(1),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expr(expr: Expression) -> Expression {
|
||||
if expr.len() <= 16 {
|
||||
expr.simplify()
|
||||
} else {
|
||||
expr
|
||||
}
|
||||
}
|
||||
|
||||
fn commutative_key(expr: Expression) -> (usize, String) {
|
||||
(expr.len(), format!("{expr:?}"))
|
||||
}
|
||||
|
||||
fn sort_commutative(lhs: Expression, rhs: Expression) -> (Expression, Expression) {
|
||||
if commutative_key(lhs) <= commutative_key(rhs) {
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_add_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs + rhs)
|
||||
}
|
||||
|
||||
fn normalize_mul_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs * rhs)
|
||||
}
|
||||
|
||||
fn checked_add_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_add(rhs))
|
||||
}
|
||||
|
||||
fn checked_sub_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_sub(rhs))
|
||||
}
|
||||
|
||||
fn checked_mul_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_mul(rhs))
|
||||
}
|
||||
|
||||
fn add_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_add_opt(lhs.min, rhs.min),
|
||||
max: checked_add_opt(lhs.max, rhs.max),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) >= 0 && rhs.min.unwrap_or(i64::MIN) >= 0 {
|
||||
return ExprBounds {
|
||||
min: checked_mul_opt(lhs.min, rhs.min),
|
||||
max: checked_mul_opt(lhs.max, rhs.max),
|
||||
};
|
||||
}
|
||||
ExprBounds::default()
|
||||
}
|
||||
|
||||
fn sub_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_sub_opt(lhs.min, rhs.max),
|
||||
max: checked_sub_opt(lhs.max, rhs.min),
|
||||
}
|
||||
}
|
||||
|
||||
fn div_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
let (Some(rhs_min), Some(rhs_max)) = (rhs.min, rhs.max) else {
|
||||
return ExprBounds::default();
|
||||
};
|
||||
if rhs_min <= 0 || rhs_max <= 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
ExprBounds {
|
||||
min: lhs.min.and_then(|lhs_min| lhs_min.checked_div(rhs_max)),
|
||||
max: lhs.max.and_then(|lhs_max| lhs_max.checked_div(rhs_min)),
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) < 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
match exact_bound_value(rhs) {
|
||||
Some(rhs_exact) if rhs_exact > 0 => ExprBounds {
|
||||
min: Some(0),
|
||||
max: rhs_exact.checked_sub(1),
|
||||
},
|
||||
_ => ExprBounds::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_min(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.min(rhs.expr),
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_max(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.max(rhs.expr),
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn min_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn max_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_is_offset_by_small_const(lhs: Expression, rhs: Expression) -> bool {
|
||||
(1..=8).any(|delta| lhs.egglog_equal(rhs + delta))
|
||||
}
|
||||
|
||||
fn split_add_const(expr: Expression) -> Option<(i64, Expression)> {
|
||||
let terms = expr.terms.read();
|
||||
if terms.len() >= 3 && terms.last() == Some(&Term::Add) {
|
||||
if let Some(Term::Num(n)) = terms.first() {
|
||||
return Some((*n, Expression::new(terms[1..terms.len() - 1].to_vec())));
|
||||
}
|
||||
if let Some(Term::Num(n)) = terms.get(terms.len() - 2) {
|
||||
return Some((*n, Expression::new(terms[..terms.len() - 2].to_vec())));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn simplify_add(lhs: BoundedExpr, rhs: BoundedExpr) -> BoundedExpr {
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) => rhs.expr,
|
||||
(_, Some(0)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs + rhs),
|
||||
(_, Some(rhs)) => normalize_add_expr(lhs.expr, Expression::from(rhs)),
|
||||
(Some(lhs), _) => normalize_add_expr(Expression::from(lhs), rhs.expr),
|
||||
_ => normalize_add_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
with_bounds(expr, add_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_sub(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return exact_expr(0);
|
||||
}
|
||||
let expr = match exact_value(rhs) {
|
||||
Some(0) => lhs.expr,
|
||||
Some(rhs_const) => {
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr) {
|
||||
normalize_expr(lhs_base + (lhs_const - rhs_const))
|
||||
} else {
|
||||
normalize_expr(lhs.expr - rhs_const)
|
||||
}
|
||||
}
|
||||
None => normalize_expr(lhs.expr - rhs.expr),
|
||||
};
|
||||
with_bounds(expr, sub_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_min(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = min_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.min(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_max(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = max_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.max(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_bound_expr(expr: Expression, sym_ranges: &FxHashMap<char, ExprBounds>) -> BoundedExpr {
|
||||
let mut stack: Vec<BoundedExpr> = Vec::new();
|
||||
let terms = expr.terms.read().clone();
|
||||
for term in terms {
|
||||
match term {
|
||||
Term::Num(n) => stack.push(exact_expr(n)),
|
||||
Term::Var(c) => stack.push(with_bounds(
|
||||
Expression::from(c),
|
||||
sym_ranges.get(&c).copied().unwrap_or_default(),
|
||||
)),
|
||||
Term::Add => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_add(lhs, rhs));
|
||||
}
|
||||
Term::Sub => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_sub(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Mul => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(0)) => Expression::from(0),
|
||||
(Some(1), _) => rhs.expr,
|
||||
(_, Some(1)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs * rhs),
|
||||
_ => normalize_mul_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mul_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Div | Term::CeilDiv => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(_, Some(0), _) => Expression::from(0),
|
||||
(_, _, Some(1)) => lhs.expr,
|
||||
(Term::Div, Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs / rhs),
|
||||
(Term::CeilDiv, Some(lhs), Some(rhs)) if rhs > 0 => {
|
||||
Expression::from(if lhs % rhs != 0 {
|
||||
lhs / rhs + 1
|
||||
} else {
|
||||
lhs / rhs
|
||||
})
|
||||
}
|
||||
(Term::Div, _, _) => normalize_expr(lhs.expr / rhs.expr),
|
||||
(Term::CeilDiv, _, _) => normalize_expr(lhs.expr.ceil_div(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, div_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Mod => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(1)) => Expression::from(0),
|
||||
(Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs % rhs),
|
||||
_ => normalize_expr(lhs.expr % rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mod_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Min => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_min(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Max => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_max(lhs, rhs, sym_ranges));
|
||||
}
|
||||
term @ (Term::And | Term::Or | Term::Gte | Term::Lt) => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(Term::And, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 && rhs != 0) as i64)
|
||||
}
|
||||
(Term::And, _, _) => normalize_expr(lhs.expr & rhs.expr),
|
||||
(Term::Or, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 || rhs != 0) as i64)
|
||||
}
|
||||
(Term::Or, _, _) => normalize_expr(lhs.expr | rhs.expr),
|
||||
(Term::Gte, Some(lhs), Some(rhs)) => Expression::from((lhs >= rhs) as i64),
|
||||
(Term::Gte, _, _) => normalize_expr(lhs.expr.gte(rhs.expr)),
|
||||
(Term::Lt, Some(lhs), Some(rhs)) => Expression::from((lhs < rhs) as i64),
|
||||
(Term::Lt, _, _) => normalize_expr(lhs.expr.lt(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, bool_bounds()));
|
||||
}
|
||||
}
|
||||
}
|
||||
stack
|
||||
.pop()
|
||||
.unwrap_or(with_bounds(expr, ExprBounds::default()))
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into `(head, body)`.
|
||||
fn split_head(expr: &str) -> Option<(&str, &str)> {
|
||||
let open = expr.find('(')?;
|
||||
if !expr.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&expr[..open], &expr[open + 1..expr.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list.
|
||||
fn extract_first_quoted(expr: &str) -> Option<String> {
|
||||
let bytes = expr.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(expr[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split a sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Drops `key=value` kwargs.
|
||||
fn split_top_level_args(expr: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = expr.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = expr[start..i].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = expr[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
return !key.is_empty() && key.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -15,16 +15,7 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
pub min_val: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
fn same_dim(lhs: Expression, rhs: Expression) -> bool {
|
||||
lhs == rhs || lhs.simplify() == rhs.simplify() || lhs.egglog_equal(rhs)
|
||||
}
|
||||
|
||||
/// Binary operation type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryOp {
|
||||
@@ -55,7 +51,7 @@ pub fn broadcast_binary(mut a: GraphTensor, mut b: GraphTensor) -> (GraphTensor,
|
||||
let a_dim = a.shape.dims[i];
|
||||
let b_dim = b.shape.dims[i];
|
||||
|
||||
if same_dim(a_dim, b_dim) {
|
||||
if a_dim == b_dim {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Which SDPA variant we're translating. Governs argument positions and
|
||||
/// which output slots are consumed downstream.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum SdpaVariant {
|
||||
/// `aten._scaled_dot_product_efficient_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False, *, scale=None)
|
||||
/// -> (output, log_sumexp, philox_seed, philox_offset)`
|
||||
Efficient,
|
||||
/// `aten._scaled_dot_product_flash_attention.default(q, k, v, dropout_p=0.,
|
||||
/// is_causal=False, return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// rng_state, unused, debug_attn_mask)`
|
||||
Flash,
|
||||
/// `aten._scaled_dot_product_flash_attention_for_cpu.default(q, k, v,
|
||||
/// dropout_p=0., is_causal=False, *, attn_mask=None, scale=None)
|
||||
/// -> (output, logsumexp)`
|
||||
FlashForCpu,
|
||||
/// `aten._scaled_dot_product_cudnn_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False,
|
||||
/// return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// philox_seed, philox_offset, debug_attn_mask)`
|
||||
Cudnn,
|
||||
/// `aten.scaled_dot_product_attention.default(q, k, v, attn_mask=None,
|
||||
/// dropout_p=0., is_causal=False, *, scale=None, enable_gqa=False)
|
||||
/// -> Tensor` (single output, no tuple).
|
||||
Unified,
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate any SDPA op variant into `softmax((Q@K^T)*scale + causal_mask +
|
||||
/// attn_bias) @ V`. Stores the primary `output` by the node's first output
|
||||
/// name. Other tuple outputs (logsumexp, philox_seed, etc.) are unused in
|
||||
/// inference — left unbound; the downstream `getitem(node, 0)` resolves
|
||||
/// to `output` via the tuple-output name list.
|
||||
pub(crate) fn translate_sdpa(&mut self, node: &Node, variant: SdpaVariant) -> Result<()> {
|
||||
let query = self.get_input_tensor(node, 0)?;
|
||||
let key = self.get_input_tensor(node, 1)?;
|
||||
let value = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// Resolve args by NAME rather than positional index. PT2 serializes
|
||||
// kwargs inline in `node.inputs` with `kind=2`, so any arg that wasn't
|
||||
// passed positionally by the caller shifts the indices of subsequent
|
||||
// positional args. Name-based lookup is unambiguous across variants
|
||||
// and across caller argument-passing styles.
|
||||
let arg_by_name =
|
||||
|name: &str| -> Option<&NodeInput> { node.inputs.iter().find(|i| i.name == name) };
|
||||
let tensor_arg = |name: &str| -> Option<GraphTensor> {
|
||||
arg_by_name(name)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.and_then(|n| self.get_tensor(n).ok())
|
||||
};
|
||||
let float_arg =
|
||||
|name: &str| -> Option<f64> { arg_by_name(name).and_then(|i| i.arg.as_float()) };
|
||||
let bool_arg =
|
||||
|name: &str| -> Option<bool> { arg_by_name(name).and_then(|i| i.arg.as_bool()) };
|
||||
|
||||
// attn_bias (Efficient/Cudnn/Unified) or attn_mask (FlashForCpu/Unified).
|
||||
let additive = tensor_arg("attn_bias").or_else(|| tensor_arg("attn_mask"));
|
||||
|
||||
let dropout_p = float_arg("dropout_p").unwrap_or(0.0) as f32;
|
||||
anyhow::ensure!(
|
||||
dropout_p == 0.0,
|
||||
"SDPA: dropout_p={dropout_p} unsupported (inference only)"
|
||||
);
|
||||
let is_causal = bool_arg("is_causal").unwrap_or(false);
|
||||
// Silence compiler warnings — variant arg remains for branch-specific
|
||||
// logic (output tuple-name resolution below) and for future divergence.
|
||||
let _ = variant;
|
||||
|
||||
// `scale` kwarg, default 1/sqrt(head_dim).
|
||||
let head_dim = query
|
||||
.shape
|
||||
.dims
|
||||
.last()
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA: query head_dim must be concrete")?;
|
||||
let default_scale = 1.0_f32 / (head_dim as f32).sqrt();
|
||||
let scale = float_arg("scale")
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(default_scale);
|
||||
|
||||
// Math form: scores = (Q @ K^T) * scale; + causal_mask; + attn_bias;
|
||||
// attn = softmax(scores, dim=-1); out = attn @ V.
|
||||
let q_ndim = query.shape.len();
|
||||
anyhow::ensure!(
|
||||
q_ndim >= 2,
|
||||
"SDPA: query must have at least 2 dims (got {q_ndim})"
|
||||
);
|
||||
// Transpose last two dims of key.
|
||||
let mut perm: Vec<usize> = (0..q_ndim).collect();
|
||||
perm.swap(q_ndim - 2, q_ndim - 1);
|
||||
let key_t = key.permute(perm);
|
||||
let (q_for_mm, k_for_mm) = ensure_same_dtype(query, key_t);
|
||||
let scores = q_for_mm.matmul(k_for_mm);
|
||||
let scale_t = self
|
||||
.graph
|
||||
.constant_float(scale)
|
||||
.cast(scores.dtype)
|
||||
.expand_rhs(scores.shape);
|
||||
let mut scores = scores * scale_t;
|
||||
|
||||
if is_causal {
|
||||
let s_q = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 2)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_q must be concrete")?;
|
||||
let s_k = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 1)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_k must be concrete")?;
|
||||
let size = s_q.max(s_k);
|
||||
// triu with diagonal=1 = 1 strictly above diagonal, 0 elsewhere.
|
||||
let mut mask = self.graph.triu(size, 1).cast(DType::F32);
|
||||
if s_q != size || s_k != size {
|
||||
mask = mask.slice_along(0..s_q, 0).slice_along(0..s_k, 1);
|
||||
}
|
||||
// -1e9 * mask ≈ -inf where masked, 0 otherwise. Broadcast across
|
||||
// batch/head prefix dims of `scores`.
|
||||
let neg_large = mask * (-1e9_f32);
|
||||
let mut neg_large = neg_large.cast(scores.dtype);
|
||||
for _ in 0..(q_ndim - 2) {
|
||||
neg_large = neg_large.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let (scores_b, mask_b) = broadcast_binary(scores, neg_large);
|
||||
scores = scores_b + mask_b;
|
||||
}
|
||||
if let Some(bias) = additive {
|
||||
let (scores_b, bias_b) = ensure_same_dtype(scores, bias);
|
||||
let (scores_b, bias_b) = broadcast_binary(scores_b, bias_b);
|
||||
scores = scores_b + bias_b;
|
||||
}
|
||||
|
||||
let attn = scores.softmax(q_ndim - 1);
|
||||
let (attn, value) = ensure_same_dtype(attn, value);
|
||||
let out = attn.matmul(value);
|
||||
|
||||
// Store the primary output by name. The other tuple outputs are
|
||||
// inference-time dead ends — downstream getitem(node, 0) resolves to
|
||||
// the same tensor name we bind here, because pt2 serializes the
|
||||
// multi-output name list with output[0] as the primary slot.
|
||||
let out_name = if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else if variant == SdpaVariant::Unified {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.or_else(|| {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(name) = out_name
|
||||
&& !name.is_empty()
|
||||
{
|
||||
self.tensors.insert(name, out);
|
||||
} else {
|
||||
anyhow::bail!("SDPA: no output tensor name found on node {}", node.target);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for SdpaVariant {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
matches!(
|
||||
(self, other),
|
||||
(SdpaVariant::Efficient, SdpaVariant::Efficient)
|
||||
| (SdpaVariant::Flash, SdpaVariant::Flash)
|
||||
| (SdpaVariant::FlashForCpu, SdpaVariant::FlashForCpu)
|
||||
| (SdpaVariant::Cudnn, SdpaVariant::Cudnn)
|
||||
| (SdpaVariant::Unified, SdpaVariant::Unified)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,40 +1,11 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, same_expr_with_ranges, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
fn normalize_equal_dims(
|
||||
a: &mut GraphTensor,
|
||||
b: &mut GraphTensor,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..a.shape.len() {
|
||||
let lhs = a.shape.dims[i];
|
||||
let rhs = b.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs, rhs, sym_ranges) {
|
||||
a.shape.dims[i] = canonical;
|
||||
b.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn same_dims(
|
||||
lhs: &[Expression],
|
||||
rhs: &[Expression],
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
lhs.len() == rhs.len()
|
||||
&& lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.all(|(lhs, rhs)| same_expr_with_ranges(*lhs, *rhs, sym_ranges))
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -42,18 +13,7 @@ impl<'a> Translator<'a> {
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (mut a, mut b) = broadcast_binary(a, b);
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
normalize_equal_dims(&mut a, &mut b, &sym_ranges);
|
||||
let lhs_dims = a.dims();
|
||||
let rhs_dims = b.dims();
|
||||
if !same_dims(&lhs_dims, &rhs_dims, &sym_ranges) {
|
||||
anyhow::bail!(
|
||||
"binary op {} still has mismatched dims after broadcast: lhs={lhs_dims:?} rhs={rhs_dims:?} inputs={:?}",
|
||||
node.target,
|
||||
node.inputs
|
||||
);
|
||||
}
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
@@ -61,12 +21,6 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -78,13 +32,6 @@ impl<'a> Translator<'a> {
|
||||
op: BinaryOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -107,47 +54,4 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / scalar,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_symbolic_scalar_op(
|
||||
&mut self,
|
||||
a: GraphTensor,
|
||||
val: Expression,
|
||||
op: BinaryOp,
|
||||
) -> GraphTensor {
|
||||
match op {
|
||||
BinaryOp::Add => a + val,
|
||||
BinaryOp::Mul => a * val,
|
||||
BinaryOp::Sub => a - val,
|
||||
BinaryOp::Div => a / val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pt2_expr::simplify_expr_with_ranges;
|
||||
|
||||
#[test]
|
||||
fn simplifies_mark_dynamic_slice_shapes_using_lower_bound() {
|
||||
let a = Expression::from('a');
|
||||
let lhs = (a.min(1) + a).min(a + 1) - 1;
|
||||
let rhs = (a.min(1) + a).min(a);
|
||||
let sym_ranges = [(
|
||||
'a',
|
||||
ExprBounds {
|
||||
min: Some(2),
|
||||
max: None,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, &sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, &sym_ranges);
|
||||
|
||||
assert_eq!(lhs_simplified, Expression::from('a'));
|
||||
assert_eq!(rhs_simplified, Expression::from('a'));
|
||||
assert!(same_expr_with_ranges(lhs, rhs, &sym_ranges));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,11 +389,8 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -5,8 +5,6 @@ use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -70,8 +68,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
|
||||
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
@@ -148,7 +144,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -188,28 +183,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
// `empty` and `empty_permuted` allocate uninitialised tensors of
|
||||
// a given shape; the caller fills them. We lower to zeros with
|
||||
// the same shape+dtype — downstream reads are officially UB on
|
||||
// PyTorch's side, and downstream writes overwrite our zeros.
|
||||
// Qwen3MoE's MoE block uses `empty_permuted` to allocate the
|
||||
// expert-output staging tensor before scatter-adding into it.
|
||||
"torch.ops.aten.empty.memory_format" | "torch.ops.aten.empty_permuted.default" => {
|
||||
self.translate_empty(node)?
|
||||
}
|
||||
// Qwen3-MoE's expert-balance counts tokens-per-expert via histc.
|
||||
"torch.ops.aten.histc.default" => self.translate_histc(node)?,
|
||||
|
||||
// Grouped matmul (MoE expert dispatch).
|
||||
// aten._grouped_mm is the native op; transformers::grouped_mm_fallback
|
||||
// is a Python-implemented custom_op (transformers/integrations/moe.py)
|
||||
// used by HF MoE when _grouped_mm isn't available for the activation
|
||||
// dtype. Both have identical (input, weight, offs) signature; route
|
||||
// both through the same batched-matmul + group-mask lowering.
|
||||
"torch.ops.aten._grouped_mm.default"
|
||||
| "torch.ops.transformers.grouped_mm_fallback.default" => {
|
||||
self.translate_grouped_mm(node)?
|
||||
}
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
@@ -221,16 +194,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -248,13 +211,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -270,11 +226,7 @@ impl<'a> Translator<'a> {
|
||||
let b = b.cast(DType::F32);
|
||||
(a * b).cast(DType::Bool)
|
||||
}
|
||||
"torch.ops.aten.bitwise_or.Tensor" | "torch.ops.aten.logical_or.default" => {
|
||||
// Both arms use the same bool-OR lowering. Gemma-4's sliding+full
|
||||
// attention mask fusion emits bitwise_or on boolean tensors; the
|
||||
// integer semantics of bitwise_or aren't exercised by any op in
|
||||
// the test suite, so we rely on inputs being boolean-typed.
|
||||
"torch.ops.aten.logical_or.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -293,27 +245,18 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -326,14 +269,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.ceil.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// ceil(x) = trunc(x) + (x > trunc(x)).
|
||||
// Cast-to-Int rounds toward zero, so for any positive fractional
|
||||
// `x` the trunc sits below `x` and we add 1; for negatives we
|
||||
// have `trunc >= x` and adjust=0. Avoids the two extra
|
||||
// mul-by-(-1) nodes that the `-floor(-x)` lowering emits.
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = a.gt(trunc).cast(DType::F32);
|
||||
trunc + adjust
|
||||
// ceil(x) = -floor(-x)
|
||||
let neg_a = a * (-1.0);
|
||||
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = neg_a.lt(trunc).cast(DType::F32);
|
||||
let floor_neg = trunc - adjust;
|
||||
floor_neg * (-1.0)
|
||||
}
|
||||
"torch.ops.aten.erf.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -409,17 +350,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -450,29 +380,6 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Scaled dot-product attention — each variant binds args slightly
|
||||
// differently but all lower to matmul+softmax via translate_sdpa.
|
||||
"torch.ops.aten._scaled_dot_product_efficient_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Efficient)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Flash)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::FlashForCpu)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_cudnn_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Cudnn)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten.scaled_dot_product_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Unified)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
@@ -483,28 +390,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod attention;
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
@@ -17,7 +16,6 @@ use anyhow::{Context, Result};
|
||||
use luminal::graph::Graph;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_expr::parse_sympy_expr_with_ranges;
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
@@ -189,21 +187,8 @@ impl<'a> Translator<'a> {
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Ok(v);
|
||||
}
|
||||
// Fall through to symbolic-aware resolution. Op-arg slots like `dim`
|
||||
// and `axis` are always concrete in practice, but with dynamic shapes
|
||||
// PT2 occasionally hands us a SymInt that is fully bound at export
|
||||
// time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`);
|
||||
// accept those when they reduce to a concrete int rather than failing
|
||||
// with the misleading "not an int" diagnostic.
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg)
|
||||
&& let Some(v) = expr.to_usize()
|
||||
{
|
||||
return Ok(v as i64);
|
||||
}
|
||||
anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg)
|
||||
arg.as_int()
|
||||
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
|
||||
@@ -222,37 +207,11 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
|
||||
use crate::pt2_schema::SymIntEntry;
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
// Symbolic int lists: tolerate them as long as every entry is a
|
||||
// bound concrete value. Prevents false "not an int list" failures on
|
||||
// graphs where torch.export emits sym_ints for what is dimensionally
|
||||
// a static parameter (kernel sizes, etc. with dynamic batch).
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
let mut out = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
let v = match entry {
|
||||
SymIntEntry::Int(i) => Some(i.as_int),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.and_then(|e| e.to_usize().map(|u| u as i64)),
|
||||
};
|
||||
match v {
|
||||
Some(n) => out.push(n),
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Input {idx} of {} contains an unresolved sym_int entry",
|
||||
node.target
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(out);
|
||||
}
|
||||
arg.as_ints()
|
||||
.map(|v| v.to_vec())
|
||||
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))
|
||||
@@ -280,13 +239,13 @@ impl<'a> Translator<'a> {
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(ints) = arg.as_ints() {
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v)).collect());
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v as usize)).collect());
|
||||
}
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
return entries
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.with_context(|| format!("Cannot resolve sym_int: {}", s.as_name)),
|
||||
@@ -319,13 +278,17 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn dim_size_to_expr(&self, dim: &DimSize) -> Result<Expression> {
|
||||
match dim {
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
DimSize::Expr(e) => self.resolve_expr_value(&e.as_expr).with_context(|| {
|
||||
format!(
|
||||
"Cannot resolve symbolic dimension expression: {}",
|
||||
e.as_expr.expr_str
|
||||
)
|
||||
}),
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
DimSize::Expr(e) => {
|
||||
let sym_name = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
.with_context(|| format!("Cannot parse symbol: {}", e.as_expr.expr_str))?;
|
||||
let c = self
|
||||
.sym_map
|
||||
.sym_to_char
|
||||
.get(&sym_name)
|
||||
.with_context(|| format!("Unknown symbol: {sym_name}"))?;
|
||||
Ok(Expression::from(*c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,9 +299,10 @@ impl<'a> Translator<'a> {
|
||||
.get("as_expr")
|
||||
.and_then(|e| e.get("expr_str"))
|
||||
.and_then(|s| s.as_str())
|
||||
&& let Some(expr) = self.resolve_expr_str(expr_str)
|
||||
&& let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(expr);
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = val
|
||||
.get("as_expr")
|
||||
@@ -346,7 +310,7 @@ impl<'a> Translator<'a> {
|
||||
.and_then(|h| h.get("as_int"))
|
||||
.and_then(|v| v.as_i64())
|
||||
{
|
||||
return Some(Expression::from(hint));
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
@@ -354,32 +318,21 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn resolve_arg_as_expression(&self, arg: &Argument) -> Option<Expression> {
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Some(Expression::from(v));
|
||||
return Some(Expression::from(v as usize));
|
||||
}
|
||||
if let Some(name) = arg.as_sym_int_name() {
|
||||
return self.resolve_sym_int(name);
|
||||
}
|
||||
if let Argument::Expr(e) = arg {
|
||||
return self.resolve_expr_value(&e.as_expr);
|
||||
if let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = e.as_expr.hint.as_ref().and_then(|h| h.as_int()) {
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_str(&self, expr_str: &str) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr_str, &self.sym_map.sym_to_char, &self.sym_map.ranges)
|
||||
.or_else(|| {
|
||||
crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
.and_then(|sym| self.sym_map.sym_to_char.get(&sym).copied())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_value(&self, expr: &ExprValue) -> Option<Expression> {
|
||||
self.resolve_expr_str(&expr.expr_str).or_else(|| {
|
||||
expr.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
@@ -13,25 +11,6 @@ const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
fn normalize_concat_dims(
|
||||
lhs: &mut GraphTensor,
|
||||
rhs: &mut GraphTensor,
|
||||
skip_dim: Option<usize>,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..lhs.shape.len() {
|
||||
if Some(i) == skip_dim {
|
||||
continue;
|
||||
}
|
||||
let lhs_dim = lhs.shape.dims[i];
|
||||
let rhs_dim = rhs.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs_dim, rhs_dim, sym_ranges) {
|
||||
lhs.shape.dims[i] = canonical;
|
||||
rhs.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -141,47 +120,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -222,17 +160,8 @@ impl<'a> Translator<'a> {
|
||||
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
for t in &tensors[1..] {
|
||||
let mut next = *t;
|
||||
normalize_concat_dims(&mut result, &mut next, Some(dim), &sym_ranges);
|
||||
|
||||
let lhs_axis = result.dims()[dim];
|
||||
let rhs_axis = next.dims()[dim];
|
||||
let mut lhs_padded = result.pad_along(0, rhs_axis, dim, 0.);
|
||||
let mut rhs_padded = next.pad_along(lhs_axis, 0, dim, 0.);
|
||||
normalize_concat_dims(&mut lhs_padded, &mut rhs_padded, None, &sym_ranges);
|
||||
result = lhs_padded + rhs_padded;
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -330,15 +259,21 @@ impl<'a> Translator<'a> {
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
|
||||
// Normalize negative indices for this dimension. Stay in Int —
|
||||
// multiplying an Int tensor by an Expression broadcasts the axis
|
||||
// size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int
|
||||
// for the result, Bool→F32 for the negative mask) per indexed dim.
|
||||
let axis_size = src_shape[dim_idx];
|
||||
let idx_int = idx_tensor.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
|
||||
let is_negative = idx_int.lt(zero).cast(DType::Int);
|
||||
let idx_int = idx_int + is_negative * axis_size;
|
||||
// Normalize negative indices for this dimension
|
||||
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"index.Tensor: dim {} must be concrete for negative index normalization",
|
||||
dim_idx
|
||||
)
|
||||
})?;
|
||||
let idx_f32 = idx_tensor.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_size as f32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_negative = idx_f32.lt(zero).cast(DType::F32);
|
||||
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let stride = &strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
@@ -404,34 +339,20 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
// (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for
|
||||
// the negative mask) that the previous F32-routed path emitted.
|
||||
let axis_dim = a.shape.dims[dim];
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(indices_int.shape);
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
|
||||
})?;
|
||||
let indices_f32 = indices.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_dim as f32)
|
||||
.expand_rhs(indices_f32.shape);
|
||||
let is_negative = indices_f32.lt(zero).cast(DType::F32);
|
||||
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -475,39 +396,14 @@ impl<'a> Translator<'a> {
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
// Broadcast the (often scalar) value tensor to match data shape,
|
||||
// then blend by mask. Cast mask to data's dtype for the
|
||||
// arithmetic so this works for both integer and float data.
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// where(mask, value, a) as `a + mask*(value - a)`. Saves a mul
|
||||
// and the `1.0` constant compared to the `a*(1 - m) + v*m`
|
||||
// form; works for any numeric dtype without a dedicated cond.
|
||||
return Ok(a + mask_f * (values_b - a));
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. idx_tensor can
|
||||
// be ANY shape — its whole shape is "batch dims" in scatter_nd terms,
|
||||
// and K is always 1 (number of dims we're indexing into). Always pad
|
||||
// a trailing size-1 dim so the rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
|
||||
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
|
||||
let indices = if indices.shape.len() == 1 {
|
||||
indices.expand_dim(1, Expression::from(1usize))
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
|
||||
@@ -6,20 +6,6 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -51,26 +37,16 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -94,100 +70,4 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,97 +72,6 @@ impl<'a> Translator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Lower `aten.histc.default` for the integer-bincount case.
|
||||
///
|
||||
/// Qwen3-MoE's expert-balance layer calls
|
||||
/// `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)` to count how
|
||||
/// many tokens were routed to each expert. With those args every
|
||||
/// integer value `i ∈ [0, K-1]` maps to exactly bin `i`, and the result
|
||||
/// is equivalent to `torch.bincount`. We implement that case as a
|
||||
/// broadcast equality + sum:
|
||||
///
|
||||
/// counts[b] = sum_i (input[i] == b + min) for b in [0, bins)
|
||||
///
|
||||
/// More general histc bin widths (`bins != max - min + 1`, or
|
||||
/// non-integer values that span fractional bins) are not supported
|
||||
/// today — the equality path would silently drop them. We bail rather
|
||||
/// than produce wrong counts.
|
||||
pub(crate) fn translate_histc(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let bins_i64: i64 = self
|
||||
.get_int_arg(node, 1)
|
||||
.context("histc: missing `bins` arg (#1)")?;
|
||||
// `min`/`max` are float kwargs (default 0.0 each, which means
|
||||
// "auto-pick from input"); for the qwen3-moe call they're always
|
||||
// integers passed as floats.
|
||||
let min = self.get_float_arg(node, 2).unwrap_or(0.0);
|
||||
let max = self.get_float_arg(node, 3).unwrap_or(0.0);
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 1,
|
||||
"histc: only 1D input is supported, got {}D",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
bins_i64 > 0,
|
||||
"histc: bins must be positive, got {}",
|
||||
bins_i64
|
||||
);
|
||||
// Bincount-equivalent case: one integer value per bin.
|
||||
anyhow::ensure!(
|
||||
(max - min - (bins_i64 - 1) as f64).abs() < 1e-6,
|
||||
"histc: only the bincount-equivalent case (bins == max - min + 1) is \
|
||||
supported; got bins={}, min={}, max={}. Other cases would need a \
|
||||
general bin-width / right-edge-inclusion implementation.",
|
||||
bins_i64,
|
||||
min,
|
||||
max,
|
||||
);
|
||||
|
||||
let bins_u = bins_i64 as usize;
|
||||
let n = input.shape.dims[0];
|
||||
|
||||
// arange(bins) [bins] → cast to input dtype, optionally shift by min,
|
||||
// broadcast to [bins, N], compare for equality with input broadcast.
|
||||
let mut bins_arange = self.graph.arange(Expression::from(bins_u));
|
||||
if min != 0.0 {
|
||||
// `min` is non-zero (uncommon in the qwen3-moe path but legal)
|
||||
// — shift the comparison values to start at min.
|
||||
let min_i = min as i64;
|
||||
let shift = self
|
||||
.graph
|
||||
.constant_float(min_i as f32)
|
||||
.cast(bins_arange.dtype)
|
||||
.expand_rhs(bins_arange.shape);
|
||||
bins_arange += shift;
|
||||
}
|
||||
let bins_expanded = bins_arange.cast(input.dtype).expand_dim(1, n);
|
||||
let input_expanded = input.expand_dim(0, Expression::from(bins_u));
|
||||
let matches = input_expanded.eq(bins_expanded); // Bool [bins, N]
|
||||
|
||||
let out_dtype = self.output_meta_dtype(node)?;
|
||||
Ok(matches.cast(out_dtype).sum(1))
|
||||
}
|
||||
|
||||
/// Lower `aten.empty.memory_format` and `aten.empty_permuted.default`.
|
||||
///
|
||||
/// Both allocate an uninitialised tensor; the caller is responsible for
|
||||
/// writing into it. We materialise zeros instead — luminal has no
|
||||
/// "uninitialised" notion, and PyTorch's contract on `empty` outputs is
|
||||
/// undefined for any read prior to a write, so a zero-fill is sound.
|
||||
/// `aten.empty_permuted` additionally takes a `physical_layout` arg
|
||||
/// (the storage permutation); for a zero-filled tensor that's a no-op.
|
||||
pub(crate) fn translate_empty(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let zero = self.graph.constant_float(0.0).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
zero
|
||||
} else {
|
||||
zero.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
@@ -193,154 +102,33 @@ impl<'a> Translator<'a> {
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
/// Translate `aten._grouped_mm.default(input, weight, offs)` → `Tensor[S, N]`.
|
||||
///
|
||||
/// Grouped matmul: `input` is `[S, K]` (tokens sorted by expert), `weight` is
|
||||
/// `[G, K, N]` (per-expert weights), `offs` is `[G]` cumulative token counts.
|
||||
/// Output `[S, N]` where token m (in group g s.t. `offs[g-1] <= m < offs[g]`)
|
||||
/// is multiplied by `weight[g]`.
|
||||
///
|
||||
/// Implementation: for each token m we (a) compute its expert id from offs,
|
||||
/// (b) gather only that expert's `[K, N]` slice from weight, and (c) do a
|
||||
/// single per-token matmul. The gather pattern mirrors the rust qwen3_moe
|
||||
/// example's `gather_experts`, which the GLUMoE host-op fusion in
|
||||
/// `luminal_cuda_lite` is designed to recognise.
|
||||
///
|
||||
/// Why not the straightforward `[G, S, K] @ [G, K, N] → [G, S, N]` + mask:
|
||||
/// it forces a full F32 cast of the entire `[G, K, N]` weight tensor as
|
||||
/// search-time intermediate, which OOMs on real MoE checkpoints
|
||||
/// (Qwen3-30B-A3B: 1.5 GB / layer × 48 layers for gate-up alone). Gathering
|
||||
/// first keeps the F32 cast on `[S, K, N]` instead — for prefill (S = top_k)
|
||||
/// that is a 16× shrink (G=128, top_k=8).
|
||||
///
|
||||
/// `offs` flows through as a runtime tensor — the routing decision is computed
|
||||
/// at execution time by the gate network and the same compiled graph handles
|
||||
/// any routing pattern without recompilation.
|
||||
pub(crate) fn translate_grouped_mm(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let offs = self.get_input_tensor(node, 2)?;
|
||||
let out_dtype = self.output_meta_dtype(node)?;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 2,
|
||||
"_grouped_mm: input must be 2D, got {}D",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
weight.shape.len() == 3,
|
||||
"_grouped_mm: weight must be 3D, got {}D",
|
||||
weight.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
offs.shape.len() == 1,
|
||||
"_grouped_mm: offs must be 1D, got {}D",
|
||||
offs.shape.len()
|
||||
);
|
||||
|
||||
let s = input.shape.dims[0];
|
||||
let g = weight.shape.dims[0];
|
||||
let k = weight.shape.dims[1];
|
||||
let n = weight.shape.dims[2];
|
||||
|
||||
// expert_id[m] = number of g s.t. m >= offs[g], clamped to [0, G-1].
|
||||
// Same value as HF MoE's `expert_ids.clamp(0, num_experts-1)` for
|
||||
// invalid expert IDs from EP, AND protects search-time profiling:
|
||||
// dummy-1 input bytes give offs=[1,…,1], which pushes the raw count
|
||||
// to G for any token with index ≥ 1 and would OOB the weight gather.
|
||||
//
|
||||
// Stay in Int throughout — arange / offs are already Int, ge → Bool
|
||||
// → cast(Int), sum stays Int, and the binary `minimum` handles the
|
||||
// clamp without an F32 round-trip.
|
||||
let _ = g
|
||||
.to_usize()
|
||||
.context("_grouped_mm: G (num_experts) must be concrete")?;
|
||||
let s_arange = self.graph.arange(s); // Int [S]
|
||||
let ge_int = s_arange
|
||||
.expand_dim(0, g)
|
||||
.ge(offs.expand_dim(1, s)) // Bool [G, S]
|
||||
.cast(DType::Int); // Int [G, S]
|
||||
let raw = ge_int.sum(0); // Int [S], values in [0, G]
|
||||
let cap = self.graph.constant(g - 1).expand_dim(0, s); // Int [S], all G-1
|
||||
let expert_id = raw.minimum(cap); // Int [S]
|
||||
|
||||
// Flat gather index into weight (treated as a length-G*K*N 1D buffer):
|
||||
// flat[m, k_, n_] = expert_id[m] * (K*N) + k_ * N + n_
|
||||
// Encoded as `Mul(expert_id, Iota(io_const)) + Iota(MIter, K*N)` so the
|
||||
// resulting Gather matches the GLUMoE / gather-experts egglog patterns.
|
||||
let io = k * n;
|
||||
let base = expert_id * io;
|
||||
let within = self.graph.iota(Expression::from('z'), (k, n));
|
||||
let exp_base = base.expand_dim(1, k).expand_dim(2, n);
|
||||
let exp_within = within.expand_dim(0, s);
|
||||
let flat_idx = exp_base + exp_within;
|
||||
|
||||
// Gather → [S, K, N], then normalize both operands to the op's declared
|
||||
// output dtype before matmul. On real Qwen3-MoE bf16 checkpoints the FX
|
||||
// graph inserts casts on the activation path, and relying on the input
|
||||
// tensor's translated dtype can leave us with mixed F32/Bf16 operands
|
||||
// by the time matmul expands into elementwise Mul. Using the PT2 output
|
||||
// metadata keeps the matmul dtype aligned with the exported contract
|
||||
// without upcasting the full expert weight bank.
|
||||
let weight_gathered = weight.gather(flat_idx).cast(out_dtype);
|
||||
let input = input.cast(out_dtype);
|
||||
|
||||
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
|
||||
// Operands stay in their native dtype — no F32 cast on the gathered
|
||||
// weight or the input. The earlier cast(F32) was a holdover from the
|
||||
// broadcast-and-mask version (which had to use F32 because of the
|
||||
// cast(F32) on the mask). Gather-then-matmul has no such requirement,
|
||||
// and casting `[S, K, N]` to F32 doubled the gather scratch (~100 MB
|
||||
// to ~200 MB per layer for Qwen3-30B-A3B prefill). Matmul rewrites
|
||||
// (cuBLASLt etc.) handle bf16 input with F32 accumulator internally.
|
||||
let result = input.unsqueeze(1).matmul(weight_gathered).squeeze(1);
|
||||
|
||||
Ok(result.cast(out_dtype))
|
||||
}
|
||||
|
||||
/// Build the where-formula graph: `cond * x + (1 - cond) * y`, computed
|
||||
/// in F32, cast back to `out_dtype`. Shared between `translate_where`,
|
||||
/// `translate_where_scalar_other`, and `translate_masked_fill_scalar` so
|
||||
/// they all go through one well-tested code path.
|
||||
pub(crate) fn where_formula(
|
||||
&mut self,
|
||||
cond: GraphTensor,
|
||||
x: GraphTensor,
|
||||
y: GraphTensor,
|
||||
out_dtype: DType,
|
||||
) -> GraphTensor {
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
// Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops
|
||||
// plus the explicit `1.0` constant. Mathematically identical for
|
||||
// c ∈ {0, 1} and produces the same F32 output type.
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
// Cast back: an F32 result downstream-interpreted as bf16 walks the
|
||||
// buffer at half-stride, returning every-other-element zeros.
|
||||
(y_f + c * (x_f - y_f)).cast(out_dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
let out_dtype = x.dtype;
|
||||
Ok(self.where_formula(cond, x, y, out_dtype))
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
let out_dtype = x.dtype;
|
||||
// Build a tensor for the scalar `other` matching `x`'s shape so we
|
||||
// can route through the shared where_formula helper.
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(x.shape);
|
||||
Ok(self.where_formula(cond, x, other, out_dtype))
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -395,37 +183,33 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names
|
||||
let tuple_outputs = node.outputs.first().and_then(|o| o.as_tensors.as_ref());
|
||||
let values_name = if let Some(ts) = tuple_outputs {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
};
|
||||
let indices_name = if let Some(ts) = tuple_outputs {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build top-k outputs from a full stable argsort. Slice the indices
|
||||
// before gathering values so the gather shape matches the requested
|
||||
// top-k output rather than the full sort width.
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(topk_indices, dim);
|
||||
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
self.tensors.insert(idx_name, topk_indices);
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -51,19 +51,13 @@ impl<'a> Translator<'a> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
let dtype_int = input
|
||||
.arg
|
||||
.as_int()
|
||||
.map(|i| i as u32)
|
||||
.or_else(|| input.arg.as_scalar_type());
|
||||
if let Some(d) = dtype_int {
|
||||
let dtype = torch_dtype_int_to_luminal(d);
|
||||
// Skip emitting a Cast op when the dtype already matches —
|
||||
// PT2 graphs frequently emit `_to_copy` purely as a clone hint
|
||||
// (e.g. dtype=float32 on a tensor that is already F32), and
|
||||
// every redundant Cast inflates the graph and survives until
|
||||
// optimization passes can prove it as a no-op.
|
||||
return Ok(if a.dtype == dtype { a } else { a.cast(dtype) });
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,34 +131,37 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
// `masked_fill(input, mask, fill)` = `where(mask, fill, input)`.
|
||||
// Routes through the shared `where_formula` helper so we exercise
|
||||
// the exact same code path as `aten.where.self`, which is verified
|
||||
// to handle the bf16 cast-back correctly. Hand-rolling the same
|
||||
// formula directly here used to drift (egglog made different
|
||||
// rewrite choices on the rebuilt-locally graph), so we deliberately
|
||||
// re-use the helper.
|
||||
// `aten.masked_fill.Scalar(input, mask, fill)` ≡
|
||||
// `aten.where.self(mask, full_like(input, fill), input)`. The
|
||||
// `full_like + where` sequence is the verified-working path
|
||||
// (test: `where(mask, torch.zeros_like(x), x)` round-trips with
|
||||
// max_diff = 0); we reproduce its exact graph-build order here.
|
||||
// Hand-rolling the formula in any other shape (single-mul, F32
|
||||
// throughout, alternative constant-cast orderings) routes egglog
|
||||
// through a rewrite that returns an F32 buffer downstream-read as
|
||||
// bf16 — the every-other-element-zero pattern.
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let out_dtype = input.dtype;
|
||||
// Build fill_t exactly like translate_full_like does:
|
||||
// constant_float(val).cast(dtype).expand_rhs(reference.shape)
|
||||
let fill_t = self
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(out_dtype)
|
||||
.expand_rhs(input.shape);
|
||||
Ok(self.where_formula(mask, fill_t, input, out_dtype))
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -213,18 +210,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -275,52 +266,4 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,10 +77,7 @@ class CompiledModel:
|
||||
)
|
||||
user_inputs = inputs
|
||||
|
||||
# Use the first *user* input for device detection — when torch.compile
|
||||
# has lifted SymInts or weights into the call args, `inputs[0]` may not
|
||||
# be a tensor. user_inputs has been filtered to actual tensors.
|
||||
input_device = user_inputs[0].device if user_inputs else torch.device("cpu")
|
||||
input_device = inputs[0].device if inputs else torch.device("cpu")
|
||||
|
||||
# Auto-detect dynamic dims from input shapes
|
||||
if self._has_dynamic_dims:
|
||||
@@ -135,11 +132,6 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
@@ -155,12 +147,11 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -188,13 +179,9 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype in _int_dtypes:
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
|
||||
@@ -11,21 +11,14 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device."""
|
||||
# Dynamo can prefix `example_inputs` with SymInt entries when shapes are
|
||||
# dynamic — those have no `.device`. Pick the first real tensor instead.
|
||||
first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None)
|
||||
device = first_tensor.device if first_tensor is not None else torch.device("cpu")
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError) as exc:
|
||||
raise RuntimeError(
|
||||
"CUDA input was provided, but luminal_python was not built with "
|
||||
"the cuda feature. Rebuild with `maturin develop --features cuda` "
|
||||
"or run through `run_tests_cuda.sh`/the Modal CUDA test runner."
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
@@ -83,7 +76,7 @@ def register_backend(factory_capsule):
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule)
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
@@ -102,7 +95,7 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -110,8 +103,8 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule)
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule, options=options)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user