mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
24 Commits
perf/paral
...
cuda_133
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
caa036dca8 | ||
|
|
4cd47ffa45 | ||
|
|
db72cf505c | ||
|
|
766db93b08 | ||
|
|
4e93f02725 | ||
|
|
25393a9fdd | ||
|
|
81ea750e6b | ||
|
|
f94335b1b8 | ||
|
|
f62e3c50d0 | ||
|
|
eeeabd7c20 | ||
|
|
0f02466f3d | ||
|
|
156fac518e | ||
|
|
a3df68bd43 | ||
|
|
7a95e56a8b | ||
|
|
e558ce6849 | ||
|
|
c898b7fd53 | ||
|
|
6cfbf538d0 | ||
|
|
966f6f8147 | ||
|
|
8ea9a71747 | ||
|
|
861c3f0419 | ||
|
|
8f17561094 | ||
|
|
d5e9001c8b | ||
|
|
6416ddb5f8 | ||
|
|
c9d4ce6217 |
@@ -1,3 +1,6 @@
|
||||
[alias]
|
||||
examples = "run --release --bin examples-perf --"
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
rustflags = [
|
||||
"-Ctarget-feature=+fp16,+fhm"
|
||||
|
||||
2
.github/workflows/test-core.yml
vendored
2
.github/workflows/test-core.yml
vendored
@@ -21,4 +21,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
run: cargo test --release -p luminal -p luminal_nn -p luminal_tracing -p luminal_python --verbose
|
||||
|
||||
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Test Full CUDA
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
rust_cuda_ignored_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Rust CUDA Ignored Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run ignored CUDA Rust tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
GPU_TYPE: H100
|
||||
MODAL_TIMEOUT: "14400"
|
||||
CARGO_TEST_ARGS: "--ignored --test-threads=1"
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
python_cuda_slow_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Python CUDA Slow Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run slow pytest CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100-80GB --timeout 14400 tests/ -v -s -m slow
|
||||
17
.github/workflows/test-metal.yml
vendored
17
.github/workflows/test-metal.yml
vendored
@@ -17,3 +17,20 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
llama_1b_metal_example:
|
||||
name: Llama 1B Metal Example
|
||||
runs-on: macos-14-xlarge
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Print runner hardware
|
||||
run: system_profiler SPHardwareDataType SPDisplaysDataType
|
||||
- name: Cache Hugging Face models
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: llama-1b-metal-hf-${{ runner.os }}-${{ runner.arch }}-v1
|
||||
- name: Run Llama 1B Metal example and validate output
|
||||
run: rustup update; python3 ci/metal_llama_1b_example.py
|
||||
|
||||
12
AGENTS.md
12
AGENTS.md
@@ -8,4 +8,14 @@ All other functionality is split into crates in the `crates/` directory. For ins
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
|
||||
## Debugging and Correctness
|
||||
- Treat model examples as specifications of the intended architecture. Do not change model code, prompt templates, weights, or example logic to hide compiler/runtime/search bugs unless the model code is demonstrably semantically wrong.
|
||||
- When outputs are incorrect, first root-cause the failing compiler/runtime path. Prefer isolating the bad LLIR/HLIR graph, rewrite, op lowering, shape/stride assumption, layout contract, or runtime implementation that caused the mismatch.
|
||||
- Avoid narrow special-case fixes. A fix should state and enforce the general invariant it relies on, or explicitly document why the affected operation is only valid for a restricted layout/shape and ensure rewrites enforce that restriction.
|
||||
- For e-graph/search issues, assume all selectable LLIR graphs are intended to be semantically equivalent. If two selectable graphs disagree, debug the equivalence violation rather than selecting around the bad graph.
|
||||
- Add regression tests at the level where the bug occurred. Prefer tests that compare against a semantic reference such as `NativeRuntime` or a small independent reference, and use fixed seeds for any randomized search/fuzz test so failures are reproducible.
|
||||
|
||||
## Compiler Rewrite Boundary
|
||||
- All graph pattern matching and op selection must be expressed in egglog rewrites. Do not add Rust-side LLIR graph post-passes that search for op patterns, fuse kernels, select backend ops, or otherwise rewrite extracted graphs after egglog. If a backend needs a fused/specialized op, add the match and rewrite in egglog and let extraction produce that op directly.
|
||||
|
||||
50
README.md
50
README.md
@@ -55,23 +55,27 @@ Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100.
|
||||
|
||||
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
### PyTorch-native
|
||||
|
||||
Luminal directly integrates with PyTorch as a compiler backend. Simply do `torch.compile(model, backend=luminal_cuda)` to compile your PyTorch models. We also have an excellent tensor API in Rust.
|
||||
|
||||
### RISC-style architecture
|
||||
|
||||
Everything in Luminal boils down to 14 primitive ops:
|
||||
Everything in Luminal boils down to 15 primitive ops:
|
||||
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Scatter, Cast`
|
||||
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model.
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model in the world.
|
||||
|
||||
### Search
|
||||
|
||||
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
|
||||
The best heuristic is no heuristic. Luminal tries to search every possible decision to give the compiler the flexibility to discover complex optimizations. This allows us to automatically discover Flash Attention and other similarly complex optimizations without relying on hand-written operations or heuristics. It also allows us to stay extremely small and simple long into the future and beat the performance of far larger frameworks.
|
||||
|
||||
### Native
|
||||
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the accelerator APIs (CUDA, Metal, etc.). No indirections or abstractions, compatability layers, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
@@ -85,39 +89,45 @@ Most deep learning libraries are eager-first, meaning each op call directly oper
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### What about XLA?
|
||||
|
||||
XLA, torch.compile, TVM, and other traditional compiler stacks suffer from complexity explosion. They are made up of a very large set of destructive (one-direction) rewrite rules that lower and optimize a graph from a high-level representation to low-level machine code. But since these rules are destructive, they are required to only fire when it's certian that there's a performance benefit. This leads to the rules becoming very complex, special-cased, and numerous. Once additional hardware backends, model architectures, and new dtypes get thrown in, they suffer from the weight of their complexity and often produce very suboptimal code, requiring DSLs like Pallas or Triton to regain performance.
|
||||
|
||||
### Compile everything
|
||||
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as a static computation graphs, compiled, and executed later.
|
||||
|
||||
### First-class dynamism
|
||||
|
||||
A fully-static world would be nice, but we live in a world of nessecary dynamism. So we model dynamic shapes natively, as symbolic dimensions. Luminal supports arbitrary symbolic dimensions, including complex expressions, to give us shapes like `(s, 4096)`, `(b, h, w + 3)`, etc. This rich representation gives the compiler full visibility into shapes and lets it still do aggressive specialization.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, Luminal has global knowledge. This means we can push most ML complexity to the compiler. For instance, devices, datatypes, and even autograd is modeled ahead of time and optimized by the compiler!
|
||||
|
||||
Now we can do:
|
||||
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
- Low-precision dtypes (mxfp4, nvfp4, fp8, etc.)
|
||||
- Complex mutli-device parallelism topologies, searched ahead-of-time
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures
|
||||
|
||||
## Where are we?
|
||||
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- Native PyTorch support
|
||||
- Many kernel libraries supported in the search space (FlashInfer, cuBLASLt, etc.)
|
||||
- Many models implemented in our Rust tensor API in `examples/`.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
|
||||
- Expand the search space to utilize Tensor Cores more flexibly
|
||||
- Bring cuda to parity with Metal
|
||||
- Add Blackwell intrinsics, such as TMEM and TMA
|
||||
- Build a ROCm backend
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM inference _and_ training
|
||||
- More fine-grained dialects supporting thread- and warp-level intrinsics like TMA and tcgen.05
|
||||
- ROCm backend
|
||||
- More public infernce accelerator backends (coming very soon...)
|
||||
- Public benchmarking suite
|
||||
- Automatically searched model parallelism (TP, PP, EPS, EPR, SP, etc.)
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
|
||||
85
ci/example_output.py
Normal file
85
ci/example_output.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import re
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
EXPECTED_CONCEPTS = {
|
||||
"llama": [
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "adapt"],
|
||||
["data", "patterns", "features"],
|
||||
],
|
||||
"gemma": [
|
||||
["neural network", "neural networks"],
|
||||
["nodes", "neurons"],
|
||||
["layers"],
|
||||
["weights"],
|
||||
["training", "learn", "learns"],
|
||||
],
|
||||
"qwen": [
|
||||
["neural network", "neural networks"],
|
||||
["computational model", "computational system"],
|
||||
["brain"],
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "training"],
|
||||
],
|
||||
"qwen3_moe": [
|
||||
["capital"],
|
||||
["france"],
|
||||
["paris"],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
normalized_output = normalize_output(output)
|
||||
|
||||
expected_concepts = EXPECTED_CONCEPTS.get(example)
|
||||
if expected_concepts is not None:
|
||||
missing = [
|
||||
concept_group
|
||||
for concept_group in expected_concepts
|
||||
if not any(normalize_output(term) in normalized_output for term in concept_group)
|
||||
]
|
||||
if missing:
|
||||
expected = "\n - ".join(" / ".join(group) for group in expected_concepts)
|
||||
missing_terms = "\n - ".join(" / ".join(group) for group in missing)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}.\n"
|
||||
f"Expected concept groups:\n - {expected}\n"
|
||||
f"Missing concept groups:\n - {missing_terms}"
|
||||
)
|
||||
|
||||
expected = ", ".join(" / ".join(group) for group in expected_concepts)
|
||||
print(f"\nOutput check passed for {example!r}: found concepts {expected}")
|
||||
return
|
||||
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
185
ci/examples_perf.py
Normal file
185
ci/examples_perf.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
|
||||
DEFAULT_EXAMPLES = ["llama", "gemma", "qwen", "qwen3_moe", "gemma4_moe", "whisper"]
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"llama": ["run", "--release", "-p", "llama"],
|
||||
"gemma": ["run", "--release", "-p", "gemma"],
|
||||
"qwen": ["run", "--release", "-p", "qwen", "--features", "cuda"],
|
||||
"qwen3_moe": ["run", "--release", "-p", "qwen3_moe"],
|
||||
"gemma4_moe": ["run", "--release", "-p", "gemma4_moe"],
|
||||
"whisper": ["run", "--release", "-p", "whisper"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
ttft_ms: float | None = None
|
||||
tpot_ms: float | None = None
|
||||
tps: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleResult:
|
||||
name: str
|
||||
ok: bool
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
wall_s: float = 0.0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = [arg for arg in sys.argv[1:] if arg != "--"]
|
||||
if any(arg in {"-h", "--help"} for arg in args):
|
||||
print_help()
|
||||
return
|
||||
if "--list" in args:
|
||||
print("\n".join(DEFAULT_EXAMPLES))
|
||||
return
|
||||
|
||||
examples = args or DEFAULT_EXAMPLES
|
||||
results = [run_example(example) for example in examples]
|
||||
print_table(results)
|
||||
if any(not result.ok for result in results):
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def print_help() -> None:
|
||||
print(
|
||||
"Run validated Luminal examples, validate textual output, and summarize perf.\n"
|
||||
"\n"
|
||||
"Usage:\n"
|
||||
" cargo examples\n"
|
||||
" cargo examples llama qwen whisper\n"
|
||||
"\n"
|
||||
"Options:\n"
|
||||
" --list Print the default validated examples\n"
|
||||
" -h, --help\n"
|
||||
"\n"
|
||||
f"The default set matches the Modal examples CI: {', '.join(DEFAULT_EXAMPLES)}."
|
||||
)
|
||||
|
||||
|
||||
def run_example(example: str) -> ExampleResult:
|
||||
cargo_args = EXAMPLE_CARGO_ARGS.get(example)
|
||||
if cargo_args is None:
|
||||
known = ", ".join(DEFAULT_EXAMPLES)
|
||||
return ExampleResult(example, False, error=f"unknown example; known examples: {known}")
|
||||
|
||||
print(f"\n=== Running {example} ===")
|
||||
print(f"$ cargo {' '.join(cargo_args)}")
|
||||
started = time.monotonic()
|
||||
env = os.environ.copy()
|
||||
env.setdefault("CUDARC_CUDA_VERSION", "12080")
|
||||
process = subprocess.Popen(
|
||||
["cargo", *cargo_args],
|
||||
cwd=repo_root(),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
wall_s = time.monotonic() - started
|
||||
metrics = parse_metrics(output)
|
||||
|
||||
if return_code:
|
||||
return ExampleResult(
|
||||
example,
|
||||
False,
|
||||
metrics=metrics,
|
||||
wall_s=wall_s,
|
||||
error=f"process exited with code {return_code}",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_output(example, output)
|
||||
except Exception as exc:
|
||||
return ExampleResult(example, False, metrics=metrics, wall_s=wall_s, error=str(exc))
|
||||
|
||||
return ExampleResult(example, True, metrics=metrics, wall_s=wall_s)
|
||||
|
||||
|
||||
def repo_root() -> str:
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def parse_metrics(output: str) -> Metrics:
|
||||
metrics = Metrics()
|
||||
for line in output.splitlines():
|
||||
if "TTFT:" in line:
|
||||
metrics.ttft_ms = parse_number_after(line, "TTFT:")
|
||||
if "TPOT:" in line:
|
||||
metrics.tpot_ms = parse_number_after(line, "TPOT:")
|
||||
if "tok/s" in line:
|
||||
metrics.tps = parse_tok_per_second(line)
|
||||
if metrics.tps is None and metrics.tpot_ms:
|
||||
metrics.tps = 1000.0 / metrics.tpot_ms
|
||||
return metrics
|
||||
|
||||
|
||||
def parse_number_after(line: str, marker: str) -> float | None:
|
||||
tail = line.split(marker, 1)[1].lstrip()
|
||||
chars = []
|
||||
for char in tail:
|
||||
if char.isdigit() or char == ".":
|
||||
chars.append(char)
|
||||
else:
|
||||
break
|
||||
if not chars:
|
||||
return None
|
||||
return float("".join(chars))
|
||||
|
||||
|
||||
def parse_tok_per_second(line: str) -> float | None:
|
||||
head = line.split("tok/s", 1)[0].rstrip(" (")
|
||||
parts = head.split()
|
||||
if not parts:
|
||||
return None
|
||||
try:
|
||||
return float(parts[-1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def print_table(results: list[ExampleResult]) -> None:
|
||||
print("\nSummary")
|
||||
print(f"{'example':<14} {'status':<8} {'TTFT ms':>10} {'TPOT ms':>10} {'tok/s':>10} {'wall s':>10}")
|
||||
print("-" * 68)
|
||||
for result in results:
|
||||
status = "ok" if result.ok else "failed"
|
||||
print(
|
||||
f"{result.name:<14} {status:<8} "
|
||||
f"{format_metric(result.metrics.ttft_ms):>10} "
|
||||
f"{format_metric(result.metrics.tpot_ms):>10} "
|
||||
f"{format_metric(result.metrics.tps):>10} "
|
||||
f"{result.wall_s:>10.1f}"
|
||||
)
|
||||
if result.error:
|
||||
print(f" error: {result.error}")
|
||||
|
||||
|
||||
def format_metric(value: float | None) -> str:
|
||||
return "-" if value is None else f"{value:.2f}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
ci/metal_llama_1b_example.py
Normal file
48
ci/metal_llama_1b_example.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
sys.path.insert(0, os.path.join(repo_root, "ci"))
|
||||
from example_output import validate_output
|
||||
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "luminal_metal", "--example", "llama_1b"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("Llama 1B Metal example did not complete generation")
|
||||
validate_output("llama", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
ci/metal_qwen_example.py
Normal file
46
ci/metal_qwen_example.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "qwen", "--features", "metal"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("qwen Metal example did not complete generation")
|
||||
validate_output("qwen", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,10 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import shlex
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
modal_timeout = int(os.environ.get("MODAL_TIMEOUT", "7200"))
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
@@ -28,7 +30,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=modal_timeout,
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -43,17 +45,20 @@ def run_cargo_test():
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
test_args = shlex.split(os.environ.get("CARGO_TEST_ARGS", "--test-threads=1"))
|
||||
cmd = [
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
*test_args,
|
||||
]
|
||||
print("Running:", " ".join(cmd), flush=True)
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cmd,
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
@@ -21,28 +20,8 @@ hf_cache = modal.Volume.from_name(
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"llama": [
|
||||
"complex system modeled after the structure and function of the human brain",
|
||||
],
|
||||
"gemma": [
|
||||
"recognize pictures of cats",
|
||||
"little detectives looking for specific features",
|
||||
],
|
||||
"qwen": [
|
||||
"computational model inspired by the structure and function of the human brain",
|
||||
],
|
||||
"qwen3_moe": [
|
||||
"The capital of France is Paris",
|
||||
],
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"qwen": ["--features", "cuda"],
|
||||
}
|
||||
|
||||
|
||||
@@ -72,28 +51,6 @@ def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str
|
||||
return output
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
normalized_output = normalize_output(output)
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
@@ -123,6 +80,8 @@ cuda_image = (
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
sys.path.insert(0, f"{WORKDIR}/ci")
|
||||
from example_output import validate_output
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
@@ -130,7 +89,7 @@ def run_example(example: str):
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release"],
|
||||
["cargo", "run", "--release", *EXAMPLE_CARGO_ARGS.get(example, [])],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ fn run_metal_pattern_benchmark(
|
||||
let mut cx = Graph::default();
|
||||
pattern.build_graph(&mut cx, *size);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, 5);
|
||||
let mut rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
|
||||
@@ -41,7 +41,7 @@ struct PreparedBench {
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, 5);
|
||||
let rt = cx.search(rt, CompileOptions::new(5));
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
|
||||
@@ -41,7 +41,7 @@ mod metal_backend {
|
||||
const NAME: &'static str = "Metal";
|
||||
|
||||
fn build_search_space(cx: &mut Graph) {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
cudarc = {version="0.19.7", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
@@ -29,6 +29,7 @@ colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
|
||||
@@ -29,9 +29,21 @@ impl DynBackend for CudaLiteDynBackend {
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_f16(&self, node: NodeIndex) -> Vec<half::f16> {
|
||||
self.runtime.get_f16(node)
|
||||
}
|
||||
fn get_output_bf16(&self, node: NodeIndex) -> Vec<half::bf16> {
|
||||
self.runtime.get_bf16(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
|
||||
self.runtime.get_i64(node)
|
||||
}
|
||||
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
|
||||
self.runtime.get_f64(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
//! ComputeAttnMask — fused op that computes the paged attention mask from indptrs.
|
||||
//!
|
||||
//! This op exists so the indptr tensors (qo_indptr, kv_indptr) are visible in the
|
||||
//! same e-graph chunk as the attention pattern, letting the FlashInfer egglog rule
|
||||
//! capture them directly.
|
||||
//!
|
||||
//! Inputs (3): q_pos (s,) Int, qo_indptr (r,) Int, kv_indptr (r,) Int.
|
||||
//! Output: mask (s, c) F32 where mask[i, j] = 0.0 (attend) or -1e10 (block).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, HLIROp, LLIROp},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaStream, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Computes the paged attention mask from indptr arrays.
|
||||
///
|
||||
/// The mask encodes both request-membership and causality:
|
||||
/// `mask[i, j] = 0.0` if query `i` and context `j` belong to the same request AND
|
||||
/// context `j`'s local position is `<= q_pos[i]`; `-1e10` otherwise.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ComputeAttnMask {
|
||||
pub s_dim: Expression,
|
||||
pub c_dim: Expression,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ComputeAttnMask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ComputeAttnMask(s={}, c={})", self.s_dim, self.c_dim)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for ComputeAttnMask {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (ComputeAttnMask {} {}) (ICons {} (ICons {} (ICons {} (INil)))))",
|
||||
self.s_dim.to_egglog(),
|
||||
self.c_dim.to_egglog(),
|
||||
inputs[0].1, // q_pos
|
||||
inputs[1].1, // qo_indptr
|
||||
inputs[2].1, // kv_indptr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for ComputeAttnMask {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"ComputeAttnMask",
|
||||
&[("s_dim", EXPRESSION), ("c_dim", EXPRESSION)],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrites — inserted directly by model code.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let s_dim = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let c_dim = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let op = Self { s_dim, c_dim };
|
||||
let llir_op = LLIROp::new::<dyn HostOp>(Box::new(op) as Box<dyn HostOp>);
|
||||
(llir_op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for ComputeAttnMask {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 3 {
|
||||
anyhow::bail!(
|
||||
"ComputeAttnMask expects 3 inputs (q_pos, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let s = self
|
||||
.s_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask s_dim unresolved"))?;
|
||||
let c = self
|
||||
.c_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask c_dim unresolved"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask requires dynamic dim 'r'"))?;
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("ComputeAttnMask missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_pos_buf = get_buf("q_pos", inputs[0])?;
|
||||
let qo_indptr_buf = get_buf("qo_indptr", inputs[1])?;
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[2])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
let q_pos = dtoh_i32(stream, q_pos_buf.ptr(), s)?;
|
||||
let qo_indptr = dtoh_i32(stream, qo_indptr_buf.ptr(), r)?;
|
||||
let kv_indptr = dtoh_i32(stream, kv_indptr_buf.ptr(), r)?;
|
||||
|
||||
let mut mask = vec![-1e10f32; s * c];
|
||||
for i in 0..s {
|
||||
let q_req = indptr_to_request(&qo_indptr, i as i32);
|
||||
for j in 0..c {
|
||||
let c_req = indptr_to_request(&kv_indptr, j as i32);
|
||||
if q_req == c_req && q_req >= 0 {
|
||||
let c_local = j as i32 - kv_indptr[c_req as usize];
|
||||
if c_local <= q_pos[i] {
|
||||
mask[i * c + j] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mask_bytes =
|
||||
unsafe { std::slice::from_raw_parts(mask.as_ptr() as *const u8, mask.len() * 4) };
|
||||
unsafe {
|
||||
let res = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
out_buf.ptr(),
|
||||
mask_bytes.as_ptr() as *const std::ffi::c_void,
|
||||
mask_bytes.len(),
|
||||
);
|
||||
if res != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
|
||||
anyhow::bail!("ComputeAttnMask cuMemcpyHtoD failed: {res:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.s_dim * self.c_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("ComputeAttnMask")
|
||||
}
|
||||
}
|
||||
|
||||
fn dtoh_i32(stream: &Arc<CudaStream>, dev_ptr: u64, len: usize) -> anyhow::Result<Vec<i32>> {
|
||||
let mut host = vec![0u8; len * std::mem::size_of::<i32>()];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, dev_ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let v = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host);
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut i32, len, len)
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Given an indptr array `[0, a, b, ...]`, find which segment `idx` belongs to.
|
||||
/// Returns `count(indptr[i] <= idx) - 1`.
|
||||
fn indptr_to_request(indptr: &[i32], idx: i32) -> i32 {
|
||||
indptr.iter().filter(|&&v| v <= idx).count() as i32 - 1
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
|
||||
|
||||
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
|
||||
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
// Strip quotes if present (egglog strings are stored with quotes)
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasSgemmV2 {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
/// Lazily initialized cuBLAS handle - created on first execute
|
||||
cublas: OnceLock<Arc<CudaBlas>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasSgemmV2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
cublas: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cublas: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasSgemmV2 {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as i32;
|
||||
let n = self.n.exec(dyn_map).unwrap() as i32;
|
||||
let k = self.k.exec(dyn_map).unwrap() as i32;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i32;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * 4,
|
||||
b_buf.len(),
|
||||
k * n * 4,
|
||||
c_buf.len(),
|
||||
m * n * 4
|
||||
);
|
||||
let _sgemm_span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLAS_SGEMM_V2",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
?a_layout,
|
||||
?b_layout,
|
||||
)
|
||||
.entered();
|
||||
|
||||
// Use shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
|
||||
|
||||
// Set the stream for this operation (cuBLAS handle can work with any stream)
|
||||
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
|
||||
unsafe {
|
||||
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
|
||||
}
|
||||
|
||||
let status = unsafe {
|
||||
cublasSgemm_v2(
|
||||
*cublas.handle(),
|
||||
a_layout,
|
||||
b_layout,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha as *const f32,
|
||||
a_ptr as *const f32,
|
||||
lda,
|
||||
b_ptr as *const f32,
|
||||
ldb,
|
||||
&beta as *const f32,
|
||||
c_ptr as *mut f32,
|
||||
ldc,
|
||||
)
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLAS SGEMM TN failed with status: {:?}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
|
||||
self.output_size() * 4
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -79,8 +81,12 @@
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -10,8 +10,454 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the scaled FP8 linear form directly before the unscaled FP8
|
||||
; matmul rewrite can hide the quantize/dequant scale structure.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
(= ?scaled (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(= ?cast (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
(
|
||||
(delete (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(delete (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers direct output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Fusion growth can make the live path consume a raw FP8 cuBLASLt
|
||||
; candidate through an internal CudaBinaryElementwise scale multiply,
|
||||
; instead of the original HLIR output-scale Mul. The scalar scale
|
||||
; product is tensor-wide, so the two scalar factors can be passed as
|
||||
; cuBLASLt A/B scale inputs and the internal multiply can be bypassed.
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(let ?fs_sgemm (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
(union ?fused_scale ?fs_sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
(set (dtype ?fs_sgemm) (F32))
|
||||
)
|
||||
:ruleset fusion_grow
|
||||
:name "cublaslt scaled fp8 fused output-scale f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(= ?fused_scale (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
)
|
||||
(
|
||||
(delete (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(delete (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers fused output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Batched form of the scaled FP8 linear rewrite. The scale operands are
|
||||
; scalar tensors expanded across the last three output/activation axes.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -59,8 +505,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -108,8 +558,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -157,8 +611,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -220,8 +678,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -283,8 +745,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -5,8 +5,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -54,8 +58,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -103,8 +111,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -152,8 +164,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -201,8 +217,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -264,8 +284,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -327,8 +351,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -390,8 +418,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -35,10 +35,20 @@ use crate::{
|
||||
},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasLt {
|
||||
@@ -69,6 +79,8 @@ pub struct CuBlasLt {
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
a_scale_input: bool,
|
||||
b_scale_input: bool,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
|
||||
@@ -103,52 +115,62 @@ impl Default for CuBlasLt {
|
||||
alpha: 1.0,
|
||||
beta: 0.0,
|
||||
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
a_scale_input: false,
|
||||
b_scale_input: false,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CuBlasLtScaled;
|
||||
|
||||
fn cublaslt_sort(name: &'static str) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
name,
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("a_order", STRING),
|
||||
("b_order", STRING),
|
||||
("c_order", STRING),
|
||||
("d_order", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("ldd", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("stride_d", EXPRESSION),
|
||||
("a_dtype", DTYPE),
|
||||
("b_dtype", DTYPE),
|
||||
("c_dtype", DTYPE),
|
||||
("d_dtype", DTYPE),
|
||||
("compute_type", STRING),
|
||||
("scale_dtype", STRING),
|
||||
("alpha", F64),
|
||||
("beta", F64),
|
||||
("epilogue", STRING),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLt {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublaslt",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("a_order", STRING),
|
||||
("b_order", STRING),
|
||||
("c_order", STRING),
|
||||
("d_order", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("ldd", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("stride_d", EXPRESSION),
|
||||
("a_dtype", DTYPE),
|
||||
("b_dtype", DTYPE),
|
||||
("c_dtype", DTYPE),
|
||||
("d_dtype", DTYPE),
|
||||
("compute_type", STRING),
|
||||
("scale_dtype", STRING),
|
||||
("alpha", F64),
|
||||
("beta", F64),
|
||||
("epilogue", STRING),
|
||||
],
|
||||
)
|
||||
cublaslt_sort("cublaslt")
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
let c_input = usize::from(self.beta != 0.0);
|
||||
let bias_input = usize::from(epilogue_uses_bias(self.epilogue));
|
||||
2 + c_input + bias_input
|
||||
let scale_inputs = usize::from(self.a_scale_input) + usize::from(self.b_scale_input);
|
||||
2 + c_input + bias_input + scale_inputs
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
@@ -158,40 +180,69 @@ impl EgglogOp for CuBlasLt {
|
||||
(cublaslt_base_dtype (F32))
|
||||
(cublaslt_base_dtype (F16))
|
||||
(cublaslt_base_dtype (Bf16))
|
||||
(cublaslt_base_dtype (TF32))",
|
||||
(cublaslt_base_dtype (TF32))
|
||||
(relation cublaslt_fp8_dtype (DType))
|
||||
(cublaslt_fp8_dtype (F8E4M3))
|
||||
(cublaslt_fp8_dtype (F8E5M2))
|
||||
(relation cublaslt_fp8_f32_output_pair (DType DType))
|
||||
(cublaslt_fp8_f32_output_pair (F8E4M3) (F8E4M3))
|
||||
(cublaslt_fp8_f32_output_pair (F8E4M3) (F8E5M2))
|
||||
(cublaslt_fp8_f32_output_pair (F8E5M2) (F8E4M3))",
|
||||
),
|
||||
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
Rule::raw(include_str!["cublaslt_fp8_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_mixed_dtype_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_scale_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_beta_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_epilogue_rewrite.egg"]),
|
||||
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
|
||||
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
|
||||
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
|
||||
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
|
||||
// (not the Mul eclass), so they survive the cascade.
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
// cuBLASLt now specializes GenericMatmul, so cleanup should prune
|
||||
// the matmul output alternatives directly. Do not delete the
|
||||
// broadcast Mul here; it may still have non-matmul consumers.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-scaled-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
]
|
||||
}
|
||||
@@ -277,6 +328,104 @@ impl EgglogOp for CuBlasLt {
|
||||
alpha,
|
||||
beta,
|
||||
epilogue,
|
||||
a_scale_input: false,
|
||||
b_scale_input: false,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLtScaled {
|
||||
fn sort(&self) -> SortDef {
|
||||
cublaslt_sort("cublaslt_scaled")
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
let a_layout = parse_cublas_op(&egraph.enodes[kind_children[3]].0);
|
||||
let b_layout = parse_cublas_op(&egraph.enodes[kind_children[4]].0);
|
||||
let a_order = parse_cublaslt_order(&egraph.enodes[kind_children[5]].0);
|
||||
let b_order = parse_cublaslt_order(&egraph.enodes[kind_children[6]].0);
|
||||
let c_order = parse_cublaslt_order(&egraph.enodes[kind_children[7]].0);
|
||||
let d_order = parse_cublaslt_order(&egraph.enodes[kind_children[8]].0);
|
||||
|
||||
let lda = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
|
||||
let ldd = extract_expr(egraph, kind_children[12], expr_cache).unwrap();
|
||||
|
||||
let batch_count = extract_expr(egraph, kind_children[13], expr_cache).unwrap();
|
||||
let stride_a = extract_expr(egraph, kind_children[14], expr_cache).unwrap();
|
||||
let stride_b = extract_expr(egraph, kind_children[15], expr_cache).unwrap();
|
||||
let stride_c = extract_expr(egraph, kind_children[16], expr_cache).unwrap();
|
||||
let stride_d = extract_expr(egraph, kind_children[17], expr_cache).unwrap();
|
||||
|
||||
let a_dtype = extract_dtype(egraph, kind_children[18]);
|
||||
let b_dtype = extract_dtype(egraph, kind_children[19]);
|
||||
let c_dtype = extract_dtype(egraph, kind_children[20]);
|
||||
let d_dtype = extract_dtype(egraph, kind_children[21]);
|
||||
let compute_type_str = &egraph.enodes[kind_children[22]].0;
|
||||
let scale_dtype_str = &egraph.enodes[kind_children[23]].0;
|
||||
let compute_type = parse_cublaslt_compute_type(compute_type_str, a_dtype);
|
||||
let scale_dtype = parse_cublaslt_scale_dtype(scale_dtype_str, a_dtype);
|
||||
let alpha = parse_cublaslt_scalar(&egraph.enodes[kind_children[24]].0);
|
||||
let beta = parse_cublaslt_scalar(&egraph.enodes[kind_children[25]].0);
|
||||
let epilogue = parse_cublaslt_epilogue(&egraph.enodes[kind_children[26]].0);
|
||||
|
||||
let extracted_state = CuBlasLt {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
a_order,
|
||||
b_order,
|
||||
c_order,
|
||||
d_order,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
batch_count,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
stride_d,
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
d_dtype,
|
||||
compute_type,
|
||||
scale_dtype,
|
||||
alpha,
|
||||
beta,
|
||||
epilogue,
|
||||
a_scale_input: true,
|
||||
b_scale_input: true,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
@@ -520,6 +669,8 @@ struct LtMatmulPointers {
|
||||
c: u64,
|
||||
d: u64,
|
||||
bias: Option<u64>,
|
||||
a_scale: Option<u64>,
|
||||
b_scale: Option<u64>,
|
||||
}
|
||||
|
||||
struct LtRawDescriptors {
|
||||
@@ -667,12 +818,12 @@ fn run_cublaslt_matmul(
|
||||
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
@@ -728,13 +879,17 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(scale) = &a_scale {
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &a_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(scale) = &b_scale {
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &b_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
@@ -857,6 +1012,8 @@ fn resolve_cublaslt_pointers(
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
beta: f64,
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
a_scale_input: bool,
|
||||
b_scale_input: bool,
|
||||
) -> anyhow::Result<LtMatmulPointers> {
|
||||
if inputs.len() < 2 {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -877,24 +1034,25 @@ fn resolve_cublaslt_pointers(
|
||||
.get(&self_node)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt output buffer"))?
|
||||
.ptr();
|
||||
let mut next_input = 2;
|
||||
let c = if beta == 0.0 {
|
||||
d
|
||||
} else if let Some(c_input) = inputs.get(2) {
|
||||
} else {
|
||||
let c_input = inputs.get(next_input).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul with beta={beta} requires a third C input")
|
||||
})?;
|
||||
next_input += 1;
|
||||
buffers
|
||||
.get(c_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt C input buffer"))?
|
||||
.ptr()
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLASLt matmul with beta={beta} requires a third C input"
|
||||
));
|
||||
};
|
||||
|
||||
let bias_input_index = if beta == 0.0 { 2 } else { 3 };
|
||||
let bias = if epilogue_uses_bias(epilogue) {
|
||||
let bias_input = inputs.get(bias_input_index).ok_or_else(|| {
|
||||
let bias_input = inputs.get(next_input).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul with {epilogue:?} epilogue requires a bias input")
|
||||
})?;
|
||||
next_input += 1;
|
||||
Some(
|
||||
buffers
|
||||
.get(bias_input)
|
||||
@@ -905,7 +1063,44 @@ fn resolve_cublaslt_pointers(
|
||||
None
|
||||
};
|
||||
|
||||
Ok(LtMatmulPointers { a, b, c, d, bias })
|
||||
let a_scale = if a_scale_input {
|
||||
let scale_input = inputs
|
||||
.get(next_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("cuBLASLt matmul requires an A scale input pointer"))?;
|
||||
next_input += 1;
|
||||
Some(
|
||||
buffers
|
||||
.get(scale_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt A scale input buffer"))?
|
||||
.ptr(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let b_scale = if b_scale_input {
|
||||
let scale_input = inputs
|
||||
.get(next_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("cuBLASLt matmul requires a B scale input pointer"))?;
|
||||
Some(
|
||||
buffers
|
||||
.get(scale_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt B scale input buffer"))?
|
||||
.ptr(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(LtMatmulPointers {
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
d,
|
||||
bias,
|
||||
a_scale,
|
||||
b_scale,
|
||||
})
|
||||
}
|
||||
|
||||
fn epilogue_uses_bias(epilogue: cublasLtEpilogue_t) -> bool {
|
||||
@@ -978,6 +1173,11 @@ impl CuBlasLt {
|
||||
&& normalize(self.stride_c) == normalize(self.stride_d)
|
||||
&& self.c_order == self.d_order
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn tensor_scale_inputs(&self) -> (bool, bool) {
|
||||
(self.a_scale_input, self.b_scale_input)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
@@ -1022,7 +1222,15 @@ impl HostOp for CuBlasLt {
|
||||
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
|
||||
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(self_node, inputs, buffers, self.beta, self.epilogue)?;
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
self_node,
|
||||
inputs,
|
||||
buffers,
|
||||
self.beta,
|
||||
self.epilogue,
|
||||
self.a_scale_input,
|
||||
self.b_scale_input,
|
||||
)?;
|
||||
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
@@ -1197,6 +1405,8 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1221,6 +1431,8 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1245,6 +1457,8 @@ mod tests {
|
||||
&buffers,
|
||||
1.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1269,6 +1483,8 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1279,6 +1495,41 @@ mod tests {
|
||||
assert_eq!(ptrs.bias, Some(0xB1A5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_pointers_use_tensor_scale_inputs_after_base_inputs() {
|
||||
let output = NodeIndex::new(0);
|
||||
let a = NodeIndex::new(1);
|
||||
let b = NodeIndex::new(2);
|
||||
let a_scale = NodeIndex::new(3);
|
||||
let b_scale = NodeIndex::new(4);
|
||||
let buffers = buffers_for(&[
|
||||
(output, 0xD000),
|
||||
(a, 0xA000),
|
||||
(b, 0xB000),
|
||||
(a_scale, 0xA5A5),
|
||||
(b_scale, 0xB5B5),
|
||||
]);
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
output,
|
||||
&[a, b, a_scale, b_scale],
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ptrs.a, 0xA000);
|
||||
assert_eq!(ptrs.b, 0xB000);
|
||||
assert_eq!(ptrs.c, 0xD000);
|
||||
assert_eq!(ptrs.d, 0xD000);
|
||||
assert_eq!(ptrs.bias, None);
|
||||
assert_eq!(ptrs.a_scale, Some(0xA5A5));
|
||||
assert_eq!(ptrs.b_scale, Some(0xB5B5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_pointers_reject_two_input_nonzero_beta() {
|
||||
let output = NodeIndex::new(0);
|
||||
@@ -1292,6 +1543,8 @@ mod tests {
|
||||
&buffers,
|
||||
1.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
@@ -1314,6 +1567,8 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
|
||||
@@ -27,19 +27,16 @@ pub fn find_indptr_inputs<'a>(
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
assert!(
|
||||
mask_kind_label.contains("Add"),
|
||||
"find_indptr_inputs: mask is not an Add (kind={mask_kind_label})"
|
||||
);
|
||||
|
||||
let mask_inputs = walk_ilist_simple(egraph, &mask_children[1]);
|
||||
let mask_inputs = logical_binary_inputs(egraph, mask_node, "Add").unwrap_or_else(|| {
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
panic!("find_indptr_inputs: mask is not an Add (kind={mask_kind_label})");
|
||||
});
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
@@ -98,15 +95,9 @@ fn find_1e10_mul<'a>(
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
if label != "Op" {
|
||||
let Some(mul_inputs) = logical_binary_inputs(egraph, input_node, "Mul") else {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("Mul") {
|
||||
continue;
|
||||
}
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
};
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
@@ -152,6 +143,7 @@ fn find_1e10_mul<'a>(
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let node = resolve_op_with_kind(egraph, node, "Constant").unwrap_or(node);
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
@@ -246,3 +238,91 @@ fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) ->
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
|
||||
fn resolve_op_with_kind<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
kind_substr: &str,
|
||||
) -> Option<&'a NodeId> {
|
||||
let class = egraph.node_to_class.get(node)?;
|
||||
for candidate in &egraph.eclasses[class].1 {
|
||||
let (label, children) = &egraph.enodes[candidate];
|
||||
if label != "Op" || children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains(kind_substr) {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn logical_binary_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
op_name: &str,
|
||||
) -> Option<Vec<&'a NodeId>> {
|
||||
if let Some(op_node) = resolve_op_with_kind(egraph, node, op_name) {
|
||||
let (_, children) = &egraph.enodes[op_node];
|
||||
return Some(walk_ilist_simple(egraph, &children[1]));
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains("CudaBinaryElementwise") {
|
||||
let opcode_class = egraph.enodes[kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
return Some(
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if !egraph.enodes[kind].0.contains("FusionEnd") {
|
||||
return None;
|
||||
}
|
||||
let fe_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
let elem = *fe_inputs.first()?;
|
||||
let (elem_label, elem_children) = &egraph.enodes[elem];
|
||||
if elem_label != "Op" || elem_children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let elem_kind = resolve_first_node(egraph, &elem_children[0]);
|
||||
if !egraph.enodes[elem_kind].0.contains("CudaBinaryElementwise") {
|
||||
return None;
|
||||
}
|
||||
let opcode_class = egraph.enodes[elem_kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
walk_ilist_simple(egraph, &elem_children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn unwrap_fusion_start<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> &'a NodeId {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return node;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("FusionStart") {
|
||||
return node;
|
||||
}
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.first()
|
||||
.copied()
|
||||
.unwrap_or(node)
|
||||
}
|
||||
|
||||
@@ -89,6 +89,16 @@
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; FlashInfer needs qo_indptr/kv_indptr to be recoverable from the mask
|
||||
; expression. Do not match examples that pass a precomputed mask Input.
|
||||
(= ?mask (Op (Add ?inner_mask_shape ?inner_mask_a_strides ?inner_mask_b_strides ?inner_mask_out_strides)
|
||||
(ICons ?mask_scaled_allowed (ICons ?mask_offset (INil)))))
|
||||
(= ?mask_scaled_allowed (Op (Mul ?allowed_shape ?allowed_strides ?scale_const_strides ?scaled_allowed_strides)
|
||||
(ICons ?mask_allowed (ICons ?mask_scale_const (INil)))))
|
||||
(= ?mask_scale_const (Op (Constant ?mask_scale_val) (INil)))
|
||||
(> ?mask_scale_val 9999999999.0)
|
||||
(< ?mask_scale_val 10000000001.0)
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
|
||||
@@ -2,19 +2,14 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub use compute_attn_mask::ComputeAttnMask;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
compute_attn_mask::ComputeAttnMask,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
@@ -79,6 +74,16 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTensorScaleInputs = (bool, bool);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_tensor_scale_inputs(op: &dyn HostOp) -> Option<CublasLtTensorScaleInputs> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::tensor_scale_inputs)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
|
||||
@@ -195,6 +195,10 @@
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
@@ -211,6 +215,37 @@
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 2 (SwiGLU with row-normalized top-k weights) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?normed_topk (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 2))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (normalized swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
|
||||
@@ -50,7 +50,7 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - SwiGLU/SwiGLUNormalized: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
@@ -78,6 +78,7 @@ pub struct GLUMoE {
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
SwiGLUNormalized,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
@@ -85,6 +86,7 @@ impl GLUMoEMode {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
2 => Self::SwiGLUNormalized,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
@@ -93,7 +95,7 @@ impl GLUMoEMode {
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU => 0,
|
||||
Self::SwiGLU | Self::SwiGLUNormalized => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
@@ -383,22 +385,22 @@ impl HostOp for GLUMoE {
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let topk_bytes = seq * top_k * 4;
|
||||
let min_topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < topk_bytes {
|
||||
if topk_idx_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < topk_bytes {
|
||||
if topk_vals_buf.len() < min_topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
@@ -440,24 +442,83 @@ impl HostOp for GLUMoE {
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
|
||||
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
|
||||
);
|
||||
if !topk_idx_i32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index element count {} is not divisible by seq {seq}",
|
||||
topk_idx_i32.len()
|
||||
);
|
||||
}
|
||||
if !topk_vals_f32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value element count {} is not divisible by seq {seq}",
|
||||
topk_vals_f32.len()
|
||||
);
|
||||
}
|
||||
let topk_idx_row_stride = topk_idx_i32.len() / seq;
|
||||
let topk_vals_row_stride = topk_vals_f32.len() / seq;
|
||||
if topk_idx_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index row stride {topk_idx_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
if topk_vals_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value row stride {topk_vals_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
|
||||
let topk_idx_at = |token: usize, expert: usize| -> i32 {
|
||||
topk_idx_i32[token * topk_idx_row_stride + expert]
|
||||
};
|
||||
let topk_val_at = |token: usize, expert: usize| -> f32 {
|
||||
topk_vals_f32[token * topk_vals_row_stride + expert]
|
||||
};
|
||||
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i);
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at token {t} top-k position {i} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - SwiGLUNormalized: normalize topk values row-wise
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::SwiGLU => {
|
||||
if topk_vals_row_stride == top_k {
|
||||
topk_vals_f32
|
||||
} else {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i);
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
}
|
||||
GLUMoEMode::SwiGLUNormalized => {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i) * inv_norm;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
@@ -471,12 +532,10 @@ impl HostOp for GLUMoE {
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
@@ -485,7 +544,8 @@ impl HostOp for GLUMoE {
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
expert_weights_storage[t * top_k + i] =
|
||||
topk_val_at(t, i) * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
@@ -525,12 +585,10 @@ impl HostOp for GLUMoE {
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
let expert_idx = expert_idx as usize;
|
||||
for (i, &weight) in weights.iter().enumerate() {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
|
||||
// a. Gate+Up matmul (BF16 in, BF16 out)
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
|
||||
738
crates/luminal_cuda_lite/src/kernel/conv2d.rs
Normal file
738
crates/luminal_cuda_lite/src/kernel/conv2d.rs
Normal file
@@ -0,0 +1,738 @@
|
||||
//! CUDA conv2d-with-bias backend rewrite.
|
||||
//!
|
||||
//! `KernelConv2D` is selected by egglog from pure HLIR conv graphs and lowers
|
||||
//! to a one-thread-per-output CUDA kernel. It avoids materializing unfold/im2col
|
||||
//! intermediates while keeping model code free of custom ops.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::FxHashMap;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::FxHashSet,
|
||||
shape::{Expression, flatten_strides},
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::{KernelOp, hlir::generate_dyn_dims_defines};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConv2D {
|
||||
out_shape: Vec<Expression>,
|
||||
input_shape: Vec<Expression>,
|
||||
input_stride: Vec<Expression>,
|
||||
weight_co_stride: Expression,
|
||||
weight_inner_stride: Expression,
|
||||
bias_c_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
kernel_h: Expression,
|
||||
kernel_w: Expression,
|
||||
stride_h: Expression,
|
||||
stride_w: Expression,
|
||||
dilation_h: Expression,
|
||||
dilation_w: Expression,
|
||||
pad_h: Expression,
|
||||
pad_w: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelConv2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelConv2D",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("input_shape", ELIST),
|
||||
("input_stride", ELIST),
|
||||
("weight_co_stride", EXPRESSION),
|
||||
("weight_inner_stride", EXPRESSION),
|
||||
("bias_c_stride", EXPRESSION),
|
||||
("out_stride", ELIST),
|
||||
("kernel_h", EXPRESSION),
|
||||
("kernel_w", EXPRESSION),
|
||||
("stride_h", EXPRESSION),
|
||||
("stride_w", EXPRESSION),
|
||||
("dilation_h", EXPRESSION),
|
||||
("dilation_w", EXPRESSION),
|
||||
("pad_h", EXPRESSION),
|
||||
("pad_w", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// 1x1 convs in Flux2's VAE are represented without `unfold`:
|
||||
//
|
||||
// input.permute([H,W,C]).merge(H,W)
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute back to [C_out,H,W]
|
||||
// -> + channel bias
|
||||
//
|
||||
// The lowered form is still the same Mul -> KernelSum -> Add
|
||||
// matmul skeleton, but the lhs FusionStart reads directly from the
|
||||
// original input instead of a KernelGather window tensor.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered bias matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the same conv after generic CUDA lowering has normalized
|
||||
// the elementwise pieces into fusion regions:
|
||||
//
|
||||
// KernelGather(input windows)
|
||||
// -> CudaBinaryElementwise("Mul", weight)
|
||||
// -> KernelSum(reduce K)
|
||||
// -> CudaBinaryElementwise("Add", bias)
|
||||
//
|
||||
// This is the form that survives long enough for CUDA search in
|
||||
// real models. The KernelConv2D op consumes the pre-gather input
|
||||
// and avoids materializing both the im2col window tensor and the
|
||||
// elementwise product tensor.
|
||||
//
|
||||
// TODO(egglog-shapes): the current e-graph does not reliably prove
|
||||
// the derived arithmetic equalities for this chain after CUDA
|
||||
// normalization:
|
||||
// * `M == H_out * W_out`
|
||||
// * `K == C_in * KH * KW`
|
||||
// * separately-derived but structurally identical stride
|
||||
// expressions, e.g. the Mul output stride and KernelSum input
|
||||
// stride, belong to the same e-class.
|
||||
// Keep the rewrite anchored on the stable conv layout facts the
|
||||
// graph does carry today: six-axis unfold window shape, flattened
|
||||
// `[M, C_out, K]` product, reduction over `K`, the three-axis
|
||||
// `[C_out, H_out, W_out]` output view, and channel-only bias
|
||||
// broadcast. Once expression/list canonicalization can prove those
|
||||
// equalities, tighten this rule and its regression tests.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the im2col-style HLIR conv used by Flux2:
|
||||
//
|
||||
// input.unfold([1, kh, kw], [1, 1, 1], [1, 1, 1])
|
||||
// -> squeeze/permute/merge view
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute view
|
||||
// -> + bias.expand_dim(1, h_out).expand_dim(2, w_out)
|
||||
//
|
||||
// The kernel consumes the pre-unfold input directly. That input may
|
||||
// already be a padded HLIR tensor, so the rewrite is still correct
|
||||
// for Flux2's padded convs while removing the large patch matrix.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
; This rewrite is for stride=1, dilation=1 over the
|
||||
; tensor passed to unfold. Padded HLIR inputs are already
|
||||
; represented as their own tensor, so padding is 0 here.
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?add (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (Add ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?fe (Op (FusionEnd ?shape ?os ?dt) ?inputs))
|
||||
(= ?fe (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?conv_dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (FusionEnd ?shape ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a luminal::egglog_utils::NodeId],
|
||||
input_enodes: Vec<&'a luminal::egglog_utils::NodeId>,
|
||||
list_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::egglog_utils::NodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
weight_co_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
weight_inner_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
|
||||
bias_c_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
kernel_h: extract_expr(egraph, kind_children[7], expr_cache).unwrap(),
|
||||
kernel_w: extract_expr(egraph, kind_children[8], expr_cache).unwrap(),
|
||||
stride_h: extract_expr(egraph, kind_children[9], expr_cache).unwrap(),
|
||||
stride_w: extract_expr(egraph, kind_children[10], expr_cache).unwrap(),
|
||||
dilation_h: extract_expr(egraph, kind_children[11], expr_cache).unwrap(),
|
||||
dilation_w: extract_expr(egraph, kind_children[12], expr_cache).unwrap(),
|
||||
pad_h: extract_expr(egraph, kind_children[13], expr_cache).unwrap(),
|
||||
pad_w: extract_expr(egraph, kind_children[14], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[15]),
|
||||
}) as Box<dyn KernelOp>),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelConv2D {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
assert_eq!(self.dtype, DType::F32, "KernelConv2D currently emits F32");
|
||||
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let c_out = self.out_shape[0].to_kernel();
|
||||
let h_out = self.out_shape[1].to_kernel();
|
||||
let w_out = self.out_shape[2].to_kernel();
|
||||
let c_in = self.input_shape[0].to_kernel();
|
||||
let h_in = self.input_shape[1].to_kernel();
|
||||
let w_in = self.input_shape[2].to_kernel();
|
||||
let weight_co_stride = self
|
||||
.weight_co_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let weight_inner_stride = self
|
||||
.weight_inner_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let bias_c_stride = self
|
||||
.bias_c_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let kh = self.kernel_h.to_kernel();
|
||||
let kw = self.kernel_w.to_kernel();
|
||||
let stride_h = self.stride_h.to_kernel();
|
||||
let stride_w = self.stride_w.to_kernel();
|
||||
let dilation_h = self.dilation_h.to_kernel();
|
||||
let dilation_w = self.dilation_w.to_kernel();
|
||||
let pad_h = self.pad_h.to_kernel();
|
||||
let pad_w = self.pad_w.to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let input_idx = flatten_strides(&self.input_shape, &self.input_stride)
|
||||
.to_kernel()
|
||||
.replace("const_z", "input_linear");
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_conv2d_bias(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias{dyn_dims_param}
|
||||
) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long long total = {total};
|
||||
if (const_z >= total) return;
|
||||
|
||||
const long long COUT = {c_out};
|
||||
const long long HOUT = {h_out};
|
||||
const long long WOUT = {w_out};
|
||||
const long long CIN = {c_in};
|
||||
const long long HIN = {h_in};
|
||||
const long long WIN = {w_in};
|
||||
const long long KH = {kh};
|
||||
const long long KW = {kw};
|
||||
const long long SH = {stride_h};
|
||||
const long long SW = {stride_w};
|
||||
const long long DH = {dilation_h};
|
||||
const long long DW = {dilation_w};
|
||||
const long long PH = {pad_h};
|
||||
const long long PW = {pad_w};
|
||||
const long long W_CO_STRIDE = {weight_co_stride};
|
||||
const long long W_INNER_STRIDE = {weight_inner_stride};
|
||||
const long long BIAS_C_STRIDE = {bias_c_stride};
|
||||
|
||||
long long co = const_z / (HOUT * WOUT);
|
||||
long long rem = const_z - co * HOUT * WOUT;
|
||||
long long oh = rem / WOUT;
|
||||
long long ow = rem - oh * WOUT;
|
||||
|
||||
float acc = bias[co * BIAS_C_STRIDE];
|
||||
for (long long ci = 0; ci < CIN; ++ci) {{
|
||||
for (long long r = 0; r < KH; ++r) {{
|
||||
long long ih = oh * SH + r * DH - PH;
|
||||
if (ih < 0 || ih >= HIN) continue;
|
||||
for (long long s = 0; s < KW; ++s) {{
|
||||
long long iw = ow * SW + s * DW - PW;
|
||||
if (iw < 0 || iw >= WIN) continue;
|
||||
long long input_linear = (ci * HIN + ih) * WIN + iw;
|
||||
long long input_idx = {input_idx};
|
||||
long long inner = (ci * KH + r) * KW + s;
|
||||
long long weight_idx = co * W_CO_STRIDE + inner * W_INNER_STRIDE;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
out[{out_idx}] = acc;
|
||||
}}
|
||||
}}",
|
||||
total = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_conv2d_bias").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs.ceil_div(256), 1.into(), 1.into()),
|
||||
(n_outputs.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2 * 4 + self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericConv2D"
|
||||
}
|
||||
}
|
||||
@@ -498,8 +498,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -530,8 +530,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -568,8 +568,8 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
@@ -610,8 +610,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
@@ -641,8 +641,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
@@ -674,8 +674,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
393
crates/luminal_cuda_lite/src/kernel/fusion/elementwise.rs
Normal file
393
crates/luminal_cuda_lite/src/kernel/fusion/elementwise.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
// =========================================================================
|
||||
// Generic CUDA elementwise ops used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// CUDA elementwise execution is represented as a FusionEnd-rooted region even
|
||||
// for a single op. These ops are therefore region-internal only; standalone
|
||||
// compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (CudaUnaryElementwise, CudaBinaryElementwise);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
fn extract_string_label(egraph: &SerializedEGraph, node: &ENodeId) -> String {
|
||||
egraph.enodes[node].0.trim_matches('"').to_string()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaUnaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaUnaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaUnaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
for (hlir, opcode) in [
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?out_s) (ICons ?x (INil))))
|
||||
(= ?dt (dtype ?u))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?out_s ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?sqrt (Op (Sqrt ?shape ?x_stride ?sqrt_stride) (ICons ?x (INil))))
|
||||
(= ?recip (Op (Recip ?shape ?sqrt_stride ?out_stride) (ICons ?sqrt (INil))))
|
||||
(= ?dt (dtype ?recip))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Rsqrt\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?recip ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-rsqrt-from-sqrt-recip\")",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Exp\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?exp2 ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(datatype*
|
||||
(CudaSigmoidScaledState
|
||||
(MkCudaSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function cuda_sigmoid_scaled (IR) CudaSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (cuda_sigmoid_scaled ?scaled)
|
||||
(MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-region-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (cuda_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Sigmoid\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?sig_out ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaUnaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaUnaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaUnaryElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaBinaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaBinaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaBinaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Add ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?bin))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Add\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Add\")",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Mul ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?a))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Mul\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Mul\")",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[4], list_cache, expr_cache).unwrap();
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype: extract_dtype(egraph, kind_children[5]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaBinaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaBinaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes() * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaBinaryElementwise"
|
||||
}
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
// =========================================================================
|
||||
// Fused elementwise op variants used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// Each `FusedX` struct mirrors its un-fused `KernelX` sibling field-for-field
|
||||
// and serves a single purpose: give the egglog rules a distinct sort to
|
||||
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
|
||||
// pattern. Cascade prevention by typing.
|
||||
//
|
||||
// Each FusedX must be absorbed into a FusionEnd-rooted region and compiled by
|
||||
// `region_codegen`; standalone compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (
|
||||
FusedSin,
|
||||
FusedSqrt,
|
||||
FusedExp,
|
||||
FusedExp2,
|
||||
FusedLog2,
|
||||
FusedRecip,
|
||||
FusedAdd,
|
||||
FusedMul,
|
||||
);
|
||||
|
||||
// Standard `compile()` return tuple (matches the trait signature).
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
|
||||
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
|
||||
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
|
||||
macro_rules! impl_fused_unary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $body:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// As `impl_fused_unary!` but for binary ops: 5-field sort signature
|
||||
/// (shape + per-input strides + out_stride + dtype), n_inputs = 2.
|
||||
/// `$op_str` is the CUDA infix operator, e.g. `"+"`, `"*"`.
|
||||
macro_rules! impl_fused_binary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $op_str:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[0],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[3],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let bytes = (self.output_size() * self.dtype.bits()).ceil_div(8);
|
||||
bytes + bytes
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_fused_unary!(FusedSin, "FusedSin", "fused_sin_k", "sinf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedSqrt,
|
||||
"FusedSqrt",
|
||||
"fused_sqrt_k",
|
||||
"sqrtf(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(FusedExp, "FusedExp", "fused_exp_k", "expf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedExp2,
|
||||
"FusedExp2",
|
||||
"fused_exp2_k",
|
||||
"exp2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedLog2,
|
||||
"FusedLog2",
|
||||
"fused_log2_k",
|
||||
"log2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedRecip,
|
||||
"FusedRecip",
|
||||
"fused_recip_k",
|
||||
"1.0f / in[{in_idx}]"
|
||||
);
|
||||
|
||||
impl_fused_binary!(FusedAdd, "FusedAdd", "fused_add_k", "+");
|
||||
impl_fused_binary!(FusedMul, "FusedMul", "fused_mul_k", "*");
|
||||
@@ -9,8 +9,8 @@
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Like FusedX, both markers'
|
||||
// `compile()` is `unreachable!()` — region codegen folds them away
|
||||
// codegen lives in `region_codegen`. Both markers' `compile()` is
|
||||
// `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
@@ -142,218 +142,164 @@ impl EgglogOp for FusionEnd {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
// rule's own output cannot re-match its LHS — cascade is prevented
|
||||
// by typing rather than by a discriminator field.
|
||||
//
|
||||
// Stride compatibility is expressed by reusing variable names: a
|
||||
// unary inside a region matches `(KernelU ?shape ?s ?s ?dt)` (in =
|
||||
// out, no transpose); a binary feeding a downstream op binds the
|
||||
// binary's out-stride to the downstream op's in-stride along the
|
||||
// connecting side.
|
||||
// Generic region growth works directly from HLIR elementwise ops into
|
||||
// `Cuda*Elementwise` region nodes. The concrete HLIR op still appears in
|
||||
// the egraph, so fusion remains a normal nondestructive alternative, but
|
||||
// the region-internal representation is arity based instead of one
|
||||
// dedicated fused sort per operation.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// (KernelX kind, FusedX kind)
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("KernelSin", "FusedSin"),
|
||||
("KernelSqrt", "FusedSqrt"),
|
||||
("KernelExp", "FusedExp"),
|
||||
("KernelExp2", "FusedExp2"),
|
||||
("KernelLog2", "FusedLog2"),
|
||||
("KernelRecip", "FusedRecip"),
|
||||
];
|
||||
// (KernelX kind, FusedX kind, rule-name label)
|
||||
let binaries: &[(&str, &str, &str)] = &[
|
||||
("KernelAdd", "FusedAdd", "Add"),
|
||||
("KernelMul", "FusedMul", "Mul"),
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
];
|
||||
let binaries: &[(&str, &str)] = &[("Add", "Add"), ("Mul", "Mul")];
|
||||
|
||||
// 1. Pair-fuse U → U: U2(U1(x)) → FE(FU2(FU1(FS(x)))).
|
||||
for (ki1, fi1) in unaries {
|
||||
for (ko2, fo2) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u1 (Op ({ki1} ?shape ?s ?s ?dt) (ICons ?x (INil))))
|
||||
(= ?u2 (Op ({ko2} ?shape ?s ?s ?dt) (ICons ?u1 (INil))))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?fu1 (Op ({fi1} ?shape ?s ?s ?dt) (ICons ?fs (INil))))
|
||||
(let ?fu2 (Op ({fo2} ?shape ?s ?s ?dt) (ICons ?fu1 (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu2 (INil))))
|
||||
(union ?u2 ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Pair-fuse B → U: U(B(a, b)) → FE(FU(FB(FS(a), FS(b)))).
|
||||
for (kb, fb, lb) in binaries {
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?u (Op ({ku} ?shape ?o_s ?o_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fu (Op ({fu} ?shape ?o_s ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Pair-fuse U → B (lhs / rhs): unary feeds binary's A or B input.
|
||||
// LHS: B(U(a), b) → FE(FB(FU(FS(a)), FS(b))).
|
||||
// RHS: B(a, U(b)) → FE(FB(FS(a), FU(FS(b)))).
|
||||
for (ku, fu) in unaries {
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?u (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?u_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fu (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?u (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?u_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fu (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Pair-fuse B → B (lhs / rhs): inner binary feeds outer's A or B.
|
||||
for (kbi, fbi, lbi) in binaries {
|
||||
for (kbo, fbo, lbo) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?bi (ICons ?c (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?fbi (ICons ?fs_c (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?c (ICons ?bi (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?fs_c (ICons ?fbi (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Grow FE → U: U(FE(inner)) → FE(FU(inner)). No new FS.
|
||||
for (ku, fu) in unaries {
|
||||
// Grow FE → unary consumer: U(FE(inner)) → FE(CudaUnary(inner)).
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({ku} ?shape ?s ?s ?dt) (ICons ?fe (INil))))
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{ku}\")"
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 6. Grow FE → B (lhs / rhs): one input is the FE, the other external.
|
||||
for (kb, fb, lb) in binaries {
|
||||
// Grow FE → binary consumer, left and right orientations.
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{lb}\")"
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{lb}\")"
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 7. Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(FB(ia, ib)).
|
||||
//
|
||||
// This is destructive: after creating the larger region, subsume the
|
||||
// two smaller FusionEnd rows. Without that, independently-grown left
|
||||
// and right regions form a Cartesian product, then those alternatives
|
||||
// can merge again higher in the graph.
|
||||
for (kb, fb, lb) in binaries {
|
||||
// Absorb an elementwise producer through a FusionStart boundary. This
|
||||
// makes a region that initially treats `producer(...)` as an external
|
||||
// input able to pull that producer inside later.
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?x (INil))))
|
||||
(= ?fs_u (Op (FusionStart ?shape ?s ?dt) (ICons ?u (INil))))
|
||||
) (
|
||||
(let ?fs_x (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?fs_x (INil))))
|
||||
(union ?fs_u ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-U-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?bad_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?bad_fs (INil))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-unary-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?fs_bin (Op (FusionStart ?shape ?out_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(union ?fs_bin ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-B-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?bad_fs (ICons ?fs_b (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?bad_fs (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(CudaBinary(ia, ib)).
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(subsume (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(subsume (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{lb}\")"
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -363,6 +309,61 @@ impl EgglogOp for FusionEnd {
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-strides\")",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
|
||||
@@ -2,25 +2,21 @@
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `fused_ops` — eight `FusedX` op variants (interior to a region) so
|
||||
//! pair-fuse rules' RHS sit in a different egglog sort than their LHS,
|
||||
//! blocking cascade by typing.
|
||||
//! - `elementwise` — generic region-internal CUDA elementwise op variants.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / `FusedX` / `FusionEnd` nodes after
|
||||
//! The LLIR keeps `FusionStart` / generic elementwise / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod fused_ops;
|
||||
pub mod elementwise;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use fused_ops::{
|
||||
FusedAdd, FusedExp, FusedExp2, FusedLog2, FusedMul, FusedRecip, FusedSin, FusedSqrt,
|
||||
};
|
||||
pub use elementwise::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior FusedX variants). Combined into a flat tuple for the
|
||||
/// `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, fused_ops::Ops);
|
||||
/// (markers + interior generic elementwise variants). Combined into a flat
|
||||
/// tuple for the `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, elementwise::Ops);
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// PR1 left FusedX / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// Older fusion lowering left elementwise / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — un-fused KernelX, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior FusedX DAG +
|
||||
// - CompileUnit::Single(node) — unfused non-region kernels, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior elementwise DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all FusedX bodies through
|
||||
// chains all elementwise bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior FusedX / FusionStart nodes
|
||||
// device pointers at execute time. Interior Cuda*Elementwise / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
@@ -40,6 +40,7 @@ use as_any::Downcast;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::elementwise::{CudaBinaryElementwise, CudaUnaryElementwise},
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
@@ -52,10 +53,10 @@ use crate::{
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior FusedX nodes, in topological order (predecessors before
|
||||
/// Interior Cuda*Elementwise nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub fusedx_topo: Vec<NodeIndex>,
|
||||
pub elementwise_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
@@ -79,13 +80,13 @@ pub(crate) enum CompileUnit {
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// FusedX and FusionStart nodes are absorbed into that region and removed
|
||||
/// Cuda*Elementwise and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / `FusedX` nodes, stopping at
|
||||
/// incoming edges through `FusionEnd` / Cuda*Elementwise nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
@@ -123,7 +124,7 @@ pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<Nod
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
@@ -187,12 +188,12 @@ pub(crate) fn build_compile_units(
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-FusedX predecessor inside what
|
||||
// Non-marker, non-elementwise predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
@@ -229,7 +230,56 @@ pub(crate) fn build_compile_units(
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionStart with no predecessor")
|
||||
.unwrap_or_else(|| {
|
||||
// Dump the malformed structure: which FE
|
||||
// triggered the walk, every node in fs_topo and
|
||||
// interior_topo, and each FS's incoming /
|
||||
// outgoing degree. Helps localize whether the
|
||||
// missing edge came from extraction or a
|
||||
// downstream LLIR transform.
|
||||
if std::env::var("LUMINAL_DEBUG_FUSION_PANIC").is_ok() {
|
||||
eprintln!(
|
||||
"FusionStart panic: fe={} (kernel={:?})",
|
||||
node.index(),
|
||||
llir_graph.node_weight(node).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
}),
|
||||
);
|
||||
eprintln!(" fs_topo ({}):", fs_topo.len());
|
||||
for &f in &fs_topo {
|
||||
let in_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Incoming)
|
||||
.count();
|
||||
let out_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Outgoing)
|
||||
.count();
|
||||
let kn = llir_graph
|
||||
.node_weight(f)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(
|
||||
" fs={} kind={} in_deg={} out_deg={}",
|
||||
f.index(),
|
||||
kn,
|
||||
in_deg,
|
||||
out_deg,
|
||||
);
|
||||
}
|
||||
eprintln!(" interior_topo ({}):", interior_topo.len());
|
||||
for &i in &interior_topo {
|
||||
let kn = llir_graph
|
||||
.node_weight(i)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(" interior={} kind={}", i.index(), kn);
|
||||
}
|
||||
}
|
||||
panic!("FusionStart with no predecessor")
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -240,7 +290,7 @@ pub(crate) fn build_compile_units(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
fusedx_topo: interior_topo,
|
||||
elementwise_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
@@ -269,24 +319,54 @@ pub(crate) fn build_compile_units(
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-FusedX body templates.
|
||||
// Per-elementwise body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn fused_body(name: &str, locals: &[&str]) -> String {
|
||||
match name {
|
||||
"FusedSin" => format!("sinf({})", locals[0]),
|
||||
"FusedSqrt" => format!("sqrtf({})", locals[0]),
|
||||
"FusedExp" => format!("expf({})", locals[0]),
|
||||
"FusedExp2" => format!("exp2f({})", locals[0]),
|
||||
"FusedLog2" => format!("log2f({})", locals[0]),
|
||||
"FusedRecip" => format!("1.0f / {}", locals[0]),
|
||||
"FusedAdd" => format!("{} + {}", locals[0], locals[1]),
|
||||
"FusedMul" => format!("{} * {}", locals[0], locals[1]),
|
||||
other => panic!("region_codegen: unknown FusedX op {other}"),
|
||||
fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
|
||||
llir_graph
|
||||
.node_weight(node)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>())
|
||||
.is_some_and(|op| {
|
||||
(***op).downcast_ref::<CudaUnaryElementwise>().is_some()
|
||||
|| (***op).downcast_ref::<CudaBinaryElementwise>().is_some()
|
||||
})
|
||||
}
|
||||
|
||||
fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("static_cast<float>({local})")
|
||||
} else {
|
||||
local.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("{cuda_ty}({expr})")
|
||||
} else {
|
||||
expr.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
|
||||
let a = || elementwise_value(locals[0], dtype);
|
||||
let b = || elementwise_value(locals[1], dtype);
|
||||
match op {
|
||||
"Sin" => format!("sinf({})", a()),
|
||||
"Sqrt" => format!("sqrtf({})", a()),
|
||||
"Rsqrt" => format!("rsqrtf({})", a()),
|
||||
"Exp" => format!("expf({})", a()),
|
||||
"Exp2" => format!("exp2f({})", a()),
|
||||
"Log2" => format!("log2f({})", a()),
|
||||
"Recip" => format!("1.0f / {}", a()),
|
||||
"Sigmoid" => format!("1.0f / (1.0f + expf(-{}))", a()),
|
||||
"Add" => format!("{} + {}", a(), b()),
|
||||
"Mul" => format!("{} * {}", a(), b()),
|
||||
other => panic!("region_codegen: unknown elementwise op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +404,7 @@ pub(crate) fn compile_region(
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides, FusedX shape — all FusedX share `out_shape`, but their
|
||||
// FE strides and elementwise shapes.
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
@@ -334,6 +414,19 @@ pub(crate) fn compile_region(
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
for &elem_idx in ®ion.elementwise_topo {
|
||||
let elem_op = llir_graph[elem_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
if let Some(elem) = (***elem_op).downcast_ref::<CudaUnaryElementwise>() {
|
||||
all_vars.extend(elem.shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.in_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
} else if let Some(elem) = (***elem_op).downcast_ref::<CudaBinaryElementwise>() {
|
||||
all_vars.extend(elem.out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.a_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.b_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
@@ -359,19 +452,19 @@ pub(crate) fn compile_region(
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk FusedX in topo order emitting a
|
||||
// Body: read FS leaves, then walk elementwise nodes in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), FusedX get fs_nodes.len()..(+ fusedx_topo.len()).
|
||||
// 0..fs_nodes.len(), elementwise nodes get fs_nodes.len()..(+ elementwise_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.fusedx_topo.iter().enumerate() {
|
||||
for (i, &op_idx) in region.elementwise_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
@@ -394,12 +487,22 @@ pub(crate) fn compile_region(
|
||||
));
|
||||
}
|
||||
|
||||
// FusedX ops in topo order. Each looks up its predecessor locals
|
||||
// Elementwise ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.fusedx_topo {
|
||||
for &op_idx in ®ion.elementwise_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let op_name = op_ref.kernel_name();
|
||||
let (elem_name, elem_dtype) =
|
||||
if let Some(elem) = (***op_ref).downcast_ref::<CudaUnaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else if let Some(elem) = (***op_ref).downcast_ref::<CudaBinaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else {
|
||||
panic!(
|
||||
"region_codegen: expected Cuda*Elementwise op, got {}",
|
||||
op_ref.kernel_name()
|
||||
);
|
||||
};
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
@@ -418,15 +521,16 @@ pub(crate) fn compile_region(
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = fused_body(op_name, &inputs_ref);
|
||||
let expr = elementwise_body(elem_name, &inputs_ref, elem_dtype);
|
||||
let expr = elementwise_init_expr(&expr, elem_dtype, cuda_ty);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the FusedX feeding FE (its single incoming edge in
|
||||
// the region — a FusedX or, in degenerate single-FS regions which
|
||||
// FE write: pick the elementwise node feeding FE (its single incoming edge in
|
||||
// the region — an elementwise node or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
@@ -474,3 +578,63 @@ pub(crate) fn compile_region(
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::fusion::elementwise::CudaBinaryElementwise;
|
||||
use luminal::op::LLIROp;
|
||||
use luminal::prelude::petgraph::algo::toposort;
|
||||
|
||||
/// Helper: wrap a `KernelOp` in an `LLIROp` of the kernel dialect.
|
||||
fn llir_of(op: impl KernelOp + 'static) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(op) as Box<dyn KernelOp>)
|
||||
}
|
||||
|
||||
/// Reproducer for the `FusionStart with no predecessor` panic at
|
||||
/// `region_codegen.rs:232`. The egglog rolling pass + iterated mode
|
||||
/// (`LUMINAL_LOOP_ROLL_ITERATE=1`) has been observed to produce LLIR
|
||||
/// graphs where a `FusionStart` marker is reached as a region leaf
|
||||
/// during the FE→FS walk but has no incoming edge — meaning the
|
||||
/// region has nothing to read from. `build_compile_units` then
|
||||
/// panics when constructing `external_inputs` because every FS leaf
|
||||
/// is required to have exactly one external producer.
|
||||
///
|
||||
/// Until that path is fixed, this test pins the failure mode so a
|
||||
/// regression doesn't silently change the panic message or location.
|
||||
/// `should_panic` rather than `ignore` so it stays runnable in CI
|
||||
/// and surfaces if the panic ever moves.
|
||||
#[test]
|
||||
#[should_panic(expected = "FusionStart with no predecessor")]
|
||||
fn fusion_start_with_no_predecessor_panics() {
|
||||
// Minimal reproducer:
|
||||
//
|
||||
// (no input) ──▶ FusionStart ──▶ CudaBinaryElementwise ──▶ FusionEnd
|
||||
//
|
||||
// CudaBinaryElementwise is a binary op (n_inputs = 2) so a real region would
|
||||
// have two FS leaves. For this panic-shape test only the *first*
|
||||
// FS leaf needs a missing predecessor — `build_compile_units`
|
||||
// panics in `expect("FusionStart with no predecessor")` as soon
|
||||
// as any FS in `fs_topo` lacks one. We add only one FS edge so
|
||||
// CudaBinaryElementwise has a dangling second input slot, but that's fine:
|
||||
// we're testing the specific panic path inside `build_compile_units`,
|
||||
// not full kernel codegen.
|
||||
let mut llir: LLIRGraph = LLIRGraph::default();
|
||||
|
||||
let fs_node = llir.add_node(llir_of(FusionStart::default()));
|
||||
let fadd_node = llir.add_node(llir_of(CudaBinaryElementwise::default()));
|
||||
let fe_node = llir.add_node(llir_of(FusionEnd::default()));
|
||||
|
||||
// FusionStart → CudaBinaryElementwise → FusionEnd.
|
||||
llir.add_edge(fs_node, fadd_node, ());
|
||||
llir.add_edge(fadd_node, fe_node, ());
|
||||
|
||||
let topo = toposort(&llir, None).expect("LLIR cycle in test setup");
|
||||
let absorbed = globally_absorbed_markers(&llir);
|
||||
|
||||
// This is the call that panics with `FusionStart with no
|
||||
// predecessor` because `fs_node`'s incoming-edges iterator is
|
||||
// empty.
|
||||
let _ = build_compile_units(&topo, &llir, &absorbed);
|
||||
}
|
||||
}
|
||||
|
||||
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{
|
||||
KernelOp,
|
||||
hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
shape::flatten_strides,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct GenericMatmul {
|
||||
out_shape: Vec<Expression>,
|
||||
mul_shape: Vec<Expression>,
|
||||
k: Expression,
|
||||
lhs_strides: Vec<Expression>,
|
||||
rhs_strides: Vec<Expression>,
|
||||
sum_input_strides: Vec<Expression>,
|
||||
sum_iter_stride: Expression,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for GenericMatmul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"GenericMatmul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("mul_shape", ELIST),
|
||||
("k", EXPRESSION),
|
||||
("lhs_strides", ELIST),
|
||||
("rhs_strides", ELIST),
|
||||
("sum_input_strides", ELIST),
|
||||
("sum_iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?sum))
|
||||
)
|
||||
(
|
||||
(let ?generic (Op (GenericMatmul
|
||||
?out_shape
|
||||
?mul_shape
|
||||
?k
|
||||
?lhs_strides
|
||||
?rhs_strides
|
||||
?sum_input_strides
|
||||
?sum_iter_stride
|
||||
?out_strides
|
||||
?dt)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(union ?sum ?generic)
|
||||
(set (dtype ?generic) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"generic-matmul-cuda-mul-sum\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
(
|
||||
(delete (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?kernel_sum (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs))
|
||||
(= ?kernel_sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
((delete (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
mul_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
lhs_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
rhs_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
sum_input_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[5],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
sum_iter_stride: extract_expr(egraph, kind_children[6], expr_cache).unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[7], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[8]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for GenericMatmul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.all_dyn_vars();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_outputs = self.output_size();
|
||||
let sum_base_idx = flatten_strides(&self.out_shape, &self.sum_input_strides).to_kernel();
|
||||
let iter_offset = self.sum_iter_stride.to_kernel().replace("const_z", "i");
|
||||
let lhs_idx = flatten_strides(&self.mul_shape, &self.lhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let rhs_idx = flatten_strides(&self.mul_shape, &self.rhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
|
||||
let k = self.k.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_matmul({dtype} *out, const {dtype} *lhs, const {dtype} *rhs{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
if (const_z >= {n_outputs}) return;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long base_idx = {sum_base_idx};
|
||||
long long iters = {k};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
long long mul_idx = base_idx + {iter_offset};
|
||||
partial += static_cast<float>(lhs[{lhs_idx}]) * static_cast<float>(rhs[{rhs_idx}]);
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
float block_sum = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s >>= 1) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = ({dtype})block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}",
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.mul_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k.dyn_vars())
|
||||
.chain(self.lhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.rhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_input_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_iter_stride.dyn_vars())
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.output_size() * self.k * self.dtype.bits() * 2).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericMatmul"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
427
crates/luminal_cuda_lite/src/kernel/matmul2d.rs
Normal file
427
crates/luminal_cuda_lite/src/kernel/matmul2d.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
//! Direct 2D matmul kernel — bypasses egglog rewrites, used as a custom op
|
||||
//! for matmul shapes where the cublaslt egg rules don't reliably fire.
|
||||
//!
|
||||
//! The cublaslt 2D rules in `host/cublaslt/cublaslt_*Cm_rewrite.egg` /
|
||||
//! `cublaslt_Rm*_rewrite.egg` are *supposed* to match any 2D matmul whose
|
||||
//! Mul + SumReduce broadcast lowering has the expected stride patterns,
|
||||
//! and the conditional matmul cleanup is *supposed* to delete the
|
||||
//! elementwise Mul + KernelSumReduce fallback whenever a cublaslt alternative
|
||||
//! exists. In practice both fail to fire reliably for the VAE's mid-block
|
||||
//! `AttnBlock` matmuls — at 1024² that lets the search occasionally pick
|
||||
//! the broadcast-Mul path for `q @ kᵀ`, generating a `(HW, HW, C) =
|
||||
//! (16384, 16384, 512)` ≈ 524 GiB single intermediate that OOMs the GPU.
|
||||
//!
|
||||
//! Same approach as `kernel::conv2d`: define a `KernelOp`, wrap it in a
|
||||
//! `CustomOp`, expose a tiny `pub fn` so callers don't see the
|
||||
//! `cx.custom_op` plumbing. This is opaque to egglog by design — we
|
||||
//! aren't trying to fuse with surrounding ops, just guarantee a sane
|
||||
//! lowering for the matmuls we know are problematic.
|
||||
//!
|
||||
//! The CUDA implementation is a textbook 2D-blocked SGEMM:
|
||||
//! * 16×16 output tile per block (256 threads)
|
||||
//! * Tiled load of A and B into shared memory in K-size chunks
|
||||
//! * Each thread accumulates one output element across all K-tiles
|
||||
//! * Optional bias broadcast along the M axis at write-out
|
||||
//! * `transpose_b` toggles between row-major B `(K, N)` and row-major
|
||||
//! B `(N, K)` (i.e. the `A @ Bᵀ` pattern that linear/projection
|
||||
//! layers use).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct 2D matmul `(M, K) × {(K, N) | (N, K)} → (M, N)` with optional
|
||||
/// per-output-column bias and an optional batch axis. A and output are
|
||||
/// always F32. B can be F32 or BF16; BF16 is converted to F32 on each
|
||||
/// load, which avoids materializing the cast as a separate intermediate
|
||||
/// tensor (important for the text encoder / transformer where the F32-
|
||||
/// cast weights would not fit in GPU memory). All shape parameters are
|
||||
/// static (baked into the CUDA source via #defines).
|
||||
///
|
||||
/// When `batch > 1` the kernel does `batch` independent 2D matmuls in
|
||||
/// parallel: A is `(batch, M, K)`, B is `(batch, *, *)` with the same
|
||||
/// per-batch shape, output is `(batch, M, N)`. All three are assumed
|
||||
/// contiguous row-major across batches (i.e. `a_batch_stride = M*K`,
|
||||
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
|
||||
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
|
||||
/// `(N,)` and broadcast across batches.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DKernel {
|
||||
pub m: usize,
|
||||
pub n: usize,
|
||||
pub k: usize,
|
||||
pub batch: usize,
|
||||
/// If `true`, B is interpreted as `(N, K)` row-major and accessed as
|
||||
/// `B[n][k]` (i.e. `A @ Bᵀ`). If `false`, B is `(K, N)` row-major and
|
||||
/// accessed as `B[k][n]` (i.e. `A @ B`).
|
||||
pub transpose_b: bool,
|
||||
pub has_bias: bool,
|
||||
/// Storage dtype of B. Currently F32 or BF16 are supported.
|
||||
pub weight_dtype: DType,
|
||||
}
|
||||
|
||||
const TILE: usize = 16;
|
||||
|
||||
impl KernelOp for Matmul2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let bias_param = if self.has_bias {
|
||||
", const float* __restrict__ bias"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let bias_add = if self.has_bias {
|
||||
" acc += bias[n];\n"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
|
||||
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
|
||||
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// Plus the per-batch offset (`b_batch_off`).
|
||||
let b_index_expr = if self.transpose_b {
|
||||
"b_batch_off + (b_n_base + tx) * K + (k0 + ty)"
|
||||
} else {
|
||||
"b_batch_off + (k0 + ty) * N + (b_n_base + tx)"
|
||||
};
|
||||
// Convert B's element to float on load. For BF16 we declare B as
|
||||
// `__nv_bfloat16*` and use `__bfloat162float`; for F32 it's a no-op.
|
||||
let (b_param_type, b_load_expr, bf16_include) = match self.weight_dtype {
|
||||
DType::F32 => (
|
||||
"const float* __restrict__ B",
|
||||
format!("B[{b_index_expr}]"),
|
||||
"",
|
||||
),
|
||||
DType::Bf16 => (
|
||||
"const __nv_bfloat16* __restrict__ B",
|
||||
format!("__bfloat162float(B[{b_index_expr}])"),
|
||||
"#include <cuda_bf16.h>\n",
|
||||
),
|
||||
other => panic!("Matmul2DKernel: unsupported weight_dtype {other:?}"),
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ A,
|
||||
{b_param_type}{bias_param}
|
||||
) {{
|
||||
const int M = {m};
|
||||
const int N = {n};
|
||||
const int K = {k};
|
||||
const int TILE = {tile};
|
||||
|
||||
__shared__ float As[{tile}][{tile}];
|
||||
__shared__ float Bs[{tile}][{tile}];
|
||||
|
||||
int bx = blockIdx.x; // tile column (n)
|
||||
int by = blockIdx.y; // tile row (m)
|
||||
int batch = blockIdx.z; // batch index (0..BATCH-1)
|
||||
int tx = threadIdx.x; // 0..TILE-1, output col within tile
|
||||
int ty = threadIdx.y; // 0..TILE-1, output row within tile
|
||||
|
||||
int m_global = by * TILE + ty;
|
||||
int n_global = bx * TILE + tx;
|
||||
|
||||
int a_m_base = by * TILE;
|
||||
int b_n_base = bx * TILE;
|
||||
|
||||
// Per-batch base pointer offsets (contiguous row-major across batches).
|
||||
int a_batch_off = batch * (M * K);
|
||||
int b_batch_off = batch * (K * N);
|
||||
int c_batch_off = batch * (M * N);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
int n_tiles = (K + TILE - 1) / TILE;
|
||||
for (int t = 0; t < n_tiles; ++t) {{
|
||||
int k0 = t * TILE;
|
||||
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
|
||||
int a_m = a_m_base + ty;
|
||||
int a_k = k0 + tx;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
|
||||
|
||||
// Load B tile depending on transpose_b
|
||||
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
|
||||
int b_k_or_k = k0 + ty; // similarly
|
||||
// We compute Bs[ty][tx] such that the inner loop reads Bs[k_local][n_local] = B[k][n].
|
||||
// For transpose_b=true (B is (N,K)): B[k][n] in math = B_storage[n][k] = B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// For transpose_b=false (B is (K,N)): B[k][n] in math = B_storage[k][n] = B[(k0+ty)*N + (b_n_base+tx)]
|
||||
bool b_in_bounds = ({transpose_b} ? (b_n_or_k < N && b_k_or_k < K)
|
||||
: (b_k_or_k < K && b_n_or_k < N));
|
||||
Bs[ty][tx] = b_in_bounds ? ({b_load_expr}) : 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < {tile}; ++kk) {{
|
||||
acc += As[ty][kk] * Bs[kk][tx];
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
|
||||
if (m_global < M && n_global < N) {{
|
||||
int n = n_global;
|
||||
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
m = self.m,
|
||||
n = self.n,
|
||||
k = self.k,
|
||||
tile = TILE,
|
||||
transpose_b = self.transpose_b,
|
||||
b_load_expr = b_load_expr,
|
||||
b_param_type = b_param_type,
|
||||
bias_param = bias_param,
|
||||
bias_add = bias_add,
|
||||
bf16_include = bf16_include,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("matmul_2d_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let grid_x = self.n.div_ceil(TILE);
|
||||
let grid_y = self.m.div_ceil(TILE);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(grid_y),
|
||||
Expression::from(self.batch),
|
||||
),
|
||||
(
|
||||
Expression::from(TILE),
|
||||
Expression::from(TILE),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.m * self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// K elements from A (F32) + K elements from B (F32 or BF16) + maybe bias (F32).
|
||||
let b_bytes = match self.weight_dtype {
|
||||
DType::F32 => 4,
|
||||
DType::Bf16 => 2,
|
||||
_ => 4,
|
||||
};
|
||||
let bias_bytes = if self.has_bias { 4 } else { 0 };
|
||||
Expression::from(
|
||||
self.batch * self.m * self.n * (self.k * 4 + self.k * b_bytes + bias_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let per_out = self.k * 2 + if self.has_bias { 1 } else { 0 };
|
||||
Expression::from(self.batch * self.m * self.n * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Matmul2D"
|
||||
}
|
||||
}
|
||||
|
||||
/// CustomOp wrapper for [`Matmul2DKernel`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DCustom(pub Matmul2DKernel);
|
||||
|
||||
impl CustomOp for Matmul2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
|
||||
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
|
||||
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
|
||||
/// pattern produced by linear / projection layers (`x @ w.t()`).
|
||||
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
|
||||
/// `(N,)`, row-major F32 throughout.
|
||||
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
|
||||
}
|
||||
|
||||
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
|
||||
///
|
||||
/// Lowers as plain HLIR — `Cast(A, BF16) @ permute(B_bf16) → Cast(F32)`.
|
||||
/// The activation cast and output cast are tiny (M*K and M*N elements;
|
||||
/// the K=hidden weight stays BF16). The inner BF16 matmul matches the
|
||||
/// existing cublaslt rewrite rules and runs as
|
||||
/// `CUBLAS_COMPUTE_32F_FAST_16BF` — Hopper's native 2× BF16 path.
|
||||
pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "linear_no_bias_bf16_w expects F32 A");
|
||||
assert_eq!(
|
||||
b_bf16.dtype,
|
||||
DType::Bf16,
|
||||
"linear_no_bias_bf16_w expects BF16 B"
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b_bf16.dims();
|
||||
assert_eq!(a_dims.len(), 2);
|
||||
assert_eq!(b_dims.len(), 2);
|
||||
let a_bf16 = a.cast(DType::Bf16);
|
||||
let b_kn = b_bf16.permute((1, 0));
|
||||
a_bf16.matmul(b_kn).cast(DType::F32)
|
||||
}
|
||||
|
||||
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
|
||||
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
|
||||
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
fn matmul_inner(
|
||||
a: GraphTensor,
|
||||
b: GraphTensor,
|
||||
transpose_b: bool,
|
||||
bias: Option<GraphTensor>,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
|
||||
let weight_dtype = b.dtype;
|
||||
assert!(
|
||||
matches!(weight_dtype, DType::F32 | DType::Bf16),
|
||||
"matmul B must be F32 or BF16, got {weight_dtype:?}",
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b.dims();
|
||||
assert_eq!(
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
"matmul A/B rank mismatch: {} vs {}",
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
);
|
||||
assert!(
|
||||
a_dims.len() == 2 || a_dims.len() == 3,
|
||||
"matmul expects rank 2 or 3, got rank {}",
|
||||
a_dims.len(),
|
||||
);
|
||||
|
||||
let (batch, a_off) = if a_dims.len() == 3 {
|
||||
let ba = a_dims[0].to_usize().expect("batch dim must be static");
|
||||
let bb = b_dims[0].to_usize().expect("batch dim must be static");
|
||||
assert_eq!(
|
||||
ba, bb,
|
||||
"matmul batch dim mismatch: A batch={ba}, B batch={bb}"
|
||||
);
|
||||
(ba, 1)
|
||||
} else {
|
||||
(1, 0)
|
||||
};
|
||||
|
||||
let m = a_dims[a_off].to_usize().expect("M must be a static dim");
|
||||
let k_a = a_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (A) must be a static dim");
|
||||
let (n, k_b) = if transpose_b {
|
||||
// B per-batch is (N, K)
|
||||
let n = b_dims[a_off].to_usize().expect("N must be a static dim");
|
||||
let k = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
(n, k)
|
||||
} else {
|
||||
// B per-batch is (K, N)
|
||||
let k = b_dims[a_off]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
let n = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("N must be a static dim");
|
||||
(n, k)
|
||||
};
|
||||
assert_eq!(k_a, k_b, "matmul K mismatch: A K={k_a}, B K={k_b}");
|
||||
let k = k_a;
|
||||
|
||||
let has_bias = bias.is_some();
|
||||
if let Some(bias) = bias {
|
||||
let bdims = bias.dims();
|
||||
assert_eq!(bdims.len(), 1, "matmul bias must be 1D");
|
||||
assert_eq!(
|
||||
bdims[0].to_usize().expect("bias dim must be static"),
|
||||
n,
|
||||
"matmul bias size must equal N"
|
||||
);
|
||||
assert_eq!(bias.dtype, DType::F32, "matmul bias must be F32");
|
||||
}
|
||||
|
||||
let kern = Matmul2DKernel {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch,
|
||||
transpose_b,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
};
|
||||
let cx = unsafe { &mut *a.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
vec![a, b, bias]
|
||||
} else {
|
||||
vec![a, b]
|
||||
};
|
||||
if batch == 1 {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
|
||||
} else {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
|
||||
}
|
||||
}
|
||||
@@ -9,14 +9,31 @@ use luminal_tracing::schema::{
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod generic_matmul;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use generic_matmul::GenericMatmul;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
pub type Ops = (
|
||||
hlir::Ops,
|
||||
other_ops::Ops,
|
||||
conv2d::KernelConv2D,
|
||||
GenericMatmul,
|
||||
fusion::Ops,
|
||||
);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
189
crates/luminal_cuda_lite/src/kernel/rope.rs
Normal file
189
crates/luminal_cuda_lite/src/kernel/rope.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Fused RoPE (rotary position embedding) — interleaved-pair convention.
|
||||
//!
|
||||
//! Replaces flux2's 6-op RoPE chain (split / slice / squeeze / neg / concat /
|
||||
//! merge_dims / 4× cast / mul / add) with a single kernel launch per call.
|
||||
//! ~120 RoPE calls per forward pass at full DiT depth.
|
||||
//!
|
||||
//! Convention: `repeat_interleave_real=True` (Flux 2 / diffusers), so adjacent
|
||||
//! dim pairs rotate together. For an input `[a0, b0, a1, b1, ...]` and per-
|
||||
//! position `(cos, sin)`, the output is
|
||||
//! `out[2j] = x[2j] * cos[2j] - x[2j+1] * sin[2j]`
|
||||
//! `out[2j+1] = x[2j+1] * cos[2j+1] + x[2j] * sin[2j+1]`
|
||||
//!
|
||||
//! Layout: x `(S, H, D)`, cos/sin `(S, D)` (broadcast across H).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPEKernel {
|
||||
pub s: usize,
|
||||
pub h: usize,
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
const TPB: usize = 64;
|
||||
|
||||
impl KernelOp for RoPEKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let s = self.s;
|
||||
let h = self.h;
|
||||
let d = self.d;
|
||||
assert!(d.is_multiple_of(2), "RoPE head_dim must be even");
|
||||
let kernel = format!(
|
||||
r#"
|
||||
extern "C" __global__ void rope_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ cos_,
|
||||
const float* __restrict__ sin_
|
||||
) {{
|
||||
const int S = {s};
|
||||
const int H = {h};
|
||||
const int D = {d};
|
||||
int sh = blockIdx.x; // 0..S*H
|
||||
int s_idx = sh / H;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const float* xr = x + sh * D;
|
||||
const float* cosr = cos_ + s_idx * D;
|
||||
const float* sinr = sin_ + s_idx * D;
|
||||
float* yr = out + sh * D;
|
||||
|
||||
for (int i = tid; i < D; i += {TPB}) {{
|
||||
float xi = xr[i];
|
||||
float xpair;
|
||||
if ((i & 1) == 0) {{
|
||||
// even: paired with i+1, rotated value is -x[i+1]
|
||||
xpair = -xr[i + 1];
|
||||
}} else {{
|
||||
// odd: paired with i-1, rotated value is +x[i-1]
|
||||
xpair = xr[i - 1];
|
||||
}}
|
||||
yr[i] = xi * cosr[i] + xpair * sinr[i];
|
||||
}}
|
||||
}}
|
||||
"#
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("rope_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
"rope_kernel".to_string(),
|
||||
(
|
||||
Expression::from(s * h),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(TPB),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.s * self.h * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// x: full (S,H,D); cos/sin: (S,D) read H times each but cached.
|
||||
Expression::from(self.s * self.h * self.d * 4 + self.s * self.d * 4 * 2)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 4 per output element (mul, neg/load, mul, add).
|
||||
Expression::from(self.s * self.h * self.d * 4)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"RoPE"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPECustom(pub RoPEKernel);
|
||||
|
||||
impl CustomOp for RoPECustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE: `x` shape `(S, H, D)` F32, `cos`/`sin` shape `(S, D)` F32.
|
||||
/// Returns `(S, H, D)` F32.
|
||||
pub fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(x.dtype, DType::F32, "RoPE x must be F32");
|
||||
let cos = if cos.dtype == DType::F32 {
|
||||
cos
|
||||
} else {
|
||||
cos.cast(DType::F32)
|
||||
};
|
||||
let sin = if sin.dtype == DType::F32 {
|
||||
sin
|
||||
} else {
|
||||
sin.cast(DType::F32)
|
||||
};
|
||||
let x_dims = x.dims();
|
||||
assert_eq!(x_dims.len(), 3, "RoPE x must be 3-D (S, H, D)");
|
||||
let s = x_dims[0].to_usize().expect("RoPE: S must be static");
|
||||
let h = x_dims[1].to_usize().expect("RoPE: H must be static");
|
||||
let d = x_dims[2].to_usize().expect("RoPE: D must be static");
|
||||
let cos_dims = cos.dims();
|
||||
let sin_dims = sin.dims();
|
||||
assert_eq!(cos_dims.len(), 2, "RoPE cos must be 2-D (S, D)");
|
||||
assert_eq!(sin_dims.len(), 2, "RoPE sin must be 2-D (S, D)");
|
||||
assert_eq!(cos_dims[0].to_usize().unwrap(), s, "RoPE cos S mismatch");
|
||||
assert_eq!(cos_dims[1].to_usize().unwrap(), d, "RoPE cos D mismatch");
|
||||
assert_eq!(sin_dims[0].to_usize().unwrap(), s, "RoPE sin S mismatch");
|
||||
assert_eq!(sin_dims[1].to_usize().unwrap(), d, "RoPE sin D mismatch");
|
||||
|
||||
let kern = RoPEKernel { s, h, d };
|
||||
let cx = unsafe { &mut *x.graph_ref };
|
||||
cx.custom_op(RoPECustom(kern), vec![x, cos, sin], (s, h, d), DType::F32)
|
||||
}
|
||||
@@ -192,6 +192,32 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
/// LLIR node IDs of every kernel in this CudaGraphOp, in the order
|
||||
/// they execute inside the compiled CUDA graph. This is the
|
||||
/// toposort `kernel_to_host` used at compile time, preserved here
|
||||
/// so the runtime can compute live ranges that match real
|
||||
/// execution order: each kernel in `state.kernels` was added to
|
||||
/// the CUDA graph with `prev_graph_node` as its sole dependency,
|
||||
/// which serializes them.
|
||||
pub fn kernel_topo_order(&self) -> Vec<NodeIndex> {
|
||||
self.state.borrow().kernels.iter().map(|k| k.node).collect()
|
||||
}
|
||||
|
||||
/// Direct LLIR-node inputs of one kernel inside this CudaGraphOp.
|
||||
/// Used by the runtime's live-range pass to refine intra-graph
|
||||
/// consumer positions: a kernel's input can stop being live as
|
||||
/// soon as that specific kernel finishes, not when the whole
|
||||
/// CudaGraphOp finishes.
|
||||
pub fn kernel_inputs(&self, kernel_node: NodeIndex) -> Vec<NodeIndex> {
|
||||
self.state
|
||||
.borrow()
|
||||
.kernels
|
||||
.iter()
|
||||
.find(|k| k.node == kernel_node)
|
||||
.map(|k| k.inputs.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -316,8 +342,7 @@ impl CudaGraphOp {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Add" | "Embed" | "Gather" | "GenericMatmul" | "LessThan" | "Mod" | "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
@@ -814,7 +839,7 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
|
||||
// Compute the set of FS / FE / Cuda*Elementwise nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
@@ -974,7 +999,7 @@ pub fn kernel_to_host(
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior FusedX nodes that don't exist
|
||||
// those are interior elementwise nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
@@ -1139,7 +1164,7 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// FusedX) from the LLIR. Region codegen has already folded them into
|
||||
// Cuda*Elementwise) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
|
||||
@@ -34,6 +34,7 @@ fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I64 => "long long",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
|
||||
@@ -237,6 +237,7 @@ pub(crate) fn split_egraph_by_memory_limit(
|
||||
let mut split = splitter.split();
|
||||
|
||||
compact_egraph_after_prune(&mut split);
|
||||
validate_unique_loop_markers(&split);
|
||||
let stats = MemorySplitStats {
|
||||
original_enodes,
|
||||
split_enodes: split.enodes.len(),
|
||||
@@ -442,6 +443,9 @@ impl<'a> StateSplitter<'a> {
|
||||
}
|
||||
}
|
||||
"Op" => self.split_op_node(owner_class, node, label, children),
|
||||
label if direct_loop_marker(label) => {
|
||||
self.split_direct_loop_marker_node(owner_class, node, label.to_string(), children)
|
||||
}
|
||||
_ => {
|
||||
let Some((idx, child_class)) =
|
||||
first_child_with_sort_index(self.original, &children, "IR")
|
||||
@@ -479,6 +483,9 @@ impl<'a> StateSplitter<'a> {
|
||||
|
||||
let input_states = self.split_list_class(inputs_class);
|
||||
for kind_node in kind_nodes {
|
||||
let Some((kind_label, _)) = self.original.enodes.get(kind_node) else {
|
||||
continue;
|
||||
};
|
||||
let Some(kind) =
|
||||
kind_memory_for_node(self.original, &self.sort_by_name, kind_node, self.dyn_map)
|
||||
else {
|
||||
@@ -488,6 +495,33 @@ impl<'a> StateSplitter<'a> {
|
||||
continue;
|
||||
}
|
||||
let kind_split_class = self.kind_singleton_class(kind_node);
|
||||
if loop_op_kind(kind_label) {
|
||||
// Loop OpKinds are structural markers. Keep the marker singleton and
|
||||
// pick one feasible state for the data flowing through it.
|
||||
let Some((state, input_split_class)) = input_states
|
||||
.iter()
|
||||
.filter_map(|(input_state, input_split_class)| {
|
||||
let state = op_memory_state(kind, input_state)?;
|
||||
(state.peak <= self.limit).then(|| (state, input_split_class.clone()))
|
||||
})
|
||||
.min_by_key(|(state, _)| (state.peak, state.live))
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut split_children = children.clone();
|
||||
split_children[0] = kind_split_class;
|
||||
split_children[1] = input_split_class;
|
||||
self.add_ir_state_node(
|
||||
owner_class,
|
||||
state,
|
||||
label.clone(),
|
||||
split_children,
|
||||
source_node,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (input_state, input_split_class) in &input_states {
|
||||
let Some(state) = op_memory_state(kind, input_state) else {
|
||||
continue;
|
||||
@@ -509,6 +543,33 @@ impl<'a> StateSplitter<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_direct_loop_marker_node(
|
||||
&mut self,
|
||||
owner_class: &ClassId,
|
||||
source_node: &NodeId,
|
||||
label: String,
|
||||
children: Vec<ClassId>,
|
||||
) {
|
||||
let Some((idx, child_class)) = first_child_with_sort_index(self.original, &children, "IR")
|
||||
else {
|
||||
return;
|
||||
};
|
||||
// LoopStart/LoopEnd identity is part of the loop scaffold, so state
|
||||
// splitting must not clone the marker across child-state variants.
|
||||
let Some((state, state_class)) = self
|
||||
.split_ir_class(&child_class)
|
||||
.into_iter()
|
||||
.filter(|(state, _)| state.peak <= self.limit)
|
||||
.min_by_key(|(state, _)| (state.peak, state.live))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut split_children = children;
|
||||
split_children[idx] = state_class;
|
||||
self.add_ir_state_node(owner_class, state, label, split_children, source_node);
|
||||
}
|
||||
|
||||
fn split_list_class(&mut self, class: &ClassId) -> Vec<(ListMemoryState, ClassId)> {
|
||||
if let Some(states) = self.list_memo.get(class) {
|
||||
return states.clone();
|
||||
@@ -992,7 +1053,10 @@ fn choose_kind_node<'a>(egraph: &'a SerializedEGraph, kind_class: &ClassId) -> O
|
||||
};
|
||||
let is_kernel = |node: &&NodeId| -> bool {
|
||||
let label = &egraph.enodes[*node].0;
|
||||
label.starts_with("Kernel") || label.starts_with("Fused")
|
||||
label.starts_with("Kernel")
|
||||
|| label.starts_with("Cuda")
|
||||
|| label == "FusionStart"
|
||||
|| label == "FusionEnd"
|
||||
};
|
||||
|
||||
kind_enodes
|
||||
@@ -1079,12 +1143,94 @@ fn compact_egraph_after_prune(egraph: &mut SerializedEGraph) {
|
||||
}
|
||||
|
||||
fn zero_local_op_kind(kind: &str) -> bool {
|
||||
loop_op_kind(kind)
|
||||
}
|
||||
|
||||
fn loop_op_kind(kind: &str) -> bool {
|
||||
matches!(
|
||||
kind,
|
||||
"LoopInput" | "LoopInputStatic" | "LoopOutput" | "LoopOutputSelect"
|
||||
)
|
||||
}
|
||||
|
||||
fn direct_loop_marker(kind: &str) -> bool {
|
||||
matches!(kind, "LoopStart" | "LoopEnd")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct LoopMarkerKey {
|
||||
label: String,
|
||||
fields: Vec<String>,
|
||||
}
|
||||
|
||||
fn validate_unique_loop_markers(egraph: &SerializedEGraph) {
|
||||
let mut seen = FxHashMap::default();
|
||||
for node in egraph.enodes.keys() {
|
||||
for key in loop_marker_keys_for_node(egraph, node) {
|
||||
if let Some(previous) = seen.insert(key.clone(), node.clone()) {
|
||||
panic!(
|
||||
"CUDA memory splitter duplicated loop marker {key:?}: {previous:?} and {node:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn loop_marker_keys_for_node(egraph: &SerializedEGraph, node: &NodeId) -> Vec<LoopMarkerKey> {
|
||||
let Some((label, children)) = egraph.enodes.get(node) else {
|
||||
return Vec::new();
|
||||
};
|
||||
if direct_loop_marker(label) {
|
||||
return vec![LoopMarkerKey {
|
||||
label: label.clone(),
|
||||
fields: field_signature(egraph, children.iter().skip(1)),
|
||||
}];
|
||||
}
|
||||
if label != "Op" {
|
||||
return Vec::new();
|
||||
}
|
||||
let Some(kind_class) = children.first() else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Some((sort, kind_nodes)) = egraph.eclasses.get(kind_class) else {
|
||||
return Vec::new();
|
||||
};
|
||||
if sort != "OpKind" {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
kind_nodes
|
||||
.iter()
|
||||
.filter_map(|kind_node| {
|
||||
let (kind_label, kind_children) = egraph.enodes.get(kind_node)?;
|
||||
loop_op_kind(kind_label).then(|| LoopMarkerKey {
|
||||
label: kind_label.clone(),
|
||||
fields: field_signature(egraph, kind_children.iter()),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn field_signature<'a>(
|
||||
egraph: &SerializedEGraph,
|
||||
fields: impl Iterator<Item = &'a ClassId>,
|
||||
) -> Vec<String> {
|
||||
fields
|
||||
.map(|class| {
|
||||
let node_label = egraph
|
||||
.eclasses
|
||||
.get(class)
|
||||
.and_then(|(_, nodes)| {
|
||||
nodes
|
||||
.iter()
|
||||
.find_map(|node| egraph.enodes.get(node).map(|(label, _)| label.clone()))
|
||||
})
|
||||
.unwrap_or_else(|| "<missing>".to_string());
|
||||
format!("{}:{node_label}", class.as_ref())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cuda_sort_map() -> FxHashMap<String, SortDef> {
|
||||
<(crate::kernel::Ops, crate::host::Ops) as luminal::op::IntoEgglogOp>::into_vec()
|
||||
.into_iter()
|
||||
@@ -1104,7 +1250,7 @@ fn local_output_bytes<'a>(
|
||||
) -> Option<Expression> {
|
||||
match sort.name.as_str() {
|
||||
name if zero_local_op_kind(name) => Some(0.into()),
|
||||
name if name.starts_with("Fused") || name == "FusionStart" => Some(0.into()),
|
||||
name if name.starts_with("Cuda") || name == "FusionStart" => Some(0.into()),
|
||||
"KernelConstant" => Some(4.into()),
|
||||
"KernelIota" => Some(expr_field(egraph, sort, kind_children, "range", expr_cache)? * 4),
|
||||
"KernelLessThan" => Some(n_elements_field(
|
||||
@@ -1135,7 +1281,7 @@ fn local_output_bytes<'a>(
|
||||
let dtype = dtype_field(egraph, sort, kind_children, "dtype")?;
|
||||
Some(bytes_for_elements(size, dtype))
|
||||
}
|
||||
"cublaslt" => {
|
||||
"cublaslt" | "cublaslt_scaled" => {
|
||||
let batch = expr_field(egraph, sort, kind_children, "batch_count", expr_cache)?;
|
||||
let m = expr_field(egraph, sort, kind_children, "m", expr_cache)?;
|
||||
let n = expr_field(egraph, sort, kind_children, "n", expr_cache)?;
|
||||
@@ -1213,7 +1359,7 @@ fn n_elements_field<'a>(
|
||||
fn output_bytes_rules(sort: &SortDef) -> Vec<String> {
|
||||
match sort.name.as_str() {
|
||||
name if zero_local_op_kind(name) => vec![output_bytes_rule(sort, "(MNum 0)", "zero")],
|
||||
name if name.starts_with("Fused") || name == "FusionStart" => {
|
||||
name if name.starts_with("Cuda") || name == "FusionStart" => {
|
||||
vec![output_bytes_rule(sort, "(MNum 0)", "zero")]
|
||||
}
|
||||
"KernelConstant" => vec![output_bytes_rule(sort, "(MNum 4)", "f32-scalar")],
|
||||
@@ -1244,7 +1390,7 @@ fn output_bytes_rules(sort: &SortDef) -> Vec<String> {
|
||||
&["(= ?__cuda_elems (n_elements ?batch_shape))"],
|
||||
)],
|
||||
"KernelCast" => dtype_output_bytes_rules(sort, "size", "dtype"),
|
||||
"cublaslt" => {
|
||||
"cublaslt" | "cublaslt_scaled" => {
|
||||
dtype_output_bytes_rules_for_expr(sort, "(MMul (MMul ?batch_count ?m) ?n)", "d_dtype")
|
||||
}
|
||||
"GLUMoE" => vec![output_bytes_rule(
|
||||
@@ -1371,7 +1517,9 @@ fn output_bytes_rule_with_facts(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{cuda_memory_analysis_pass, estimate_graph_memory_bytes};
|
||||
use super::{
|
||||
cuda_memory_analysis_pass, estimate_graph_memory_bytes, loop_marker_keys_for_node,
|
||||
};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
EGraphChoiceSet, SerializedEGraph, count_choice_sets_up_to, random_initial_choice,
|
||||
@@ -1383,11 +1531,7 @@ mod tests {
|
||||
};
|
||||
|
||||
fn ops() -> Vec<std::sync::Arc<Box<dyn luminal::op::EgglogOp>>> {
|
||||
let mut ops = <(
|
||||
crate::kernel::hlir::Ops,
|
||||
crate::kernel::other_ops::Ops,
|
||||
crate::host::Ops,
|
||||
) as IntoEgglogOp>::into_vec();
|
||||
let mut ops = <(crate::kernel::Ops, crate::host::Ops) as IntoEgglogOp>::into_vec();
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
ops
|
||||
}
|
||||
@@ -1399,13 +1543,13 @@ mod tests {
|
||||
.expect("cuda memory pass should parse and run")
|
||||
}
|
||||
|
||||
fn kernel_add(name: &str, size: usize, a: &str, b: &str) -> String {
|
||||
fn kernel_mod(name: &str, size: &str, a: &str, b: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
(let {name}
|
||||
(Op
|
||||
(KernelAdd
|
||||
(ECons (MNum {size}) (ENil))
|
||||
(KernelMod
|
||||
(ECons {size} (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
@@ -1454,25 +1598,20 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_late_pass_runs_on_kernel_add() {
|
||||
fn cuda_memory_late_pass_runs_on_kernel_mod() {
|
||||
let ops = ops();
|
||||
let late_pass = cuda_memory_analysis_pass(&ops, None, &FxHashMap::default());
|
||||
let program = r#"
|
||||
let program = format!(
|
||||
r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
(let t2
|
||||
(Op
|
||||
(KernelAdd
|
||||
(ECons (MNum 4) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(F32))
|
||||
(ICons t0 (ICons t1 (INil)))))
|
||||
{}
|
||||
(let t3 (Output t2 2))
|
||||
"#;
|
||||
"#,
|
||||
kernel_mod("t2", "(MNum 4)", "t0", "t1"),
|
||||
);
|
||||
|
||||
run_egglog_with_late_passes(program, "t3", &ops, false, &[late_pass])
|
||||
run_egglog_with_late_passes(&program, "t3", &ops, false, &[late_pass])
|
||||
.expect("cuda memory pass should parse and run");
|
||||
}
|
||||
|
||||
@@ -1499,6 +1638,55 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_state_split_does_not_duplicate_loop_markers() {
|
||||
let program = format!(
|
||||
r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
{}
|
||||
{}
|
||||
(union small big)
|
||||
(let loop_start (LoopStart small 0 0 (MNum 2) (F32)))
|
||||
(let loop_end (LoopEnd small 0 0 (F32)))
|
||||
(let loop_input (Op (LoopInput 0 0 (F32)) (ICons small (ICons t0 (INil)))))
|
||||
(let loop_output (Op (LoopOutput 0 0 (F32)) (ICons small (INil))))
|
||||
(let loop_select (Op (LoopOutputSelect 0 0 0 (F32)) (ICons loop_output (INil))))
|
||||
(let out_start (Output loop_start 2))
|
||||
(let out_end (Output loop_end 3))
|
||||
(let out_input (Output loop_input 4))
|
||||
(let out_select (Output loop_select 5))
|
||||
(let out_a (OutputJoin out_start out_end))
|
||||
(let out_b (OutputJoin out_input out_select))
|
||||
(let out (OutputJoin out_a out_b))
|
||||
"#,
|
||||
kernel_mod("small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("big", "(MNum 8)", "t0", "t1"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(1024));
|
||||
let mut marker_counts = FxHashMap::<String, usize>::default();
|
||||
for node in egraph.enodes.keys() {
|
||||
for key in loop_marker_keys_for_node(&egraph, node) {
|
||||
*marker_counts.entry(key.label).or_default() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for marker in [
|
||||
"LoopStart",
|
||||
"LoopEnd",
|
||||
"LoopInput",
|
||||
"LoopOutput",
|
||||
"LoopOutputSelect",
|
||||
] {
|
||||
assert_eq!(
|
||||
marker_counts.get(marker).copied().unwrap_or_default(),
|
||||
1,
|
||||
"{marker} should not be duplicated by memory state splitting"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_estimates_peak_for_two_live_inputs() {
|
||||
let program = format!(
|
||||
@@ -1510,9 +1698,9 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 3))
|
||||
"#,
|
||||
kernel_add("left", 4, "t0", "t1"),
|
||||
kernel_add("right", 4, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left", "right"),
|
||||
kernel_mod("left", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("right", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left", "right"),
|
||||
);
|
||||
let egraph = run_memory_egraph(&program, "out", None);
|
||||
let mut rng = rand::rng();
|
||||
@@ -1546,7 +1734,7 @@ mod tests {
|
||||
(ICons dest (ICons indexes (ICons src (INil))))))
|
||||
(let out (Output scatter 4))
|
||||
"#,
|
||||
kernel_add("dest", 4, "t0", "t1"),
|
||||
kernel_mod("dest", "(MNum 4)", "t0", "t1"),
|
||||
);
|
||||
let egraph = run_memory_egraph(&program, "out", None);
|
||||
let mut rng = rand::rng();
|
||||
@@ -1569,8 +1757,8 @@ mod tests {
|
||||
(union small big)
|
||||
(let out (Output small 2))
|
||||
"#,
|
||||
kernel_add("small", 4, "t0", "t1"),
|
||||
kernel_add("big", 32, "t0", "t1"),
|
||||
kernel_mod("small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("big", "(MNum 32)", "t0", "t1"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(64));
|
||||
@@ -1590,22 +1778,17 @@ mod tests {
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', 4);
|
||||
let late_pass = cuda_memory_analysis_pass(&ops, Some(16), &dyn_map);
|
||||
let program = r#"
|
||||
let program = format!(
|
||||
r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
(let add
|
||||
(Op
|
||||
(KernelAdd
|
||||
(ECons (MVar "s") (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(F32))
|
||||
(ICons t0 (ICons t1 (INil)))))
|
||||
{}
|
||||
(let out (Output add 2))
|
||||
"#;
|
||||
"#,
|
||||
kernel_mod("add", "(MVar \"s\")", "t0", "t1"),
|
||||
);
|
||||
|
||||
let egraph = run_egglog_with_late_passes(program, "out", &ops, false, &[late_pass])
|
||||
let egraph = run_egglog_with_late_passes(&program, "out", &ops, false, &[late_pass])
|
||||
.expect("cuda memory pass should parse and run");
|
||||
assert_eq!(count_choice_sets_up_to(&egraph, 10), 1);
|
||||
|
||||
@@ -1628,9 +1811,9 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 3))
|
||||
"#,
|
||||
kernel_add("left", 12, "t0", "t1"),
|
||||
kernel_add("right", 12, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left", "right"),
|
||||
kernel_mod("left", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("right", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left", "right"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(64));
|
||||
@@ -1659,11 +1842,11 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 4))
|
||||
"#,
|
||||
kernel_add("left_small", 4, "t0", "t1"),
|
||||
kernel_add("left_medium", 8, "t0", "t1"),
|
||||
kernel_add("left_big", 12, "t0", "t1"),
|
||||
kernel_add("right_small", 4, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left_small", "right_small"),
|
||||
kernel_mod("left_small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("left_medium", "(MNum 8)", "t0", "t1"),
|
||||
kernel_mod("left_big", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("right_small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left_small", "right_small"),
|
||||
);
|
||||
|
||||
let uncapped_start = std::time::Instant::now();
|
||||
|
||||
@@ -80,6 +80,14 @@ struct PlannedBuffer {
|
||||
end: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct NonFiniteBufferReport {
|
||||
pub(crate) node: NodeIndex,
|
||||
pub(crate) index: usize,
|
||||
pub(crate) value: f32,
|
||||
}
|
||||
|
||||
/// Per-bucket compiled state. Each bucket holds its own executable graph,
|
||||
/// explicit runtime metadata, intermediate buffers, and node mappings.
|
||||
/// Weights (hlir_buffers) are shared.
|
||||
@@ -106,6 +114,9 @@ pub(crate) struct CompiledBucket {
|
||||
pub(crate) bucket_indices: FxHashMap<char, usize>,
|
||||
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
|
||||
pub(crate) hlir_synced: bool,
|
||||
/// Test/debug mode: give every intermediate a distinct arena range so
|
||||
/// post-execution diagnostics can inspect expired nodes without reuse noise.
|
||||
pub(crate) preserve_intermediate_buffers_for_debug: bool,
|
||||
}
|
||||
|
||||
impl CompiledBucket {
|
||||
@@ -130,6 +141,7 @@ impl CompiledBucket {
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
bucket_indices: FxHashMap::default(),
|
||||
hlir_synced: false,
|
||||
preserve_intermediate_buffers_for_debug: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,10 +237,96 @@ impl CudaRuntime {
|
||||
result::memcpy_dtod_async(dst_ptr, src.ptr(), src.len(), stream.cu_stream())
|
||||
.expect("cuMemcpyDtoDAsync failed");
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
dst
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn first_nonfinite_f32_buffer_in_nodes(
|
||||
&self,
|
||||
nodes: impl IntoIterator<Item = NodeIndex>,
|
||||
) -> Option<NonFiniteBufferReport> {
|
||||
let _ = self.cuda_stream.synchronize();
|
||||
let bucket = self.active();
|
||||
let mut checked = FxHashSet::default();
|
||||
|
||||
for node in nodes {
|
||||
let spec_node = resolve_logical_buffer_node(
|
||||
node,
|
||||
&bucket.logical_buffer_bytes,
|
||||
&bucket.output_alias_map,
|
||||
)
|
||||
.unwrap_or(node);
|
||||
if !checked.insert(spec_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(spec) = bucket.buffer_specs.get(&spec_node) else {
|
||||
continue;
|
||||
};
|
||||
if !matches!(spec.dtype, DType::F32) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(buf) = Self::resolve_runtime_buffer(
|
||||
bucket,
|
||||
&self.cuda_stream,
|
||||
&self.hlir_buffers,
|
||||
&self.external_buffers,
|
||||
&self.external_output_buffers,
|
||||
spec_node,
|
||||
) else {
|
||||
continue;
|
||||
};
|
||||
if buf.is_empty() || buf.len() % std::mem::size_of::<f32>() != 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let host_bytes = match buf.clone_dtoh(&self.cuda_stream) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let values: &[f32] = bytemuck::cast_slice(&host_bytes);
|
||||
if let Some((index, value)) = values
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.find(|(_, value)| !value.is_finite())
|
||||
{
|
||||
return Some(NonFiniteBufferReport {
|
||||
node: spec_node,
|
||||
index,
|
||||
value,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn first_nonfinite_f32_buffer(&self) -> Option<NonFiniteBufferReport> {
|
||||
let bucket = self.active();
|
||||
self.first_nonfinite_f32_buffer_in_nodes(
|
||||
bucket
|
||||
.buffer_specs
|
||||
.keys()
|
||||
.copied()
|
||||
.sorted_by_key(|node| node.index()),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn preserve_intermediate_buffers_for_debug(&mut self) {
|
||||
for bucket in &mut self.compiled_buckets {
|
||||
bucket.preserve_intermediate_buffers_for_debug = true;
|
||||
bucket.logical_buffer_offsets.clear();
|
||||
bucket.logical_buffer_bytes.clear();
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
bucket.arena = None;
|
||||
bucket.arena_bytes = 0;
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_runtime_buffer(
|
||||
bucket: &CompiledBucket,
|
||||
stream: &Arc<CudaStream>,
|
||||
@@ -287,7 +385,12 @@ impl CudaRuntime {
|
||||
let dev = f32s.to_cuda_input(&self.cuda_stream);
|
||||
self.hlir_buffers.insert(node, dev);
|
||||
}
|
||||
safetensors::Dtype::U8 | safetensors::Dtype::BF16 | safetensors::Dtype::F16 => {
|
||||
safetensors::Dtype::U8
|
||||
| safetensors::Dtype::BF16
|
||||
| safetensors::Dtype::F16
|
||||
| safetensors::Dtype::F8_E4M3
|
||||
| safetensors::Dtype::F8_E5M2
|
||||
| safetensors::Dtype::F8_E8M0 => {
|
||||
let bytes = tensor.data();
|
||||
let dev = bytes.to_cuda_input(&self.cuda_stream);
|
||||
self.hlir_buffers.insert(node, dev);
|
||||
@@ -646,7 +749,57 @@ impl CudaRuntime {
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Read an output buffer as i64. Strict: the buffer must already
|
||||
/// be `DType::I64`; no widening at the read boundary.
|
||||
pub fn get_i64(&self, id: impl ToId) -> Vec<i64> {
|
||||
let id = id.to_id();
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
let buf_dtype = bucket.buffer_specs.get(&data_id).map(|s| s.dtype);
|
||||
if !matches!(buf_dtype, Some(DType::I64)) {
|
||||
panic!(
|
||||
"get_i64: buffer dtype is {buf_dtype:?}, expected I64. \
|
||||
Add a `Cast(DType::I64)` before the Output."
|
||||
);
|
||||
}
|
||||
self.get_output_data(id)
|
||||
.chunks_exact(8)
|
||||
.map(|c| i64::from_ne_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Read an output buffer as f64. Strict: the buffer must already
|
||||
/// be `DType::F64`; no widening at the read boundary.
|
||||
pub fn get_f64(&self, id: impl ToId) -> Vec<f64> {
|
||||
let id = id.to_id();
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
let buf_dtype = bucket.buffer_specs.get(&data_id).map(|s| s.dtype);
|
||||
if !matches!(buf_dtype, Some(DType::F64)) {
|
||||
panic!(
|
||||
"get_f64: buffer dtype is {buf_dtype:?}, expected F64. \
|
||||
Add a `Cast(DType::F64)` before the Output."
|
||||
);
|
||||
}
|
||||
self.get_output_data(id)
|
||||
.chunks_exact(8)
|
||||
.map(|c| f64::from_ne_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Read an output buffer as f16. Strict: the buffer must already
|
||||
/// be `DType::F16`; no widening at the read boundary.
|
||||
pub fn get_f16(&self, id: impl ToId) -> Vec<f16> {
|
||||
let id = id.to_id();
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
let buf_dtype = bucket.buffer_specs.get(&data_id).map(|s| s.dtype);
|
||||
if !matches!(buf_dtype, Some(DType::F16)) {
|
||||
panic!(
|
||||
"get_f16: buffer dtype is {buf_dtype:?}, expected F16. \
|
||||
Add a `Cast(DType::F16)` before the Output."
|
||||
);
|
||||
}
|
||||
let bytes = self.get_output_data(id);
|
||||
let n = bytes.len() / 2;
|
||||
let cap = bytes.capacity() / 2;
|
||||
@@ -655,7 +808,19 @@ impl CudaRuntime {
|
||||
unsafe { Vec::from_raw_parts(ptr, n, cap) }
|
||||
}
|
||||
|
||||
/// Read an output buffer as bf16. Strict: the buffer must already
|
||||
/// be `DType::Bf16`; no widening at the read boundary.
|
||||
pub fn get_bf16(&self, id: impl ToId) -> Vec<bf16> {
|
||||
let id = id.to_id();
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
let buf_dtype = bucket.buffer_specs.get(&data_id).map(|s| s.dtype);
|
||||
if !matches!(buf_dtype, Some(DType::Bf16)) {
|
||||
panic!(
|
||||
"get_bf16: buffer dtype is {buf_dtype:?}, expected Bf16. \
|
||||
Add a `Cast(DType::Bf16)` before the Output."
|
||||
);
|
||||
}
|
||||
let bytes = self.get_output_data(id);
|
||||
let n = bytes.len() / 2;
|
||||
let cap = bytes.capacity() / 2;
|
||||
@@ -894,6 +1059,32 @@ impl CudaRuntime {
|
||||
let planned_logical_bytes = planned.iter().map(|buf| buf.bytes).sum::<usize>();
|
||||
let logical_peak = logical_interval_peak(&planned);
|
||||
|
||||
if bucket.preserve_intermediate_buffers_for_debug {
|
||||
planned.sort_by_key(|buf| buf.node.index());
|
||||
let mut arena_end = 0usize;
|
||||
for buf in &planned {
|
||||
let offset = align_up(arena_end, ARENA_ALIGNMENT);
|
||||
bucket.logical_buffer_offsets.insert(buf.node, offset);
|
||||
bucket.logical_buffer_bytes.insert(buf.node, buf.bytes);
|
||||
arena_end = offset + align_up(buf.bytes, ARENA_ALIGNMENT);
|
||||
}
|
||||
bucket.arena_bytes = arena_end;
|
||||
|
||||
if std::env::var_os("LUMINAL_CUDA_MEMORY_DEBUG").is_some() {
|
||||
eprintln!(
|
||||
" CUDA memory plan specs={total_spec_count} used={planned_logical_count} skipped={} spec_bytes={} used_bytes={} skipped_bytes={} logical_peak={} preserved_arena={} allocations={}",
|
||||
total_spec_count.saturating_sub(planned_logical_count),
|
||||
total_spec_bytes,
|
||||
planned_logical_bytes,
|
||||
total_spec_bytes.saturating_sub(planned_logical_bytes),
|
||||
logical_peak,
|
||||
bucket.arena_bytes,
|
||||
bucket.logical_buffer_offsets.len(),
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let mut arena_end = 0usize;
|
||||
let mut placed: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(planned.len());
|
||||
let mut placement_order = planned.iter().collect_vec();
|
||||
@@ -1178,7 +1369,7 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
fn late_egglog_passes(
|
||||
ops: &[Arc<Box<dyn luminal::op::EgglogOp>>],
|
||||
options: &luminal::graph::BuildSearchSpaceOptions,
|
||||
options: &luminal::graph::CompileOptions,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
|
||||
vec![crate::memory_analysis::cuda_memory_analysis_pass(
|
||||
@@ -1189,7 +1380,7 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
|
||||
fn estimate_graph_memory<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
choices: &luminal::egglog_utils::EGraphChoiceSet<'a>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Option<usize> {
|
||||
@@ -1343,8 +1534,8 @@ impl Runtime for CudaRuntime {
|
||||
&mut self,
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
// Clear active bucket's arena before loading new LLIR for profiling.
|
||||
if !self.compiled_buckets.is_empty() {
|
||||
@@ -1352,10 +1543,18 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
self.load_llir(llir_graph);
|
||||
self.profiling = true;
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
let profile_start = std::time::Instant::now();
|
||||
let mut durations = Vec::with_capacity(trials.max(1));
|
||||
for _ in 0..trials.max(1) {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
durations.push(start.elapsed());
|
||||
if timeout.is_some_and(|timeout| profile_start.elapsed() >= timeout) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.profiling = false;
|
||||
let duration = start.elapsed();
|
||||
let duration = durations.iter().sum::<std::time::Duration>() / durations.len() as u32;
|
||||
|
||||
let total_bytes: usize = self
|
||||
.last_kernel_stats
|
||||
@@ -1397,6 +1596,35 @@ impl Runtime for CudaRuntime {
|
||||
.filter(|n| n.to_dialect::<dyn HostOp>().is_some())
|
||||
.count()
|
||||
);
|
||||
let display = if std::env::var_os("LUMINAL_SEARCH_OP_NAMES").is_some() {
|
||||
let mut kernel_counts = std::collections::BTreeMap::<&'static str, usize>::new();
|
||||
let mut host_counts = std::collections::BTreeMap::<String, usize>::new();
|
||||
for node in llir_graph.node_weights() {
|
||||
if let Some(kernel) = node.to_dialect::<dyn KernelOp>() {
|
||||
*kernel_counts.entry(kernel.kernel_name()).or_default() += 1;
|
||||
}
|
||||
if let Some(host) = node.to_dialect::<dyn HostOp>() {
|
||||
let debug = format!("{:?}", host.as_ref().as_ref());
|
||||
let name = debug
|
||||
.split([' ', '{', '('])
|
||||
.next()
|
||||
.unwrap_or("HostOp")
|
||||
.to_string();
|
||||
*host_counts.entry(name).or_default() += 1;
|
||||
}
|
||||
}
|
||||
let kernel_summary = kernel_counts
|
||||
.iter()
|
||||
.map(|(name, count)| format!("{name}:{count}"))
|
||||
.join(",");
|
||||
let host_summary = host_counts
|
||||
.iter()
|
||||
.map(|(name, count)| format!("{name}:{count}"))
|
||||
.join(",");
|
||||
format!("{display} [Kernels: {kernel_summary}] [Hosts: {host_summary}]")
|
||||
} else {
|
||||
display
|
||||
};
|
||||
|
||||
(duration, display)
|
||||
}
|
||||
@@ -1417,35 +1645,6 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
Self::allocate_intermediate_buffers(bucket, &self.cuda_stream, dyn_map);
|
||||
// Cache HLIR input pointers
|
||||
if !self.changed_hlir.is_empty() || !bucket.hlir_synced {
|
||||
let hlir_nodes: Vec<NodeIndex> = if !bucket.hlir_synced {
|
||||
// First time this bucket is active since HLIR changed — sync all
|
||||
self.hlir_buffers.keys().copied().collect()
|
||||
} else {
|
||||
self.changed_hlir.iter().copied().collect()
|
||||
};
|
||||
for hlir_node in hlir_nodes {
|
||||
let Some(&llir_node) = bucket.hlir_to_llir.get(&hlir_node) else {
|
||||
continue;
|
||||
};
|
||||
let Some(input) = self.hlir_buffers.get(&hlir_node) else {
|
||||
continue;
|
||||
};
|
||||
let ptr = match input {
|
||||
CudaInput::Buffer(buf) => buf.device_ptr(&self.cuda_stream).0,
|
||||
CudaInput::Ptr(p) => *p,
|
||||
};
|
||||
bucket.cached_buffer_ptrs.insert(llir_node, ptr);
|
||||
}
|
||||
bucket.hlir_synced = true;
|
||||
// Only clear changed_hlir if single bucket (multi-bucket: others may need it)
|
||||
if self.compiled_buckets.len() == 1 {
|
||||
self.changed_hlir.clear();
|
||||
}
|
||||
}
|
||||
// Ensure all CUDA graphs are built (handles first execute and any missing graphs)
|
||||
self.prebuild_graphs(dyn_map);
|
||||
|
||||
@@ -1522,6 +1721,21 @@ impl Runtime for CudaRuntime {
|
||||
exec_op.internal.stats_name().unwrap_or("unknown")
|
||||
);
|
||||
});
|
||||
|
||||
#[cfg(test)]
|
||||
if std::env::var_os("LUMINAL_CUDA_CHECK_NONFINITE_INTERNAL").is_some() {
|
||||
let mut produced_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
produced_nodes.push(exec_op.output);
|
||||
if let Some(report) = self.first_nonfinite_f32_buffer_in_nodes(produced_nodes) {
|
||||
panic!(
|
||||
"CUDA execute produced non-finite buffer after {:?}: node={} index={} value={}",
|
||||
exec_op.internal.stats_name().unwrap_or("unknown"),
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Single sync at end - CUDA stream ordering guarantees sequential execution
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
@@ -1657,8 +1871,8 @@ impl CudaRuntime {
|
||||
//
|
||||
// The default assumption is "yes" for ordinary kernel ops
|
||||
// (Conv outputs, matmul outputs, etc). FusionStart and
|
||||
// Fused* are the exceptions — they're synthetic markers
|
||||
// that the fusion rewrites add inside a region; the
|
||||
// Cuda*Elementwise are the exceptions — they're synthetic
|
||||
// nodes that the fusion rewrites add inside a region; the
|
||||
// megakernel computes them in registers and never writes
|
||||
// to memory, so allocating a buffer would just be waste.
|
||||
//
|
||||
@@ -1673,12 +1887,12 @@ impl CudaRuntime {
|
||||
// an unrelated downstream op that lives in another region.
|
||||
//
|
||||
// Safe over-approximation: if the node is a FusionStart /
|
||||
// Fused* and *any* of its consumers is a FusionStart
|
||||
// Cuda*Elementwise and *any* of its consumers is a FusionStart
|
||||
// (which can only happen when that consumer is the leaf
|
||||
// of a different region) or a non-marker op (e.g. an
|
||||
// unfused Add/Mul reading the value directly), allocate a
|
||||
// buffer so cross-region reads have somewhere to land.
|
||||
let is_marker = kernel_name == "FusionStart" || kernel_name.starts_with("Fused");
|
||||
let is_marker = kernel_name == "FusionStart" || kernel_name.starts_with("Cuda");
|
||||
let has_external_consumer = is_marker
|
||||
&& llir_graph
|
||||
.neighbors_directed(node, Direction::Outgoing)
|
||||
|
||||
@@ -22,6 +22,10 @@ fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeInde
|
||||
(cx, a.id, b.id, c.id)
|
||||
}
|
||||
|
||||
fn bucket_options(buckets: &[DimBucket]) -> CompileOptions {
|
||||
CompileOptions::default().dim_buckets('s', buckets)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_dispatch_simple() {
|
||||
// Tests that bucketed compilation produces correct results for different dim values
|
||||
@@ -31,9 +35,10 @@ fn test_bucket_dispatch_simple() {
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Set dummy input for search
|
||||
@@ -41,7 +46,7 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -73,9 +78,10 @@ fn test_bucket_matmul_dynamic() {
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 8),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
@@ -85,7 +91,7 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -135,12 +141,12 @@ fn test_bucket_results_match_unbucketed() {
|
||||
// Non-bucketed run
|
||||
let (mut cx1, a1, b1) = build_dynamic_add_graph();
|
||||
cx1.set_dim('s', 3);
|
||||
cx1.build_search_space::<CudaRuntime>();
|
||||
cx1.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt1 = CudaRuntime::initialize(stream.clone());
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
|
||||
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -148,12 +154,11 @@ fn test_bucket_results_match_unbucketed() {
|
||||
// Bucketed run with bucket that covers s=3
|
||||
let (mut cx2, a2, b2) = build_dynamic_add_graph();
|
||||
cx2.set_dim('s', 3);
|
||||
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
|
||||
cx2.build_search_space::<CudaRuntime>();
|
||||
cx2.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 4)]));
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
|
||||
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -172,14 +177,16 @@ fn test_bucket_out_of_range_panics() {
|
||||
};
|
||||
|
||||
let (mut cx, a, _b) = build_dynamic_add_graph();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -197,14 +204,14 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
cx.set_dim('s', 2);
|
||||
|
||||
// No set_dim_buckets call
|
||||
// No bucket options
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -237,9 +244,10 @@ fn test_bucket_switch_preserves_weights() {
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
@@ -249,7 +257,7 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -297,15 +305,13 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 8)]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
@@ -323,8 +329,7 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
#[test]
|
||||
#[should_panic(expected = "Overlapping buckets")]
|
||||
fn test_bucket_overlapping_ranges_panics() {
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
let _ = bucket_options(&[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
fn conv2d_bias_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
let output_spatial_dims = unfolded.dims()[1..3].to_vec();
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out = out
|
||||
.split_dims(0, output_spatial_dims[1])
|
||||
.permute(&[2, 0, 1]);
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(1, out_dims[1]).expand_dim(2, out_dims[2])
|
||||
}
|
||||
|
||||
fn build_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_hlir(x, weight, bias, 3, 2).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_bias_padded_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
let zero = Expression::from(0);
|
||||
let pad = Expression::from(padding);
|
||||
let padded = x.pad(vec![(zero, zero), (pad, pad), (pad, pad)], 0.0);
|
||||
conv2d_bias_hlir(padded, weight, bias, kernel, kernel)
|
||||
}
|
||||
|
||||
fn build_padded_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_padded_hlir(x, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn nearest_upsample_2x_hlir(x: GraphTensor) -> GraphTensor {
|
||||
let stage1 = x.expand_dim(2, 2usize).merge_dims(1, 2);
|
||||
stage1.expand_dim(3, 2usize).merge_dims(2, 3)
|
||||
}
|
||||
|
||||
fn build_upsample_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 3usize, 4usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let up = nearest_upsample_2x_hlir(x);
|
||||
let out = conv2d_bias_padded_hlir(up, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv1x1_bias_hlir(x: GraphTensor, weight: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
let h = dims[1];
|
||||
let w = dims[2];
|
||||
let xt = x.permute(&[1, 2, 0]).merge_dims(0, 1);
|
||||
let out = xt.matmul(weight.t());
|
||||
let out = out.split_dims(0, w).permute(&[2, 0, 1]);
|
||||
out + bias.expand_dim(1, h).expand_dim(2, w)
|
||||
}
|
||||
|
||||
fn build_conv1x1_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv1x1_bias_hlir(x, weight, bias).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_matmul_without_conv_output_shape(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(0, out_dims[0])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_unfold_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate"
|
||||
);
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "Add").is_empty(),
|
||||
"generic conv2d cleanup should prune the final bias Add fallback"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_conv1x1_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate for 1x1 conv"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_requires_conv_output_shape() {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
conv2d_matmul_without_conv_output_shape(x, weight, bias, 3, 2).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"matmul+bias without [C_out,H_out,W_out] conv output shape should not match KernelConv2D"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 5 * 6).map(|i| i as f32 * 0.03 - 0.4).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 2)
|
||||
.map(|i| (i as f32 % 11.0) * 0.04 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.25_f32, -0.15, 0.05];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 5,
|
||||
w: 6,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 2,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_conv1x1_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.07 - 1.0).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2).map(|i| (i as f32 % 5.0) * 0.11 - 0.2).collect();
|
||||
let biases = vec![0.2_f32, -0.1, 0.4];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 1,
|
||||
kw: 1,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_padded_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_padded_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.05 - 0.5).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 13.0) * 0.03 - 0.17)
|
||||
.collect();
|
||||
let biases = vec![0.15_f32, -0.25, 0.35];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_upsample_view_input() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_upsample_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 * 0.09 - 0.8).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 17.0) * 0.025 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.05_f32, -0.1, 0.2];
|
||||
let upsampled = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
let expected = reference_conv2d(
|
||||
&upsampled,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 6,
|
||||
w: 8,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
struct ConvCase {
|
||||
c_in: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
c_out: usize,
|
||||
kh: usize,
|
||||
kw: usize,
|
||||
padding_h: usize,
|
||||
padding_w: usize,
|
||||
}
|
||||
|
||||
fn reference_nearest_upsample_2x(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0_f32; c * h * 2 * w * 2];
|
||||
for ci in 0..c {
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let value = input[ci * h * w + y * w + x];
|
||||
for dy in 0..2 {
|
||||
for dx in 0..2 {
|
||||
let oy = y * 2 + dy;
|
||||
let ox = x * 2 + dx;
|
||||
out[ci * h * 2 * w * 2 + oy * w * 2 + ox] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn reference_conv2d(input: &[f32], weight: &[f32], bias: &[f32], case: ConvCase) -> Vec<f32> {
|
||||
let ConvCase {
|
||||
c_in,
|
||||
h,
|
||||
w,
|
||||
c_out,
|
||||
kh,
|
||||
kw,
|
||||
padding_h,
|
||||
padding_w,
|
||||
} = case;
|
||||
let h_out = h + 2 * padding_h - kh + 1;
|
||||
let w_out = w + 2 * padding_w - kw + 1;
|
||||
let mut out = vec![0.0; c_out * h_out * w_out];
|
||||
for co in 0..c_out {
|
||||
for oh in 0..h_out {
|
||||
for ow in 0..w_out {
|
||||
let mut acc = bias[co];
|
||||
for ci in 0..c_in {
|
||||
for r in 0..kh {
|
||||
for s in 0..kw {
|
||||
let Some(ih) = (oh + r).checked_sub(padding_h) else {
|
||||
continue;
|
||||
};
|
||||
let Some(iw) = (ow + s).checked_sub(padding_w) else {
|
||||
continue;
|
||||
};
|
||||
if ih >= h || iw >= w {
|
||||
continue;
|
||||
}
|
||||
let input_idx = ci * h * w + ih * w + iw;
|
||||
let weight_idx = co * c_in * kh * kw + (ci * kh + r) * kw + s;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
out[co * h_out * w_out + oh * w_out + ow] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, "KernelConv2D");
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0xC0_2D00 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
ClassId, NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice,
|
||||
validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
@@ -11,7 +12,8 @@ use crate::{
|
||||
host::{
|
||||
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
|
||||
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
|
||||
cublaslt_scale_values, cublaslt_transpose_ops, cublaslt_type_tuple,
|
||||
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
|
||||
cublaslt_type_tuple,
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
@@ -443,6 +445,54 @@ fn cublaslt_rewrites_cover_batched_row_order_layout_pairs() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_cover_flux2_qk_transposed_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((8usize, 4usize));
|
||||
let k = cx.tensor((8usize, 4usize));
|
||||
let _out = q.matmul(k.t()).output();
|
||||
|
||||
assert_cublaslt_rewrite(cx, "flux2 q @ k.t()", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&("ROW", "COL", "ROW", "ROW"))
|
||||
|| cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_cover_flux2_linear_bias_epilogue() {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((8usize, 4usize));
|
||||
let weight = cx.tensor((6usize, 4usize));
|
||||
let bias = cx.tensor(6usize);
|
||||
let _out = (x.matmul(weight.t()) + bias.expand_dim(0, 8usize)).output();
|
||||
|
||||
assert_cublaslt_epilogue_rewrite(
|
||||
cx,
|
||||
"flux2 x @ weight.t() + bias",
|
||||
"BIAS",
|
||||
Some(("COL", "COL", "COL", "COL")),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_cleanup_prunes_flux2_broadcast_mul_fallback() {
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((8usize, 4usize));
|
||||
let k = cx.tensor((8usize, 4usize));
|
||||
let _out = q.matmul(k.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
assert!(
|
||||
!cublaslt_ir_nodes(egraph).is_empty(),
|
||||
"Flux2 q @ k.t() should have at least one cuBLASLt candidate"
|
||||
);
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "Mul").is_empty(),
|
||||
"cuBLASLt cleanup should prune the broadcast Mul fallback once a cuBLASLt candidate exists"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
|
||||
for case in LAYOUT_CASES {
|
||||
@@ -900,6 +950,196 @@ fn cublaslt_fp8_e4m3_beta_candidate_executes_2d_matmul_plus_f32_c() {
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA FP8 rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_fp8_scaled_candidate_executes_2d_matmul_f32_output() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !gpu_supports_cublaslt_fp8_launch(DType::F8E4M3) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.t();
|
||||
let scaled_a = (a / a_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let out =
|
||||
(scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n))).output();
|
||||
let expected_tuple = (
|
||||
DType::F8E4M3,
|
||||
DType::F8E4M3,
|
||||
DType::F32,
|
||||
DType::F32,
|
||||
"32F",
|
||||
DType::F32,
|
||||
);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "functional scaled fp8", |llir| {
|
||||
cublaslt_type_tuples(llir).contains(&expected_tuple)
|
||||
&& cublaslt_tensor_scale_input_tuples(llir).contains(&(true, true))
|
||||
&& cublaslt_transpose_op_tuples(llir).contains(&("T", "N"))
|
||||
&& cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
|
||||
let input_scale = 0.25f32;
|
||||
let weight_scale = 2.0f32;
|
||||
let (a_fp8_bytes, a_values) = fp8_exact_bytes(DType::F8E4M3, m * k, 7);
|
||||
let a_data = a_values
|
||||
.iter()
|
||||
.map(|value| value * input_scale)
|
||||
.collect::<Vec<_>>();
|
||||
let (b_bytes, b_storage_values) = fp8_exact_bytes(DType::F8E4M3, k * n, 9);
|
||||
let b_values = logical_b_from_column_major_storage(&b_storage_values, n, k);
|
||||
let mut expected = reference_matmul_2d(&a_values, &b_values, m, n, k);
|
||||
for value in &mut expected {
|
||||
*value *= input_scale * weight_scale;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(a_scale, vec![input_scale]);
|
||||
rt.set_data(b_scale, vec![weight_scale]);
|
||||
rt.set_data(b_input, b_bytes);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Keep the raw bytes live in the test construction: a_data was chosen so
|
||||
// the explicit scaled cast quantizes to these exact FP8 values.
|
||||
assert_eq!(a_fp8_bytes.len(), m * k);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_fp8_scaled_candidate_reaches_fused_output_scale_consumer() {
|
||||
let (m, n, k) = (16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.t();
|
||||
let side = cx.tensor((m, n));
|
||||
let scaled_a = (a / a_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let scaled_out = scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n));
|
||||
(scaled_out * side).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
dataflow_reachable_cublaslt_scaled_count(egraph) > 0,
|
||||
"scaled cuBLASLt must remain reachable when fusion growth consumes the output-scale multiply internally"
|
||||
);
|
||||
assert_eq!(
|
||||
dataflow_reachable_cublaslt_raw_fp8_count(egraph),
|
||||
0,
|
||||
"raw FP8 cuBLASLt must be deleted when a scaled equivalent covers the fused output-scale consumer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_fp8_scaled_candidates_reach_fused_mlp_consumer() {
|
||||
let (m, n, k) = (16, 32, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let gate_input_scale = cx.tensor(());
|
||||
let gate_weight_scale = cx.tensor(());
|
||||
let up_input_scale = cx.tensor(());
|
||||
let up_weight_scale = cx.tensor(());
|
||||
let gate_weight = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let up_weight = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
|
||||
let scaled_gate_a = (a / gate_input_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let gate = scaled_gate_a.matmul(gate_weight.t()).cast(DType::F32)
|
||||
* (gate_input_scale * gate_weight_scale).expand_rhs((m, n));
|
||||
let scaled_up_a = (a / up_input_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let up = scaled_up_a.matmul(up_weight.t()).cast(DType::F32)
|
||||
* (up_input_scale * up_weight_scale).expand_rhs((m, n));
|
||||
(gate.swish() * up).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
dataflow_reachable_cublaslt_scaled_count(egraph) >= 2,
|
||||
"scaled cuBLASLt candidates must remain reachable through fused MLP gate/up consumers"
|
||||
);
|
||||
assert_eq!(
|
||||
dataflow_reachable_cublaslt_raw_fp8_count(egraph),
|
||||
0,
|
||||
"raw FP8 cuBLASLt must be deleted when a scaled equivalent covers the fused MLP consumer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA FP8 rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_fp8_scaled_candidate_executes_batched_matmul_f32_output() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !gpu_supports_cublaslt_fp8_launch(DType::F8E4M3) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (batch, m, n, k) = (2, 16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((batch, n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.transpose(1, 2);
|
||||
let scaled_a = (a / a_scale.expand_rhs((batch, m, k))).cast(DType::F8E4M3);
|
||||
let lhs = scaled_a.expand_dim(2, n);
|
||||
let rhs = b.permute((0, 2, 1)).expand_dim(1, m);
|
||||
let mul = unchecked_mul_same_shape(lhs, rhs, DType::F8E4M3);
|
||||
let matmul = mul.sum(3).cast(DType::F32);
|
||||
let out = (matmul * (a_scale * b_scale).expand_rhs((batch, m, n))).output();
|
||||
let expected_tuple = (
|
||||
DType::F8E4M3,
|
||||
DType::F8E4M3,
|
||||
DType::F32,
|
||||
DType::F32,
|
||||
"32F",
|
||||
DType::F32,
|
||||
);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "functional scaled batched fp8", |llir| {
|
||||
cublaslt_type_tuples(llir).contains(&expected_tuple)
|
||||
&& cublaslt_tensor_scale_input_tuples(llir).contains(&(true, true))
|
||||
&& cublaslt_transpose_op_tuples(llir).contains(&("T", "N"))
|
||||
&& cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
|
||||
let input_scale = 0.5f32;
|
||||
let weight_scale = 1.5f32;
|
||||
let (a_fp8_bytes, a_values) = fp8_exact_bytes(DType::F8E4M3, batch * m * k, 11);
|
||||
let a_data = a_values
|
||||
.iter()
|
||||
.map(|value| value * input_scale)
|
||||
.collect::<Vec<_>>();
|
||||
let (b_bytes, b_storage_values) = fp8_exact_bytes(DType::F8E4M3, batch * k * n, 13);
|
||||
let b_values = logical_b_from_batched_column_major_storage(&b_storage_values, batch, n, k);
|
||||
let mut expected = reference_matmul_batched(&a_values, &b_values, batch, m, n, k);
|
||||
for value in &mut expected {
|
||||
*value *= input_scale * weight_scale;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(a_scale, vec![input_scale]);
|
||||
rt.set_data(b_scale, vec![weight_scale]);
|
||||
rt.set_data(b_input, b_bytes);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(a_fp8_bytes.len(), batch * m * k);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
fn cublaslt_fp8_candidate_executes_2d_matmul_f32_output(a_dtype: DType, b_dtype: DType) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -2168,6 +2408,85 @@ fn cublaslt_row_order_candidate_executes_2d_layout_pairs() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "large row-order CUDA functional repro for llama lm_head shape"]
|
||||
fn cublaslt_row_order_candidate_executes_large_lm_head_like_projection() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (m, n, k) = (1, 128_256, 64);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b_input = cx.tensor((n, k));
|
||||
let b = b_input.t();
|
||||
let out = a.matmul(b).output();
|
||||
let expected_orders = ("ROW", "COL", "ROW", "ROW");
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "lm_head-like row-order", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&expected_orders)
|
||||
&& cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0x1A11_A000, -0.5, 0.5);
|
||||
let b_data = random_f32_vec(n * k, 0x1A11_B000, -0.5, 0.5);
|
||||
let mut expected = vec![0.0f32; m * n];
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kk in 0..k {
|
||||
sum += a_data[kk] * b_data[col * k + kk];
|
||||
}
|
||||
expected[col] = sum;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(b_input, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "large row-order CUDA functional repro for llama MLP residual beta=1 shape"]
|
||||
fn cublaslt_row_order_beta_one_candidate_executes_llama_mlp_residual_like_projection() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (m, n, k) = (1, 4096, 64);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b_input = cx.tensor((n, k));
|
||||
let b = b_input.t();
|
||||
let c = cx.tensor((m, n));
|
||||
let out = (a.matmul(b) + c).output();
|
||||
let expected_orders = ("ROW", "COL", "ROW", "ROW");
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mlp residual row-order", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&expected_orders)
|
||||
&& cublaslt_scale_value_tuples(llir).contains(&(1.0, 1.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0x1A12_A000, -0.5, 0.5);
|
||||
let b_data = random_f32_vec(n * k, 0x1A12_B000, -0.5, 0.5);
|
||||
let c_data = random_f32_vec(m * n, 0x1A12_C000, -0.5, 0.5);
|
||||
let mut expected = c_data.clone();
|
||||
for col in 0..n {
|
||||
for kk in 0..k {
|
||||
expected[col] += a_data[kk] * b_data[col * k + kk];
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(b_input, b_data);
|
||||
rt.set_data(c, c_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA functional candidate sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_row_order_candidate_executes_batched_row_major_matmul() {
|
||||
@@ -2617,7 +2936,7 @@ fn extract_forced_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) -> LLIRGraph {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -2672,7 +2991,7 @@ fn assert_no_forced_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -2721,7 +3040,7 @@ fn assert_no_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -2762,10 +3081,17 @@ fn assert_no_cublaslt_llir_where(
|
||||
}
|
||||
|
||||
fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
let cublaslt_kind_classes = egraph
|
||||
op_ir_nodes(egraph, "cublaslt")
|
||||
.into_iter()
|
||||
.chain(op_ir_nodes(egraph, "cublaslt_scaled"))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == "cublaslt")
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -2776,12 +3102,93 @@ fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| cublaslt_kind_classes.contains(kind)))
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_scaled_count(egraph: &SerializedEGraph) -> usize {
|
||||
dataflow_reachable_cublaslt_count(egraph, true)
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_raw_fp8_count(egraph: &SerializedEGraph) -> usize {
|
||||
dataflow_reachable_cublaslt_count(egraph, false)
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_count(egraph: &SerializedEGraph, scaled: bool) -> usize {
|
||||
let reachable = dataflow_reachable_ir_classes(egraph);
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(node, (label, children))| {
|
||||
label == "Op"
|
||||
&& reachable.contains(&egraph.node_to_class[*node])
|
||||
&& children.first().is_some_and(|kind_class| {
|
||||
egraph
|
||||
.eclasses
|
||||
.get(kind_class)
|
||||
.is_some_and(|(_, kind_nodes)| {
|
||||
kind_nodes.iter().any(|kind_node| {
|
||||
egraph.enodes.get(kind_node).is_some_and(|(kind_label, _)| {
|
||||
if scaled {
|
||||
kind_label == "cublaslt_scaled"
|
||||
} else {
|
||||
kind_label == "cublaslt"
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
fn dataflow_reachable_ir_classes(egraph: &SerializedEGraph) -> FxHashSet<ClassId> {
|
||||
let mut reachable = FxHashSet::default();
|
||||
let mut stack = egraph.roots.clone();
|
||||
while let Some(class) = stack.pop() {
|
||||
if !reachable.insert(class.clone()) {
|
||||
continue;
|
||||
}
|
||||
let Some((sort, nodes)) = egraph.eclasses.get(&class) else {
|
||||
continue;
|
||||
};
|
||||
for node in nodes {
|
||||
let Some((label, children)) = egraph.enodes.get(node) else {
|
||||
continue;
|
||||
};
|
||||
match (sort.as_str(), label.as_str()) {
|
||||
("IR", "Output") => {
|
||||
if let Some(child) = children.first() {
|
||||
stack.push(child.clone());
|
||||
}
|
||||
}
|
||||
("IR", "OutputJoin") => stack.extend(children.iter().cloned()),
|
||||
("IR", "Op") => {
|
||||
if let Some(inputs) = children.get(1) {
|
||||
stack.push(inputs.clone());
|
||||
}
|
||||
}
|
||||
("IR", _) => {
|
||||
for child in children {
|
||||
if egraph
|
||||
.eclasses
|
||||
.get(child)
|
||||
.is_some_and(|(child_sort, _)| child_sort == "IR")
|
||||
{
|
||||
stack.push(child.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
("IList", "ICons") => stack.extend(children.iter().cloned()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
reachable
|
||||
}
|
||||
|
||||
fn llir_has_cublaslt(llir: &LLIRGraph) -> bool {
|
||||
!cublaslt_type_tuples(llir).is_empty()
|
||||
}
|
||||
@@ -2800,6 +3207,13 @@ fn cublaslt_scale_value_tuples(llir: &LLIRGraph) -> Vec<CublasLtScaleValues> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cublaslt_tensor_scale_input_tuples(llir: &LLIRGraph) -> Vec<(bool, bool)> {
|
||||
llir.node_weights()
|
||||
.filter_map(|op| op.to_dialect::<dyn HostOp>())
|
||||
.filter_map(|host_op| cublaslt_tensor_scale_inputs(host_op.as_ref().as_ref()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cublaslt_epilogues(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_weights()
|
||||
.filter_map(|op| op.to_dialect::<dyn HostOp>())
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
|
||||
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
|
||||
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
|
||||
//! 3. Mask op correctness (GPU): `ComputeAttnMask` produces the right (s, c) mask.
|
||||
//! 3. Mask helper correctness (GPU): the primitive-op `test_compute_attn_mask` builder produces the right (s, c) mask.
|
||||
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
|
||||
//! compared against a luminal-compiled reference attention graph.
|
||||
//!
|
||||
@@ -18,7 +18,7 @@ use luminal::op::{EgglogOp, IntoEgglogOp};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::host::flashinfer::FlashInferAttention;
|
||||
use crate::host::{ComputeAttnMask, DeviceBuffer, HostOp};
|
||||
use crate::host::{DeviceBuffer, HostOp};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::get_cuda_stream;
|
||||
|
||||
@@ -83,13 +83,13 @@ fn run_reference_attention(
|
||||
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
|
||||
cx.set_dim('s', batch_size);
|
||||
cx.set_dim('c', context_len);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
@@ -285,106 +285,6 @@ fn flashinfer_op_sort_shape() {
|
||||
assert!(dbg.contains("FlashInferAttention"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_registers() {
|
||||
assert!(
|
||||
ops_contains_sort("ComputeAttnMask"),
|
||||
"ComputeAttnMask is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 2: ComputeAttnMask correctness ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_matches_cpu_reference() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// 2 sequences, seq0 length=3, seq1 length=2 → s=2 queries (one per seq, decode),
|
||||
// c=5 total context tokens (3+2).
|
||||
let s_dim = 2usize;
|
||||
let c_dim = 5usize;
|
||||
let q_pos: Vec<i32> = vec![2, 1]; // last position in each seq
|
||||
let qo_indptr: Vec<i32> = vec![0, 1, 2];
|
||||
let kv_indptr: Vec<i32> = vec![0, 3, 5];
|
||||
let r = kv_indptr.len();
|
||||
|
||||
let q_pos_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(q_pos.as_ptr() as *const u8, q_pos.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let qo_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(qo_indptr.as_ptr() as *const u8, qo_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let kv_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(kv_indptr.as_ptr() as *const u8, kv_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let out_bytes = s_dim * c_dim * 4;
|
||||
let out_buf = unsafe { stream.alloc::<u8>(out_bytes).unwrap() };
|
||||
|
||||
let op = ComputeAttnMask {
|
||||
s_dim: Expression::from(s_dim),
|
||||
c_dim: Expression::from(c_dim),
|
||||
};
|
||||
|
||||
let q_pos_n = NodeIndex::new(0);
|
||||
let qo_n = NodeIndex::new(1);
|
||||
let kv_n = NodeIndex::new(2);
|
||||
let out_n = NodeIndex::new(3);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
buffers.insert(
|
||||
q_pos_n,
|
||||
DeviceBuffer::new(q_pos_buf.device_ptr(&stream).0, q_pos.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
qo_n,
|
||||
DeviceBuffer::new(qo_buf.device_ptr(&stream).0, qo_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
kv_n,
|
||||
DeviceBuffer::new(kv_buf.device_ptr(&stream).0, kv_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
out_n,
|
||||
DeviceBuffer::new(out_buf.device_ptr(&stream).0, out_bytes),
|
||||
);
|
||||
|
||||
let inputs = [q_pos_n, qo_n, kv_n];
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('r', r);
|
||||
|
||||
op.execute(&stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.unwrap();
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let host_bytes = stream.clone_dtoh(&out_buf).unwrap();
|
||||
let mask: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
|
||||
// Expected: query 0 (q_pos=2, seq 0) attends to ctx [0, 3) i.e. mask[0, 0..3]=0;
|
||||
// query 1 (q_pos=1, seq 1) attends to ctx [3, 5) i.e. mask[1, 3..5]=0.
|
||||
// Everywhere else is -1e10.
|
||||
let mut expected = vec![-1e10f32; s_dim * c_dim];
|
||||
for j in 0..3 {
|
||||
expected[0 * c_dim + j] = 0.0;
|
||||
}
|
||||
for j in 3..5 {
|
||||
expected[1 * c_dim + j] = 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(mask, expected);
|
||||
}
|
||||
|
||||
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
@@ -527,7 +427,7 @@ fn test_indptr_to_request_idx(
|
||||
n: Expression,
|
||||
) -> GraphTensor {
|
||||
let r = indptr.dims1();
|
||||
let indices = graph.arange(n.clone()).expand_dim(1, r.clone());
|
||||
let indices = graph.arange(n).expand_dim(1, r);
|
||||
let indptr_2d = indptr.expand_dim(0, n);
|
||||
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
|
||||
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
|
||||
@@ -541,13 +441,13 @@ fn test_compute_attn_mask(
|
||||
c: Expression,
|
||||
) -> GraphTensor {
|
||||
let s = q_pos.dims1();
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s.clone());
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c.clone());
|
||||
let c_arange = graph.arange(c.clone());
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s);
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c);
|
||||
let c_arange = graph.arange(c);
|
||||
let c_kv_start = kv_indptr.gather(c_request);
|
||||
let c_local_pos = c_arange - c_kv_start;
|
||||
let q_req_2d = q_request.expand_dim(1, c.clone());
|
||||
let c_req_2d = c_request.expand_dim(0, s.clone());
|
||||
let q_req_2d = q_request.expand_dim(1, c);
|
||||
let c_req_2d = c_request.expand_dim(0, s);
|
||||
let same = q_req_2d.eq(c_req_2d);
|
||||
let c_pos_2d = c_local_pos.expand_dim(0, s);
|
||||
let qp_2d = q_pos.expand_dim(1, c);
|
||||
@@ -577,6 +477,7 @@ fn scatter_rows(
|
||||
|
||||
/// Handles to every named input of the paged-attention test graph, returned
|
||||
/// alongside the graph so the GA-selection test can `set_data` on each one.
|
||||
#[allow(dead_code)]
|
||||
struct PagedAttnHandles {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
@@ -878,7 +779,7 @@ fn flashinfer_extraction_reachable_from_search_space() {
|
||||
cx.set_dim('s', 1usize);
|
||||
cx.set_dim('c', 16usize);
|
||||
cx.set_dim('r', 2usize);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx
|
||||
.egraph()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use as_any::Downcast;
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::fusion::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{
|
||||
TOLERANCE_SAFETY_FACTOR, dtype_epsilon, random_f32_vec, test_binary_cuda, test_unary_cuda,
|
||||
@@ -86,7 +88,7 @@ fn test_unary_fusion_preserves_output() {
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single marker region containing all three FusedX ops.
|
||||
// reachable as a single marker region containing all three elementwise ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
@@ -104,7 +106,7 @@ fn test_three_unary_ops_fuse() {
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single marker region containing all
|
||||
// four FusedX ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
// four elementwise ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
@@ -291,7 +293,7 @@ struct FusedRegion {
|
||||
/// Helper: collect every distinct fused region reachable across many random
|
||||
/// extractions of the search space.
|
||||
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
@@ -317,8 +319,15 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
|
||||
let name_of = |idx: NodeIndex| -> Option<String> {
|
||||
llir.node_weight(idx).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>()
|
||||
.map(|k| k.kernel_name().to_string())
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| {
|
||||
if let Some(elem) = (***k).downcast_ref::<CudaUnaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else if let Some(elem) = (***k).downcast_ref::<CudaBinaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else {
|
||||
k.kernel_name().to_string()
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
@@ -343,12 +352,13 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
|
||||
// Resolve chains of nested FusionStart wrappers (cascade artifact)
|
||||
// to the real external source. A FusionStart whose incoming neighbor
|
||||
// is itself a FusionStart — or a FusionEnd whose region is fully
|
||||
// inside ours — is a cascade layer, not a new external tensor.
|
||||
// is itself a FusionStart is a cascade layer, not a new external
|
||||
// tensor. A FusionEnd predecessor is a real external region output
|
||||
// in the generic singleton-region model, so do not walk through it.
|
||||
let resolve_source = |mut n: NodeIndex| -> NodeIndex {
|
||||
loop {
|
||||
match name_of(n).as_deref() {
|
||||
Some("FusionStart") | Some("FusionEnd") => {
|
||||
Some("FusionStart") => {
|
||||
let mut inc = llir.neighbors_directed(n, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(p) => n = p,
|
||||
@@ -379,15 +389,6 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
let mut inc =
|
||||
llir.neighbors_directed(pred, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(src_node)
|
||||
if name_of(src_node).as_deref() == Some("FusionEnd") =>
|
||||
{
|
||||
// Merge adjacent regions — treat the FS/FE
|
||||
// pair as internal; walk past the upstream
|
||||
// FE into its region.
|
||||
visited.insert(src_node);
|
||||
stack.push(src_node);
|
||||
}
|
||||
Some(src_node) => {
|
||||
start_sources.insert(resolve_source(src_node));
|
||||
}
|
||||
@@ -467,6 +468,15 @@ fn test_single_binary_does_not_fuse_alone() {
|
||||
fn test_chain_of_binaries_fuses() {
|
||||
// `(a + b) * c`: three external inputs collapse into one region with
|
||||
// internal [Add, Mul] and 3 FusionStarts.
|
||||
//
|
||||
// Requires BB family, which is opt-in at runtime via
|
||||
// LUMINAL_FUSION_FAMILIES. Set it before the graph build so the rules
|
||||
// emitted from FusionEnd::rewrites include the B-B pair-fuse rules.
|
||||
// SAFETY: tests run in parallel; we set this before constructing the
|
||||
// Graph, and never unset, so concurrent tests just see BB on.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
@@ -520,6 +530,13 @@ fn test_unary_then_binary_fuses() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Subsume in grow rules (introduced to bound the BB partial-FE explosion)
|
||||
// means a multi-consumer producer can no longer be fused into the same
|
||||
// region as all its consumers — only one branch wins. The diamond's `t`
|
||||
// has two consumers, so the structural "one 5-op region" outcome is no
|
||||
// longer guaranteed. Numerical correctness still holds (see
|
||||
// test_diamond_dag_preserves_output).
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_diamond_dag_fuses() {
|
||||
// The canonical diamond-DAG example agreed with the user:
|
||||
// t = a + b; u = exp2(t); v = sin(t); w = u * a; out = w + v
|
||||
@@ -650,6 +667,7 @@ fn test_diamond_dag_preserves_output() {
|
||||
// ---- Marker invariant tests ----
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_has_exactly_one_end() {
|
||||
// Design invariant: a fused region always has exactly one FusionEnd.
|
||||
// Uses the diamond DAG so there's real fan-in/out inside the region.
|
||||
@@ -677,6 +695,7 @@ fn test_fused_region_has_exactly_one_end() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_starts_match_distinct_external_tensors() {
|
||||
// Design invariant: FusionStart count == number of distinct external input
|
||||
// tensors, NOT number of edges crossing the boundary. In the diamond DAG
|
||||
@@ -768,6 +787,10 @@ fn test_pair_fuse_binary_to_binary_rhs() {
|
||||
// Pair-fuse B→B (RHS variant): `c * (a + b)`. The inner binary feeds the
|
||||
// outer binary's B input, exercising the mirror direction of the rule
|
||||
// covered by test_chain_of_binaries_fuses.
|
||||
// See test_chain_of_binaries_fuses for the LUMINAL_FUSION_FAMILIES note.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
@@ -809,6 +832,7 @@ fn test_grow_fe_to_binary_rhs() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume two-FE merge shape; numerical correctness preserved"]
|
||||
fn test_merge_two_regions_at_outer_binary() {
|
||||
// Merge: `(sin(a) + b) + (sqrt(c) + d)`. Each side independently pair-fuses
|
||||
// U→B on its own (the unary gives the inner Add a fusion partner that
|
||||
|
||||
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericMatmul");
|
||||
let names = llir_kernel_names(&llir);
|
||||
|
||||
assert!(
|
||||
names.contains(&"GenericMatmul"),
|
||||
"expected generic matmul fallback, kernels: {names:?}"
|
||||
);
|
||||
assert!(
|
||||
!names.contains(&"Mul") && !names.contains(&"SumReduce"),
|
||||
"generic matmul should prune the broadcast multiply/sum fallback, kernels: {names:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_executes_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let stream = get_cuda_stream().expect("CUDA device required for GenericMatmul execution test");
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, attn_data.as_slice());
|
||||
rt.set_data(weight, weight_data.as_slice());
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
assert!(
|
||||
rt.kernel_names().contains(&"GenericMatmul"),
|
||||
"expected GenericMatmul to be selected, kernels: {:?}",
|
||||
rt.kernel_names()
|
||||
);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output.id);
|
||||
|
||||
let mut expected = vec![0.0; seq * out_dim];
|
||||
for token in 0..seq {
|
||||
for out_col in 0..out_dim {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..hidden {
|
||||
let head = inner / head_dim;
|
||||
let dim = inner % head_dim;
|
||||
let attn_idx = head * seq * head_dim + token * head_dim + dim;
|
||||
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
|
||||
}
|
||||
expected[token * out_dim + out_col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let x = ((i * 37 + 11) % 97) as f32 / 97.0;
|
||||
x * scale + bias
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, kernel_name);
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -5,12 +5,16 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod conv2d_rewrite;
|
||||
#[cfg(test)]
|
||||
mod cublaslt_rewrite_tests;
|
||||
#[cfg(test)]
|
||||
mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod generic_matmul_rewrite;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
@@ -19,4 +23,8 @@ mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod rope_test;
|
||||
#[cfg(test)]
|
||||
mod search_equivalence_fuzz;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
@@ -83,7 +83,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -143,7 +143,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
let proj_w = cx.tensor((proj_dim, hidden));
|
||||
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -219,7 +219,7 @@ fn fuzz_layer_no_attn(
|
||||
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
|
||||
let out = (x + mlp_out).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -305,7 +305,7 @@ fn fuzz_layer_no_attn(
|
||||
}
|
||||
|
||||
/// Test a SwiGLU MLP with HLIR-only to specifically verify
|
||||
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
|
||||
/// the HLIR matmul decomposition (elementwise Mul + KernelSumReduce).
|
||||
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -318,7 +318,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -481,7 +481,7 @@ mod gemma {
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 800u64;
|
||||
@@ -518,7 +518,7 @@ mod gemma {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -641,7 +641,7 @@ mod qwen {
|
||||
let embedding = cx.tensor((VOCAB, HIDDEN));
|
||||
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 1300u64;
|
||||
@@ -655,7 +655,7 @@ mod qwen {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -256,10 +256,10 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, 10);
|
||||
rt = cx.search(rt, CompileOptions::new(10));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out_dim0 = rt.get_i32(sorted_dim0.id);
|
||||
let out_dim1 = rt.get_i32(sorted_dim1.id);
|
||||
@@ -424,7 +424,7 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
let e = (d + c).relu();
|
||||
let out = e.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().unwrap();
|
||||
let ops = cx.egglog_ops().unwrap();
|
||||
|
||||
@@ -592,7 +592,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let token_data: Vec<i32> = random_i32_vec(seq_len, seed, 0, vocab_size as i32 - 1);
|
||||
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Test that measures bandwidth utilization for a large element-wise add kernel.
|
||||
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
|
||||
/// This demonstrates that generic fused Add can achieve reasonable bandwidth with large tensors.
|
||||
#[test]
|
||||
pub fn kernel_add_bandwidth_test() {
|
||||
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
|
||||
@@ -27,11 +27,11 @@ pub fn kernel_add_bandwidth_test() {
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -40,7 +40,7 @@ pub fn kernel_add_bandwidth_test() {
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Print stats
|
||||
println!("\n=== Large KernelAdd Bandwidth Test ===");
|
||||
println!("\n=== Large Fused Add Bandwidth Test ===");
|
||||
println!(
|
||||
"Tensor size: {} elements ({} MB per tensor)",
|
||||
size,
|
||||
|
||||
@@ -2,16 +2,13 @@ use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
use crate::{host::moe::GLUMoE, runtime::CudaRuntime};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const HIDDEN: usize = 32;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const MOE_INTERMEDIATE: usize = 12;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
@@ -58,6 +55,7 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_values = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
@@ -71,9 +69,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +128,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
@@ -172,30 +170,51 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
fn search_space_contains(cx: &Graph, op_name: &str) -> bool {
|
||||
let egraph = cx.egraph().expect("test should build an e-graph");
|
||||
|
||||
for (label, children) in egraph.enodes.values() {
|
||||
if label != "Op" {
|
||||
continue;
|
||||
}
|
||||
let Some(kind_eclass) = children.first() else {
|
||||
continue;
|
||||
};
|
||||
let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) else {
|
||||
continue;
|
||||
};
|
||||
if kind_enodes
|
||||
.iter()
|
||||
.any(|kind_node| egraph.enodes[kind_node].0 == op_name)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn assert_glumoe_in_search_space(cx: &Graph) {
|
||||
assert!(
|
||||
search_space_contains(cx, "GLUMoE"),
|
||||
"GLUMoE was not in the e-graph search space"
|
||||
);
|
||||
}
|
||||
|
||||
fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
if include_glumoe {
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
@@ -214,25 +233,27 @@ fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
if include_glumoe {
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
@@ -257,54 +278,58 @@ fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
let expected = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
let actual = run_qwen_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
let expected = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let actual = run_gemma_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
115
crates/luminal_cuda_lite/src/tests/rope_test.rs
Normal file
115
crates/luminal_cuda_lite/src/tests/rope_test.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::{
|
||||
graph::{CompileOptions, Graph},
|
||||
op::Runtime,
|
||||
};
|
||||
|
||||
use crate::{kernel::apply_rope, runtime::CudaRuntime};
|
||||
|
||||
fn cpu_rope(x: &[f32], cos: &[f32], sin: &[f32], s: usize, h: usize, d: usize) -> Vec<f32> {
|
||||
assert!(d.is_multiple_of(2));
|
||||
let mut out = vec![0.0f32; s * h * d];
|
||||
for si in 0..s {
|
||||
for hi in 0..h {
|
||||
for i in 0..d {
|
||||
let xi = x[si * h * d + hi * d + i];
|
||||
let xpair = if i % 2 == 0 {
|
||||
-x[si * h * d + hi * d + i + 1]
|
||||
} else {
|
||||
x[si * h * d + hi * d + i - 1]
|
||||
};
|
||||
let c = cos[si * d + i];
|
||||
let sn = sin[si * d + i];
|
||||
out[si * h * d + hi * d + i] = xi * c + xpair * sn;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_matches_cpu_reference() {
|
||||
let s = 8;
|
||||
let h = 4;
|
||||
let d = 32;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
let x_data: Vec<f32> = (0..s * h * d).map(|i| ((i as f32) * 0.013).sin()).collect();
|
||||
let cos_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).cos()).collect();
|
||||
let sin_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).sin()).collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-5, "max abs error {max_err} too high");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_flux2_shape() {
|
||||
// Flux 2 transformer attention: S=1536 (img+txt), H=48, D=128.
|
||||
let s = 1536;
|
||||
let h = 48;
|
||||
let d = 128;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(11);
|
||||
let x_data: Vec<f32> = (0..s * h * d)
|
||||
.map(|_| rng.random_range(-2.0..2.0_f32))
|
||||
.collect();
|
||||
let cos_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
let sin_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope flux2: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-4, "max abs error {max_err} too high");
|
||||
}
|
||||
374
crates/luminal_cuda_lite/src/tests/search_equivalence_fuzz.rs
Normal file
374
crates/luminal_cuda_lite/src/tests/search_equivalence_fuzz.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
//! End-to-end e-graph search-space equivalence fuzz tests.
|
||||
//!
|
||||
//! These tests do not compare against a hand-written reference. They assert the
|
||||
//! stronger search invariant: every selectable LLIR graph from the same e-graph
|
||||
//! must produce finite, numerically close outputs for the same runtime inputs.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[path = "../../../../examples/llama/src/model.rs"]
|
||||
mod llama_model;
|
||||
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
|
||||
use super::utilities::{CudaSearchEquivalenceFuzzer, get_cuda_stream, random_f32_vec};
|
||||
|
||||
const SEARCH_EQUIV_SAMPLES: usize = 32;
|
||||
|
||||
fn random_bf16_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<bf16> {
|
||||
random_f32_vec(n, seed, low, high)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llama_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const CTX: usize = 3;
|
||||
const SLOTS: usize = 4;
|
||||
|
||||
let config = llama_model::LlamaConfig {
|
||||
layers: 2,
|
||||
hidden: 32,
|
||||
intermediate: 64,
|
||||
head_dim: 8,
|
||||
kv_groups: 2,
|
||||
vocab_size: 64,
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim('s', SEQ);
|
||||
cx.set_dim('c', CTX);
|
||||
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = llama_model::KVCache::new_with_config(&mut cx, SLOTS, config);
|
||||
let llama = llama_model::Llama::init_with_config(&mut cx, config);
|
||||
|
||||
let (logits, cache_outputs) =
|
||||
llama.forward(input, q_pos, scatter_idx, gather_idx, attn_mask, &kv_cache);
|
||||
let logits = logits.output();
|
||||
let mut fuzzer = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x5EED_1234)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 5e-2, 5e-2);
|
||||
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
|
||||
let k_out = k_out.output();
|
||||
let v_out = v_out.output();
|
||||
fuzzer = fuzzer.output_f32(k_out.id, format!("layer{layer}.k_cache"), 3e-3, 3e-3);
|
||||
fuzzer = fuzzer.output_f32(v_out.id, format!("layer{layer}.v_cache"), 3e-3, 3e-3);
|
||||
}
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0x11A_AA55);
|
||||
fuzzer = fuzzer
|
||||
.input_i32(input.id, vec![3, 17])
|
||||
.input_i32(q_pos.id, vec![1, 2])
|
||||
.input_i32(scatter_idx.id, vec![1, 2])
|
||||
.input_i32(gather_idx.id, vec![0, 1, 2])
|
||||
.input_f32(attn_mask.id, vec![0.0, 0.0, -1e4, 0.0, 0.0, 0.0]);
|
||||
|
||||
let kv_dim = config.kv_dim();
|
||||
for tensor in kv_cache.tensors() {
|
||||
fuzzer = fuzzer.input_f32(tensor.id, vec![0.0; SLOTS * kv_dim]);
|
||||
}
|
||||
for tensor in llama.parameter_tensors() {
|
||||
let elements = tensor
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|dim| dim.to_usize().expect("tiny llama test uses static params"))
|
||||
.product::<usize>();
|
||||
let data = (0..elements)
|
||||
.map(|_| rng.random_range(-0.08f32..0.08f32))
|
||||
.collect::<Vec<_>>();
|
||||
fuzzer = fuzzer.input_f32(tensor.id, data);
|
||||
}
|
||||
|
||||
let report = fuzzer.run();
|
||||
eprintln!("llama search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemma_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
const Q_DIM: usize = 24;
|
||||
const INTERMEDIATE: usize = 64;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let attn_norm_w = cx.tensor(HIDDEN);
|
||||
let post_attn_norm_w = cx.tensor(HIDDEN);
|
||||
let pre_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let post_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let proj_w = cx.tensor((Q_DIM, HIDDEN));
|
||||
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
|
||||
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
|
||||
let normed = rms_norm(input, attn_norm_w, EPS);
|
||||
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
|
||||
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
|
||||
let x = input + attn_normed;
|
||||
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
|
||||
let mlp_out =
|
||||
(gemma_gelu(ff_normed.matmul(w_gate.t())) * ff_normed.matmul(w_up.t())).matmul(w_down.t());
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x6E4D_4DAA)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 101, -0.15, 0.15))
|
||||
.input_f32(attn_norm_w.id, random_f32_vec(HIDDEN, 102, 0.7, 1.3))
|
||||
.input_f32(post_attn_norm_w.id, random_f32_vec(HIDDEN, 103, 0.7, 1.3))
|
||||
.input_f32(pre_ff_norm_w.id, random_f32_vec(HIDDEN, 104, 0.7, 1.3))
|
||||
.input_f32(post_ff_norm_w.id, random_f32_vec(HIDDEN, 105, 0.7, 1.3))
|
||||
.input_f32(proj_w.id, random_f32_vec(Q_DIM * HIDDEN, 106, -0.08, 0.08))
|
||||
.input_f32(
|
||||
o_proj_w.id,
|
||||
random_f32_vec(HIDDEN * Q_DIM, 107, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_gate.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 108, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_up.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 109, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_down.id,
|
||||
random_f32_vec(HIDDEN * INTERMEDIATE, 110, -0.08, 0.08),
|
||||
)
|
||||
.output_f32(out.id, "gemma_block", 5e-3, 5e-3)
|
||||
.run();
|
||||
eprintln!("gemma search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x0DEE_55EE)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.input_f32(
|
||||
router_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 201, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(
|
||||
expert_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 202, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(router_scale.id, random_f32_vec(HIDDEN, 203, 0.7, 1.3))
|
||||
.input_f32(
|
||||
router_proj.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 204, -0.2, 0.2),
|
||||
)
|
||||
.input_f32(
|
||||
per_expert_scale.id,
|
||||
random_f32_vec(NUM_EXPERTS, 205, 0.5, 1.5),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 206, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 207, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "gemma_moe_block", 5e-2, 5e-2)
|
||||
.run();
|
||||
eprintln!("moe search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_native_reference_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = input.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = input.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_weights = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let input_exp = input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = input_exp
|
||||
.matmul(gate_up_gathered.transpose(2, 3))
|
||||
.squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x51A7_E5ED)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.native_reference()
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 301, -0.15, 0.15))
|
||||
.input_f32(
|
||||
router.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 302, -0.2, 0.2),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 303, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 304, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "qwen_swiglu_moe_native_reference", 6e-2, 6e-2)
|
||||
.run();
|
||||
eprintln!("moe native-reference fuzz report: {report:?}");
|
||||
}
|
||||
@@ -267,7 +267,7 @@ fn test_mini_transformer_layer() {
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
|
||||
|
||||
// Use minimal search iterations to avoid excessive graph rewriting
|
||||
// which can cause float drift through softmax/RMSNorm reordering
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -303,7 +303,7 @@ fn test_mini_transformer_two_layers() {
|
||||
let x = layer1.forward(input);
|
||||
let out = layer2.forward(x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -361,7 +361,7 @@ fn test_transformer_multi_seed() {
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
|
||||
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -394,7 +394,7 @@ fn test_rms_norm_cuda() {
|
||||
let weight = cx.tensor(HIDDEN);
|
||||
let out = rms_norm(input, weight, 1e-5).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 1, -0.5, 0.5);
|
||||
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
|
||||
.collect();
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(weight, weight_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -433,7 +433,7 @@ fn test_self_attention_cuda() {
|
||||
let wo = cx.tensor((HIDDEN, HIDDEN));
|
||||
let out = self_attention(input, wq, wk, wv, wo).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 10, -0.5, 0.5);
|
||||
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
|
||||
rt.set_data(wk, wk_data.clone());
|
||||
rt.set_data(wv, wv_data.clone());
|
||||
rt.set_data(wo, wo_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -479,7 +479,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 20, -0.5, 0.5);
|
||||
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -526,11 +526,11 @@ fn test_rolled_chained_scalar_muls() {
|
||||
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
|
||||
let out = (chained + x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use itertools::Itertools;
|
||||
use luminal::egglog_utils::{
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
|
||||
validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use num_traits::{Num, Signed};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
@@ -128,6 +133,498 @@ pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CudaFuzzInput {
|
||||
F32(NodeIndex, Vec<f32>),
|
||||
Bf16(NodeIndex, Vec<bf16>),
|
||||
I32(NodeIndex, Vec<i32>),
|
||||
}
|
||||
|
||||
impl CudaFuzzInput {
|
||||
fn apply(&self, rt: &mut CudaRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_native(&self, rt: &mut NativeRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct F32OutputCheck {
|
||||
pub id: NodeIndex,
|
||||
pub name: String,
|
||||
pub rtol: f32,
|
||||
pub atol: f32,
|
||||
}
|
||||
|
||||
impl F32OutputCheck {
|
||||
pub fn new(id: NodeIndex, name: impl Into<String>, rtol: f32, atol: f32) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name: name.into(),
|
||||
rtol,
|
||||
atol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchEquivalenceFuzzConfig {
|
||||
pub seed: u64,
|
||||
pub samples: usize,
|
||||
pub generation_size: usize,
|
||||
pub mutations: usize,
|
||||
pub max_attempts: usize,
|
||||
pub build_options: CompileOptions,
|
||||
pub reference: SearchEquivalenceReference,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SearchEquivalenceReference {
|
||||
FirstCudaExtraction,
|
||||
NativeRuntime,
|
||||
}
|
||||
|
||||
impl Default for SearchEquivalenceFuzzConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
seed: 0,
|
||||
samples: 32,
|
||||
generation_size: 16,
|
||||
mutations: 2,
|
||||
max_attempts: 1_000,
|
||||
build_options: CompileOptions::default(),
|
||||
reference: SearchEquivalenceReference::FirstCudaExtraction,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SearchEquivalenceFuzzReport {
|
||||
pub tested: usize,
|
||||
pub skipped_invalid: usize,
|
||||
}
|
||||
|
||||
struct ChoiceRun {
|
||||
outputs: Vec<Vec<f32>>,
|
||||
llir_summary: String,
|
||||
}
|
||||
|
||||
pub struct CudaSearchEquivalenceFuzzer<'a> {
|
||||
cx: &'a mut Graph,
|
||||
stream: &'a Arc<cudarc::driver::CudaStream>,
|
||||
inputs: Vec<CudaFuzzInput>,
|
||||
outputs: Vec<F32OutputCheck>,
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
}
|
||||
|
||||
impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
pub fn new(cx: &'a mut Graph, stream: &'a Arc<cudarc::driver::CudaStream>) -> Self {
|
||||
Self {
|
||||
cx,
|
||||
stream,
|
||||
inputs: Vec::new(),
|
||||
outputs: Vec::new(),
|
||||
config: SearchEquivalenceFuzzConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.config.seed = seed;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn samples(mut self, samples: usize) -> Self {
|
||||
self.config.samples = samples;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn generation_size(mut self, generation_size: usize) -> Self {
|
||||
self.config.generation_size = generation_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mutations(mut self, mutations: usize) -> Self {
|
||||
self.config.mutations = mutations;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_options(mut self, build_options: CompileOptions) -> Self {
|
||||
self.config.build_options = build_options;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn native_reference(mut self) -> Self {
|
||||
self.config.reference = SearchEquivalenceReference::NativeRuntime;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_f32(mut self, id: NodeIndex, data: Vec<f32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::F32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_bf16(mut self, id: NodeIndex, data: Vec<bf16>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::Bf16(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_i32(mut self, id: NodeIndex, data: Vec<i32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::I32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn output_f32(
|
||||
mut self,
|
||||
id: NodeIndex,
|
||||
name: impl Into<String>,
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
) -> Self {
|
||||
self.outputs.push(F32OutputCheck::new(id, name, rtol, atol));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn run(self) -> SearchEquivalenceFuzzReport {
|
||||
fuzz_cuda_search_space_equivalence(
|
||||
self.cx,
|
||||
self.stream,
|
||||
&self.inputs,
|
||||
&self.outputs,
|
||||
self.config,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// End-to-end search-space equivalence fuzzing for CUDA.
|
||||
///
|
||||
/// This builds the normal CUDA e-graph search space, extracts random selectable
|
||||
/// LLIR graphs, runs each with identical inputs, and verifies every requested
|
||||
/// f32 output matches the first valid extraction. The reference is intentionally
|
||||
/// another selected LLIR graph, not a hand-written CPU implementation: this
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge, including
|
||||
/// candidates that produce non-finite outputs.
|
||||
pub fn fuzz_cuda_search_space_equivalence(
|
||||
cx: &mut Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
) -> SearchEquivalenceFuzzReport {
|
||||
assert!(
|
||||
!outputs.is_empty(),
|
||||
"fuzz harness needs at least one output"
|
||||
);
|
||||
|
||||
let native_reference_outputs = if config.reference == SearchEquivalenceReference::NativeRuntime
|
||||
{
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_with_rng(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::new(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
input.apply_native(&mut native_rt);
|
||||
}
|
||||
native_rt.execute(&cx.dyn_map);
|
||||
Some(
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| native_rt.get_f32(out.id).clone())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(config.build_options);
|
||||
|
||||
let egraph = cx.egraph().expect("search space should be built");
|
||||
let ops = cx.egglog_ops().expect("search ops should be built");
|
||||
let seed = if native_reference_outputs.is_some() {
|
||||
config.seed.wrapping_add(0xC0DA_C0DA)
|
||||
} else {
|
||||
config.seed
|
||||
};
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let mut prev_selected = FxHashSet::default();
|
||||
let mut base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
|
||||
let mut skipped_invalid = 0usize;
|
||||
let reference_is_cuda = native_reference_outputs.is_none();
|
||||
let (reference_hash, reference_outputs, reference_llir_summary, mut tested) =
|
||||
if let Some(reference_outputs) = native_reference_outputs {
|
||||
(0, reference_outputs, None, 0usize)
|
||||
} else {
|
||||
let mut attempts = 0usize;
|
||||
let (reference_hash, reference_run) = loop {
|
||||
attempts += 1;
|
||||
if attempts > config.max_attempts {
|
||||
panic!(
|
||||
"failed to extract a valid reference LLIR after {} attempts",
|
||||
config.max_attempts
|
||||
);
|
||||
}
|
||||
if validate_choice_set(egraph, &base, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
} else {
|
||||
let hash = hash_choice_set(&base);
|
||||
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
|
||||
Ok(run) => break (hash, run),
|
||||
Err(err) => panic!("reference candidate hash={hash} failed: {err}"),
|
||||
}
|
||||
}
|
||||
base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
};
|
||||
(
|
||||
reference_hash,
|
||||
reference_run.outputs,
|
||||
Some(reference_run.llir_summary),
|
||||
1usize,
|
||||
)
|
||||
};
|
||||
|
||||
let mut attempts = 0usize;
|
||||
while tested < config.samples && attempts < config.max_attempts {
|
||||
attempts += 1;
|
||||
let mut candidates = extract_generation(
|
||||
egraph,
|
||||
&base,
|
||||
config.generation_size,
|
||||
config.mutations,
|
||||
&mut prev_selected,
|
||||
&mut rng,
|
||||
);
|
||||
if candidates.is_empty() {
|
||||
let next = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&next));
|
||||
candidates.push(next);
|
||||
}
|
||||
|
||||
for candidate in candidates {
|
||||
if tested >= config.samples {
|
||||
break;
|
||||
}
|
||||
let candidate_hash = hash_choice_set(&candidate);
|
||||
if reference_is_cuda && candidate_hash == reference_hash {
|
||||
continue;
|
||||
}
|
||||
if validate_choice_set(egraph, &candidate, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate_run = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
|
||||
assert_fuzz_outputs_close(
|
||||
outputs,
|
||||
&reference_outputs,
|
||||
&candidate_run.outputs,
|
||||
&candidate_run.llir_summary,
|
||||
reference_llir_summary.as_deref(),
|
||||
reference_hash,
|
||||
candidate_hash,
|
||||
);
|
||||
base = candidate;
|
||||
tested += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
tested, config.samples,
|
||||
"only tested {tested}/{} LLIR samples before exhausting attempts",
|
||||
config.samples
|
||||
);
|
||||
SearchEquivalenceFuzzReport {
|
||||
tested,
|
||||
skipped_invalid,
|
||||
}
|
||||
}
|
||||
|
||||
fn run_choice_outputs<'a>(
|
||||
cx: &'a Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
) -> Result<ChoiceRun, String> {
|
||||
let egraph = cx.egraph().ok_or("search space was not built")?;
|
||||
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let mut llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
choices.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
let llir_summary = summarize_llir(&llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
rt.preserve_intermediate_buffers_for_debug();
|
||||
for input in inputs {
|
||||
input.apply(&mut rt);
|
||||
}
|
||||
if std::env::var_os("LUMINAL_FUZZ_DUMP_LAST_LLIR").is_some() {
|
||||
let _ = std::fs::write("/tmp/luminal_fuzz_last_candidate_llir.txt", &llir_summary);
|
||||
}
|
||||
rt.execute(&cx.dyn_map);
|
||||
let topo_order = toposort(&llir_graph, None).map_err(|cycle| {
|
||||
format!(
|
||||
"extracted LLIR contains cycle at node {:?}",
|
||||
cycle.node_id()
|
||||
)
|
||||
})?;
|
||||
if let Some(report) = rt.first_nonfinite_f32_buffer_in_nodes(topo_order) {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
return Err(format!(
|
||||
"LLIR produced non-finite F32 buffer node={} index={} value={} op={}; llir={dump_path}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
));
|
||||
}
|
||||
|
||||
let values = outputs
|
||||
.iter()
|
||||
.map(|out| rt.get_f32(out.id))
|
||||
.collect::<Vec<_>>();
|
||||
for (spec, values) in outputs.iter().zip(&values) {
|
||||
if let Some((idx, value)) = values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, value)| !value.is_finite())
|
||||
{
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let internal = rt
|
||||
.first_nonfinite_f32_buffer()
|
||||
.map(|report| {
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
format!(
|
||||
"; first observed non-finite buffer node={} index={} value={} op={}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"output {} produced non-finite value {value} at index {idx}{internal}; llir={dump_path}",
|
||||
spec.name
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(ChoiceRun {
|
||||
outputs: values,
|
||||
llir_summary,
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_fuzz_outputs_close(
|
||||
outputs: &[F32OutputCheck],
|
||||
expected: &[Vec<f32>],
|
||||
actual: &[Vec<f32>],
|
||||
candidate_llir_summary: &str,
|
||||
reference_llir_summary: Option<&str>,
|
||||
reference_hash: u64,
|
||||
candidate_hash: u64,
|
||||
) {
|
||||
for ((spec, expected), actual) in outputs.iter().zip(expected.iter()).zip(actual.iter()) {
|
||||
assert_eq!(
|
||||
expected.len(),
|
||||
actual.len(),
|
||||
"output {} length mismatch for candidate hash={candidate_hash} reference hash={reference_hash}",
|
||||
spec.name
|
||||
);
|
||||
let mut max_abs = 0.0f32;
|
||||
let mut max_rel = 0.0f32;
|
||||
let mut worst = 0usize;
|
||||
for (i, (&a, &b)) in actual.iter().zip(expected.iter()).enumerate() {
|
||||
assert!(
|
||||
a.is_finite(),
|
||||
"output {} candidate hash={candidate_hash} produced non-finite value {a} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
assert!(
|
||||
b.is_finite(),
|
||||
"output {} reference hash={reference_hash} produced non-finite value {b} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
let abs = (a - b).abs();
|
||||
let rel = abs / b.abs().max(1e-12);
|
||||
if abs > max_abs {
|
||||
max_abs = abs;
|
||||
max_rel = rel;
|
||||
worst = i;
|
||||
}
|
||||
if abs > spec.atol + spec.rtol * b.abs() {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, candidate_llir_summary);
|
||||
if let Some(reference_llir_summary) = reference_llir_summary {
|
||||
let _ = std::fs::write(
|
||||
"/tmp/luminal_fuzz_bad_reference_llir.txt",
|
||||
reference_llir_summary,
|
||||
);
|
||||
}
|
||||
panic!(
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={} candidate_llir={dump_path}",
|
||||
spec.name,
|
||||
spec.atol + spec.rtol * b.abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"fuzz output {} ok: candidate hash={candidate_hash} max_abs={max_abs} max_rel={max_rel} worst={worst}",
|
||||
spec.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_llir(llir_graph: &LLIRGraph) -> String {
|
||||
llir_graph
|
||||
.node_indices()
|
||||
.map(|idx| {
|
||||
let inputs = llir_graph
|
||||
.edges_directed(idx, Direction::Incoming)
|
||||
.sorted_by_key(|edge| edge.id())
|
||||
.map(|edge| edge.source().index().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!("{} <- [{}]: {:?}", idx.index(), inputs, &llir_graph[idx])
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
@@ -199,12 +696,12 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
let a = cx.tensor(shape.clone());
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
@@ -272,14 +769,14 @@ pub fn test_binary_cuda<T: TestDType>(
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = a_generator(a_elements, seed);
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
@@ -339,7 +836,7 @@ pub fn test_mod(
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
|
||||
@@ -347,7 +844,7 @@ pub fn test_mod(
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
@@ -1,22 +1,32 @@
|
||||
[package]
|
||||
name = "luminal_metal"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
description = "Metal backend for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
metal = "0.31"
|
||||
metal = { version = "0.31", features = ["mps"] }
|
||||
objc = "0.2"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
half = "2.7.1"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
tracing = "0.1.43"
|
||||
safetensors = "0.7.0"
|
||||
memmap2 = "0.9.9"
|
||||
bytemuck = "1.24.0"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
rustc-hash = "2.1"
|
||||
tokenizers = "0.22.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
|
||||
641
crates/luminal_metal/examples/llama_1b.rs
Normal file
641
crates/luminal_metal/examples/llama_1b.rs
Normal file
@@ -0,0 +1,641 @@
|
||||
use hf_hub::api::sync::Api;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::{CompileOptions, DimBucket, Graph},
|
||||
prelude::{F32Pow, GraphTensor, Runtime},
|
||||
};
|
||||
use luminal_metal::MetalRuntime;
|
||||
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
|
||||
use luminal_tracing::luminal_filter;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
error::Error,
|
||||
io::Write,
|
||||
path::PathBuf,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "unsloth/Llama-3.2-1B-Instruct";
|
||||
const MAX_SEQ_LEN: usize = 2048;
|
||||
const GEN_TOKENS: usize = 96;
|
||||
const SEARCH_GRAPHS: usize = 100;
|
||||
const SEARCH_MEMORY_MIB: usize = 1536;
|
||||
const PROMPT: &str = "In one short paragraph, explain neural networks using the words layers, neurons, learning, and data.";
|
||||
|
||||
const LAYERS: usize = 16;
|
||||
const HIDDEN: usize = 2048;
|
||||
const INTERMEDIATE: usize = 8192;
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_HEADS: usize = 32;
|
||||
const N_KV_HEADS: usize = 8;
|
||||
const KV_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const VOCAB_SIZE: usize = 128256;
|
||||
const RMS_NORM_EPS: f32 = 1e-5;
|
||||
const ROPE_THETA: f32 = 500_000.0;
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
fn prepare_hf_model() -> Result<PathBuf, Box<dyn Error>> {
|
||||
let repo = Api::new()?.model(REPO_ID.to_string());
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
repo.get("model.safetensors")?;
|
||||
Ok(tokenizer_path.parent().unwrap().to_path_buf())
|
||||
}
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct StepProfile {
|
||||
total: Duration,
|
||||
execute: Duration,
|
||||
get_logits: Duration,
|
||||
cache_roundtrip: Duration,
|
||||
}
|
||||
|
||||
fn avg_ms(duration: Duration, n: usize) -> f64 {
|
||||
if n == 0 {
|
||||
0.0
|
||||
} else {
|
||||
duration.as_secs_f64() * 1e3 / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
|
||||
let mut mask = vec![-1e10f32; q_pos.len() * context_len];
|
||||
for (qi, &pos) in q_pos.iter().enumerate() {
|
||||
for ci in 0..context_len {
|
||||
if ci <= pos {
|
||||
mask[qi * context_len + ci] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mask
|
||||
}
|
||||
|
||||
struct KVCache {
|
||||
k_caches: Vec<GraphTensor>,
|
||||
v_caches: Vec<GraphTensor>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
v_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
}
|
||||
Self { k_caches, v_caches }
|
||||
}
|
||||
}
|
||||
|
||||
struct Llama {
|
||||
embedding: GraphTensor,
|
||||
layers: Vec<LlamaLayer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_norm: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
q_pos,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.embedding.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamaLayer {
|
||||
up: GraphTensor,
|
||||
gate: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: GraphTensor,
|
||||
o_proj: GraphTensor,
|
||||
attn_rms: LayerNorm,
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = ROPE_THETA.pow(freqs).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
|
||||
let x1 = input.slice((.., .., HEAD_DIM / 2..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
|
||||
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, KV_DIM);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let masked_scores = scores + attn_mask.expand_dim(0, N_HEADS);
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_model_step(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut MetalRuntime,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
tokens: &[u32],
|
||||
q_pos: &[i32],
|
||||
scatter_idx: &[i32],
|
||||
gather_idx: &[i32],
|
||||
attn_mask: &[f32],
|
||||
) -> (Vec<f32>, StepProfile) {
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', tokens.len());
|
||||
cx.set_dim('c', gather_idx.len());
|
||||
|
||||
runtime.set_data(input, tokens.iter().map(|t| *t as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(q_pos_t, q_pos.to_vec());
|
||||
runtime.set_data(scatter_idx_t, scatter_idx.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather_idx.to_vec());
|
||||
runtime.set_data(attn_mask_t, attn_mask.to_vec());
|
||||
runtime.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let execute_start = Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let execute = execute_start.elapsed();
|
||||
|
||||
let logits_start = Instant::now();
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let get_logits = logits_start.elapsed();
|
||||
|
||||
let cache_start = Instant::now();
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let cache_roundtrip = cache_start.elapsed();
|
||||
|
||||
(
|
||||
logits_data,
|
||||
StepProfile {
|
||||
total: start.elapsed(),
|
||||
execute,
|
||||
get_logits,
|
||||
cache_roundtrip,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.try_init();
|
||||
|
||||
let model_dir = prepare_hf_model()?;
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(llama3_chat_prompt(PROMPT), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = KVCache::new(&mut cx, MAX_SEQ_LEN);
|
||||
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
&kv_cache,
|
||||
);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let search_c = 16.min(max_context).max(2);
|
||||
let build_options = CompileOptions::default()
|
||||
.max_memory_mib(SEARCH_MEMORY_MIB)
|
||||
.dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
)
|
||||
.dim_buckets(
|
||||
'c',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_context).representative(search_c),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let egraph_start = Instant::now();
|
||||
cx.build_search_space::<MetalRuntime>(build_options);
|
||||
println!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
println!("Loading weights...");
|
||||
let load_start = Instant::now();
|
||||
let mut runtime = MetalRuntime::initialize(());
|
||||
runtime.load_safetensors(&cx, model_dir.join("model.safetensors").to_str().unwrap());
|
||||
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
|
||||
|
||||
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let compile_start = Instant::now();
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_c);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(q_pos_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, CompileOptions::new(SEARCH_GRAPHS));
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut context_len = 0usize;
|
||||
let mut profiles = Vec::new();
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty = 1.05;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, GEN_TOKENS
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut next_token = None;
|
||||
if GEN_TOKENS > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&prompt_tokens,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&mask,
|
||||
);
|
||||
context_len = prompt_len;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
|
||||
while generated < GEN_TOKENS {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
let gather_idx = (0..=context_len as i32).collect::<Vec<_>>();
|
||||
let mask = causal_mask(&[context_len], context_len + 1);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
context_len += 1;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!();
|
||||
|
||||
let ttft = profiles.first().map(|p| p.total).unwrap_or_default();
|
||||
let decode_steps = profiles.len().saturating_sub(1);
|
||||
let decode_total: Duration = profiles.iter().skip(1).map(|p| p.total).sum();
|
||||
println!(" TTFT: {:.2} ms", ttft.as_secs_f64() * 1e3);
|
||||
println!(" TPOT: {:.2} ms", avg_ms(decode_total, decode_steps));
|
||||
|
||||
let execute_total: Duration = profiles.iter().map(|p| p.execute).sum();
|
||||
let logits_total: Duration = profiles.iter().map(|p| p.get_logits).sum();
|
||||
let cache_total: Duration = profiles.iter().map(|p| p.cache_roundtrip).sum();
|
||||
println!(
|
||||
" Profile: n={}, exec={:.2} ms, logits={:.2} ms, cache={:.2} ms",
|
||||
profiles.len(),
|
||||
avg_ms(execute_total, profiles.len()),
|
||||
avg_ms(logits_total, profiles.len()),
|
||||
avg_ms(cache_total, profiles.len()),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, bytes_to_native_data, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
@@ -31,10 +31,42 @@ impl DynBackend for MetalDynBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Reject dtypes the Metal kernel emitters don't support.
|
||||
///
|
||||
/// Metal codegen has no native 64-bit integer or 64-bit float paths.
|
||||
/// Reaching the kernel emitter with one of these dtypes used to panic deep
|
||||
/// in MSL generation with an unhelpful error; surfacing a clean message
|
||||
/// at translate-time lets the user fall back to CPU or pick a narrower
|
||||
/// dtype before any Metal compilation runs.
|
||||
fn reject_unsupported_dtype(graph: &Graph) -> Result<(), String> {
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
match input.dtype {
|
||||
DType::I64 | DType::F64 => {
|
||||
return Err(format!(
|
||||
"Metal backend does not support {:?} (input `{}`). \
|
||||
Metal codegen has no native 64-bit kernels; either \
|
||||
narrow the dtype (e.g. `.to(torch.int32)` / \
|
||||
`.to(torch.float32)`) before the boundary or \
|
||||
compile with the CPU / CUDA backend.",
|
||||
input.dtype, input.label
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn metal_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
reject_unsupported_dtype(graph)?;
|
||||
compile_backend::<MetalRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|
||||
@@ -1,227 +1,5 @@
|
||||
use super::{MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum MetalMatmulFamily {
|
||||
#[default]
|
||||
Naive,
|
||||
RegularTiled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulDescriptor {
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub batch_shape: Vec<Expression>,
|
||||
pub lhs_strides: Vec<Expression>,
|
||||
pub rhs_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub transpose_lhs: bool,
|
||||
pub transpose_rhs: bool,
|
||||
}
|
||||
|
||||
impl MatmulDescriptor {
|
||||
pub fn from_mul_and_sum(
|
||||
mul_info: &MetalMulInfo,
|
||||
sum_info: &MetalSumReduceInfo,
|
||||
) -> Option<Self> {
|
||||
let zero = Expression::from(0);
|
||||
let z = Expression::from('z');
|
||||
|
||||
let is_simple_2d_matmul = mul_info.shape.len() == 3
|
||||
&& sum_info.shape.len() == 2
|
||||
&& mul_info.a_strides.len() == 3
|
||||
&& mul_info.b_strides.len() == 3
|
||||
&& sum_info.strides.len() == 2
|
||||
&& mul_info.shape[0] == sum_info.shape[0]
|
||||
&& mul_info.shape[1] == sum_info.shape[1]
|
||||
&& mul_info.shape[2] == sum_info.iters
|
||||
&& mul_info.a_strides[1] == zero
|
||||
&& mul_info.a_strides[2] == z
|
||||
&& mul_info.b_strides[0] == zero
|
||||
&& mul_info.b_strides[1] == z
|
||||
&& sum_info.strides[1] == z
|
||||
&& sum_info.iter_stride == z;
|
||||
|
||||
if !is_simple_2d_matmul {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
m: sum_info.shape[0],
|
||||
n: sum_info.shape[1],
|
||||
k: sum_info.iters,
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: mul_info.a_strides.clone(),
|
||||
rhs_strides: mul_info.b_strides.clone(),
|
||||
out_strides: sum_info.strides.clone(),
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulPlan {
|
||||
pub family: MetalMatmulFamily,
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub lda: Expression,
|
||||
pub ldb: Expression,
|
||||
pub ldd: Expression,
|
||||
pub batch_size: u32,
|
||||
pub batch_stride_a: u32,
|
||||
pub batch_stride_b: u32,
|
||||
pub batch_stride_d: u32,
|
||||
pub bm: u16,
|
||||
pub bn: u16,
|
||||
pub bk: u16,
|
||||
pub wm: u16,
|
||||
pub wn: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct MetalMatmulPlanner;
|
||||
|
||||
impl MetalMatmulPlanner {
|
||||
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
|
||||
let family = if desc.batch_shape.is_empty()
|
||||
&& desc.m.as_num().is_some_and(|m| m >= 32)
|
||||
&& desc.n.as_num().is_some_and(|n| n >= 32)
|
||||
&& desc.k.as_num().is_some_and(|k| k >= 32)
|
||||
{
|
||||
MetalMatmulFamily::RegularTiled
|
||||
} else {
|
||||
MetalMatmulFamily::Naive
|
||||
};
|
||||
MatmulPlan {
|
||||
family,
|
||||
m: desc.m,
|
||||
n: desc.n,
|
||||
k: desc.k,
|
||||
lda: desc.lhs_strides[0],
|
||||
ldb: desc.rhs_strides[2],
|
||||
ldd: desc.out_strides[0],
|
||||
batch_size: 1,
|
||||
batch_stride_a: 0,
|
||||
batch_stride_b: 0,
|
||||
batch_stride_d: 0,
|
||||
bm: 16,
|
||||
bn: 16,
|
||||
bk: 8,
|
||||
wm: 2,
|
||||
wn: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn descriptor_recovers_simple_2d_matmul() {
|
||||
let mul = MetalMulInfo {
|
||||
shape: vec![
|
||||
Expression::from(4),
|
||||
Expression::from(8),
|
||||
Expression::from(16),
|
||||
],
|
||||
a_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
b_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
output_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from('z') * 8,
|
||||
Expression::from('z'),
|
||||
],
|
||||
};
|
||||
let sum = MetalSumReduceInfo {
|
||||
shape: vec![Expression::from(4), Expression::from(8)],
|
||||
strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
iters: Expression::from(16),
|
||||
iter_stride: Expression::from('z'),
|
||||
};
|
||||
|
||||
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
|
||||
assert_eq!(desc.m, Expression::from(4));
|
||||
assert_eq!(desc.n, Expression::from(8));
|
||||
assert_eq!(desc.k, Expression::from(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_keeps_small_problems_on_naive_path() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(4),
|
||||
n: Expression::from(8),
|
||||
k: Expression::from(16),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::Naive);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
assert_eq!(plan.lda, Expression::from('z') * 16);
|
||||
assert_eq!(plan.ldb, Expression::from('z') * 8);
|
||||
assert_eq!(plan.ldd, Expression::from('z') * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_promotes_large_problems_to_regular_tiled() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(64),
|
||||
n: Expression::from(64),
|
||||
k: Expression::from(64),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 64,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 64,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
}
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MPSMatrixLayout {
|
||||
RowMajor,
|
||||
TransposedRowMajor,
|
||||
}
|
||||
|
||||
@@ -6,10 +6,127 @@ pub use ops::*;
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device,
|
||||
foreign_types::ForeignTypeRef, mps,
|
||||
};
|
||||
use objc::rc::StrongPtr;
|
||||
use objc::runtime::Object;
|
||||
use objc::{class, msg_send, sel, sel_impl};
|
||||
use std::cell::RefCell;
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatrixDescriptorKey {
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
data_type: isize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatmulKey {
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: u64,
|
||||
beta: u64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MpsKernelCache {
|
||||
matrix_descriptors: FxHashMap<MpsMatrixDescriptorKey, StrongPtr>,
|
||||
matmul_kernels: FxHashMap<MpsMatmulKey, StrongPtr>,
|
||||
}
|
||||
|
||||
impl MpsKernelCache {
|
||||
pub(crate) fn matrix_descriptor(
|
||||
&mut self,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
dtype: DType,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatrixDescriptorKey {
|
||||
rows,
|
||||
cols,
|
||||
row_bytes,
|
||||
data_type: Self::mps_data_type(dtype),
|
||||
};
|
||||
let descriptor = self
|
||||
.matrix_descriptors
|
||||
.entry(key)
|
||||
.or_insert_with(|| unsafe {
|
||||
let descriptor: *mut Object = msg_send![
|
||||
class!(MPSMatrixDescriptor),
|
||||
matrixDescriptorWithRows: rows
|
||||
columns: cols
|
||||
rowBytes: row_bytes as usize
|
||||
dataType: key.data_type
|
||||
];
|
||||
StrongPtr::retain(descriptor)
|
||||
});
|
||||
**descriptor
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn matrix_multiplication(
|
||||
&mut self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatmulKey {
|
||||
transpose_lhs,
|
||||
transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha: alpha.to_bits(),
|
||||
beta: beta.to_bits(),
|
||||
};
|
||||
let kernel = self.matmul_kernels.entry(key).or_insert_with(|| unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: transpose_lhs
|
||||
transposeRight: transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: alpha
|
||||
beta: beta
|
||||
];
|
||||
StrongPtr::new(kernel)
|
||||
});
|
||||
**kernel
|
||||
}
|
||||
|
||||
fn mps_data_type(dtype: DType) -> isize {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => mps::MPSDataType::Float32 as isize,
|
||||
DType::F16 => mps::MPSDataType::Float16 as isize,
|
||||
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MetalEncodeContext<'a> {
|
||||
pub(crate) command_buffer: &'a CommandBufferRef,
|
||||
pub(crate) dyn_buffer: &'a Buffer,
|
||||
pub(crate) mps_cache: &'a RefCell<MpsKernelCache>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalMulInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
@@ -32,7 +149,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
device: &Device,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) -> ComputePipelineState;
|
||||
) -> Option<ComputePipelineState>;
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.first().copied().unwrap_or(DType::F32)
|
||||
@@ -40,7 +157,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
fn encode(
|
||||
fn encode_compute(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
pipeline: &ComputePipelineState,
|
||||
@@ -49,6 +166,25 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
);
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn encode(
|
||||
&self,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_input_dtypes: &[DType],
|
||||
_output_dtype: DType,
|
||||
) {
|
||||
let pipeline = pipeline.expect("compute pipeline not compiled");
|
||||
let encoder = context.command_buffer.new_compute_command_encoder();
|
||||
let dyn_idx = inputs.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(context.dyn_buffer), 0);
|
||||
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Performance Metrics for MBU/MFU Calculation
|
||||
// ========================================================================
|
||||
@@ -73,6 +209,10 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_matmul(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "luminal_nn"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
@@ -166,8 +166,8 @@ mod tests {
|
||||
let indices = cx.tensor(3).as_dtype(DType::Int);
|
||||
let result = gather_rows(data, indices, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
|
||||
rt.set_data(
|
||||
@@ -192,8 +192,8 @@ mod tests {
|
||||
let dest = cx.tensor((4, 3));
|
||||
let result = scatter_rows(src, indices, dest, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
|
||||
rt.set_data(indices.id, vec![1, 3]);
|
||||
@@ -218,8 +218,8 @@ mod tests {
|
||||
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
|
||||
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
|
||||
@@ -271,8 +271,8 @@ mod tests {
|
||||
let k_cache_new = k_cache_new.output();
|
||||
let v_cache_new = v_cache_new.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
|
||||
rt.set_data(q.id, vec![1., 0., 1., 0.]);
|
||||
@@ -344,8 +344,8 @@ mod tests {
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
|
||||
// K cached at slot 0: [1, 0]
|
||||
@@ -416,8 +416,8 @@ mod tests {
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
// Cache has 1 token at slot 0
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
|
||||
@@ -61,7 +61,8 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -70,7 +71,7 @@ impl MoE {
|
||||
mod tests {
|
||||
use super::MoE;
|
||||
use luminal::prelude::*;
|
||||
use rand::{rng, Rng};
|
||||
use rand::{Rng, rng};
|
||||
|
||||
fn random_vec(n: usize) -> Vec<f32> {
|
||||
let mut r = rng();
|
||||
@@ -182,8 +183,8 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
let input_data = vec![1.0, 2.0, 3.0];
|
||||
// Router strongly favors expert 0
|
||||
@@ -237,8 +238,8 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
let input_data = vec![1.0, 1.0];
|
||||
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
|
||||
@@ -291,8 +292,8 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
let input_data = vec![
|
||||
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
|
||||
@@ -348,8 +349,8 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
let input_data = random_vec(in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
@@ -393,8 +394,8 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
|
||||
let input_data = random_vec(batch * in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
@@ -478,7 +479,8 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -855,8 +855,6 @@ Two important details:
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
|
||||
@@ -431,7 +431,7 @@ def main() -> None:
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
|
||||
max_new_tokens = 100
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
|
||||
@@ -8,7 +8,7 @@ echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py"
|
||||
CUDA_TESTS="tests/"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -5,6 +5,7 @@ use luminal::{
|
||||
visualization::ToDot,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::typed_data::TypedData;
|
||||
@@ -73,22 +74,13 @@ fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usiz
|
||||
Some((var, candidate))
|
||||
}
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
/// Convert luminal `DType` to a PT2 dtype code via `TorchDType`. Panics
|
||||
/// for luminal-specific dtypes that have no PyTorch counterpart (`I4`,
|
||||
/// `U4`, the F6 / F4 families, ...).
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
match dtype {
|
||||
DType::U8 => 1,
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
DType::F64 => 8,
|
||||
DType::Bool => 12,
|
||||
DType::Bf16 => 13,
|
||||
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
|
||||
}
|
||||
crate::torch_dtype::TorchDType::try_from(dtype)
|
||||
.map(|t| t.code())
|
||||
.unwrap_or_else(|d| panic!("luminal_dtype_to_pt2_code: unsupported dtype {d:?}"))
|
||||
}
|
||||
|
||||
/// Common intermediate result from translating a model graph.
|
||||
@@ -98,7 +90,12 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -124,7 +121,9 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -151,17 +150,21 @@ impl CompiledGraph {
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
let WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
} = weight_data;
|
||||
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
// Build compile args from WeightData.
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weight_data
|
||||
.weights
|
||||
weights: weights
|
||||
.iter()
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
};
|
||||
|
||||
// Create backend via the factory directly
|
||||
@@ -380,7 +383,7 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Register a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
@@ -441,7 +444,7 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// Register a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
@@ -476,10 +479,7 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
self.output_dtypes.clone()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
@@ -504,6 +504,65 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Read an output as f16 (returned as raw little-endian bytes —
|
||||
/// Python has no native f16, so the caller bit-casts via
|
||||
/// `torch.frombuffer(..., dtype=torch.float16)`). Strict: the
|
||||
/// producer node must already be `DType::F16`; no widening at
|
||||
/// the read boundary.
|
||||
fn get_output_f16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = self.runtime.get_output_f16(*node_id);
|
||||
let bytes: &[u8] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
|
||||
Ok(PyBytes::new(py, bytes))
|
||||
}
|
||||
|
||||
/// Read an output as bf16 (returned as raw little-endian bytes —
|
||||
/// caller bit-casts via `torch.frombuffer(..., dtype=torch.
|
||||
/// bfloat16)`). Strict: the producer node must already be
|
||||
/// `DType::Bf16`; no widening at the read boundary.
|
||||
fn get_output_bf16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = self.runtime.get_output_bf16(*node_id);
|
||||
let bytes: &[u8] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
|
||||
Ok(PyBytes::new(py, bytes))
|
||||
}
|
||||
|
||||
/// Read an output as i64. Strict: the producer node must already
|
||||
/// be `DType::I64`; no widening at the read boundary.
|
||||
fn get_output_i64(&self, name: &str) -> PyResult<Vec<i64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_i64(*node_id))
|
||||
}
|
||||
|
||||
/// Read an output as f64. Strict: the producer node must already
|
||||
/// be `DType::F64`; no widening at the read boundary.
|
||||
fn get_output_f64(&self, name: &str) -> PyResult<Vec<f64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_f64(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
|
||||
120
crates/luminal_python/rust/src/dim_arith.rs
Normal file
120
crates/luminal_python/rust/src/dim_arith.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
//! Canonical-form helpers for dimension `Expression` arithmetic — used
|
||||
//! by the translator to keep shape arithmetic syntactically consistent
|
||||
//! across code paths.
|
||||
//!
|
||||
//! `Expression` equality is syntactic; `a * 8` and `8 * a` are distinct
|
||||
//! objects despite being mathematically equal. When two translator code
|
||||
//! paths build the same logical dim via differently-ordered
|
||||
//! multiplications, downstream `assert_eq!(self.dims(), rhs.dims())`
|
||||
//! checks in `GraphTensor::Add` / `Sub` / `Mul` / `Rem` panic. These
|
||||
//! helpers solve that at the construction site: every shape product
|
||||
//! goes through `product_of_dims`, which sorts the operand list before
|
||||
//! folding, so two callers passing the operands in different orders
|
||||
//! produce identical `Expression`s.
|
||||
//!
|
||||
//! Lives in `luminal_python` (rather than upstream `luminal::shape`) so
|
||||
//! the change is contained to the translator. luminal-core callers of
|
||||
//! `gather_elements` / `scatter_elements` / `scatter_nd` historically
|
||||
//! pass concrete dims, so they don't need this; the translator-local
|
||||
//! lowerings in `translator::movement_dynamic` do.
|
||||
//!
|
||||
//! The ordering matches what `pt2_expr.rs::normalize_mul_expr` was
|
||||
//! using locally before being promoted here — see that file for the
|
||||
//! original canonical-sort logic.
|
||||
|
||||
use luminal::prelude::Expression;
|
||||
|
||||
/// Sort key for the canonical commutative ordering. Sorts by RPN-term
|
||||
/// count first so single-term operands (variables, literals) sort
|
||||
/// before compound subexpressions; ties broken by debug repr so two
|
||||
/// single-term operands have a stable alphabetic order.
|
||||
///
|
||||
/// O(n) string alloc per compare — only call on shape products, never
|
||||
/// per-element in a kernel.
|
||||
#[inline]
|
||||
pub(crate) fn commutative_key(expr: &Expression) -> (usize, String) {
|
||||
(expr.len(), format!("{expr:?}"))
|
||||
}
|
||||
|
||||
/// Order `(a, b)` so the canonically-smaller expression is first.
|
||||
#[inline]
|
||||
pub(crate) fn sort_pair(a: Expression, b: Expression) -> (Expression, Expression) {
|
||||
if commutative_key(&a) <= commutative_key(&b) {
|
||||
(a, b)
|
||||
} else {
|
||||
(b, a)
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply two dim expressions with canonical operand ordering.
|
||||
#[inline]
|
||||
pub(crate) fn mul_dims(a: Expression, b: Expression) -> Expression {
|
||||
let (a, b) = sort_pair(a, b);
|
||||
a * b
|
||||
}
|
||||
|
||||
/// Add two dim expressions with canonical operand ordering.
|
||||
#[inline]
|
||||
pub(crate) fn add_dims(a: Expression, b: Expression) -> Expression {
|
||||
let (a, b) = sort_pair(a, b);
|
||||
a + b
|
||||
}
|
||||
|
||||
/// Product of a sequence of dim expressions. Operands are sorted
|
||||
/// canonically before folding so callers passing the same logical
|
||||
/// dim set in different orders produce identical `Expression`s.
|
||||
/// Empty sequence → `Expression::from(1usize)`.
|
||||
pub(crate) fn product_of_dims<I>(dims: I) -> Expression
|
||||
where
|
||||
I: IntoIterator<Item = Expression>,
|
||||
{
|
||||
let mut v: Vec<Expression> = dims.into_iter().collect();
|
||||
v.sort_by_key(commutative_key);
|
||||
v.into_iter()
|
||||
.fold(Expression::from(1usize), |acc, d| acc * d)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mul_dims_canonicalises_commutative_order() {
|
||||
let a = Expression::from('a');
|
||||
let n = Expression::from(8i64);
|
||||
assert_eq!(mul_dims(a, n), mul_dims(n, a));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn product_of_dims_independent_of_input_order() {
|
||||
let a = Expression::from('a');
|
||||
let b = Expression::from('b');
|
||||
let n = Expression::from(8i64);
|
||||
let p1 = product_of_dims([a, n, b]);
|
||||
let p2 = product_of_dims([n, b, a]);
|
||||
let p3 = product_of_dims([b, a, n]);
|
||||
assert_eq!(p1, p2);
|
||||
assert_eq!(p1, p3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_product_is_one() {
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(product_of_dims(empty), Expression::from(1usize));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_numeric_types_canonicalise_together() {
|
||||
// `pt2_util` builds with `Expression::from(usize)` while tests /
|
||||
// direct callers reach for `i64`. The two literal paths must
|
||||
// produce identical reprs or `product_of_dims` will sort them
|
||||
// into different positions and we lose the canonical-form
|
||||
// guarantee across call sites.
|
||||
assert_eq!(Expression::from(8usize), Expression::from(8i64));
|
||||
let a = Expression::from('a');
|
||||
assert_eq!(
|
||||
product_of_dims([Expression::from(8usize), a]),
|
||||
product_of_dims([Expression::from(8i64), a]),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
mod compiled_graph;
|
||||
mod dim_arith;
|
||||
pub mod torch_dtype;
|
||||
pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
mod pt2_expr;
|
||||
mod pt2_parser;
|
||||
mod pt2_schema;
|
||||
mod pt2_util;
|
||||
@@ -12,17 +15,32 @@ use compiled_graph::CompiledGraph;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyCapsule;
|
||||
use std::collections::HashMap;
|
||||
use torch_dtype::TorchDType;
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(_torch_dtype_codes, m)?)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `{variant_name: pt2_code}` for every `TorchDType` variant. The Python
|
||||
/// parity test (`tests/test_torch_dtype_parity.py`) consumes this and
|
||||
/// asserts every entry matches `torch._export.serde.schema.ScalarType.<name>
|
||||
/// .value` — drift fails CI rather than silently miscompiling at runtime.
|
||||
#[pyfunction]
|
||||
fn _torch_dtype_codes() -> HashMap<&'static str, u32> {
|
||||
TorchDType::ALL
|
||||
.iter()
|
||||
.map(|v| (v.name(), v.code()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory capsule helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -6,10 +6,11 @@ use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_expr::parse_sympy_expr;
|
||||
use crate::pt2_parser;
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::typed_data::TypedData;
|
||||
use crate::{pt2_parser, pt2_util};
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
@@ -21,7 +22,7 @@ fn resolve_dim_sizes(
|
||||
sizes
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
let s = e.as_expr.expr_str.trim();
|
||||
// Try the full sympy-style parse first so compound forms like
|
||||
@@ -45,7 +46,7 @@ fn resolve_dim_sizes(
|
||||
.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(|h| Expression::from(h as usize))
|
||||
.map(Expression::from)
|
||||
})
|
||||
.unwrap_or_else(|| Expression::from(1usize))
|
||||
}
|
||||
@@ -53,139 +54,6 @@ fn resolve_dim_sizes(
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
|
||||
///
|
||||
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
|
||||
///
|
||||
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
|
||||
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
|
||||
/// * `Integer(N)` / `Number(N)` — concrete int.
|
||||
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
|
||||
///
|
||||
/// Returns `None` for anything else so the caller can fall back to a less
|
||||
/// precise representation rather than committing a wrong expression.
|
||||
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Bare integer literal — `srepr` doesn't usually emit this at the top
|
||||
// level (it wraps in `Integer(...)`), but accept it for robustness.
|
||||
if let Ok(n) = s.parse::<i64>() {
|
||||
return Some(Expression::from(n as usize));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(s)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
// Body is `'name', positive=True, integer=True` etc. Pull the
|
||||
// first quoted token as the name.
|
||||
let name = extract_first_quoted(body)?;
|
||||
sym_to_char.get(&name).map(|c| Expression::from(*c))
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let n: i64 = body.trim().parse().ok()?;
|
||||
Some(Expression::from(n as usize))
|
||||
}
|
||||
"Mul" | "Add" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
|
||||
for p in iter {
|
||||
let rhs = parse_sympy_expr(p, sym_to_char)?;
|
||||
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into (head, body); returns None if not in that form.
|
||||
fn split_head(s: &str) -> Option<(&str, &str)> {
|
||||
let open = s.find('(')?;
|
||||
if !s.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&s[..open], &s[open + 1..s.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list,
|
||||
/// e.g. `'s77', positive=True` → `s77`.
|
||||
fn extract_first_quoted(s: &str) -> Option<String> {
|
||||
let bytes = s.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(s[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
|
||||
/// dimensional information).
|
||||
fn split_top_level_args(s: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = s.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = s[start..i].trim();
|
||||
// Drop `key=value` kwargs — they're metadata sympy uses
|
||||
// for pretty-printing, not arguments to the operator.
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = s[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
// sympy kwargs are bare identifiers like `positive`, `integer`.
|
||||
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
@@ -262,10 +130,13 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,14 +152,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -503,52 +374,10 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to TypedData using PT2 dtype numbering.
|
||||
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
|
||||
/// Converts i64/f64/i16 to the closest luminal-native representation.
|
||||
/// Convert raw bytes to `TypedData` using PT2 dtype numbering. Thin
|
||||
/// wrapper around `TypedData::from_pytorch_bytes` — the dtype dispatch
|
||||
/// (including the narrow-int panic and unknown-code rejection) lives
|
||||
/// there, so this site stays a one-liner that just clones the slice.
|
||||
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
match dtype {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
|
||||
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
|
||||
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
|
||||
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
|
||||
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
_ => {
|
||||
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
|
||||
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
|
||||
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
|
||||
}
|
||||
}
|
||||
TypedData::from_pytorch_bytes(bytes.to_vec(), dtype)
|
||||
}
|
||||
|
||||
699
crates/luminal_python/rust/src/pt2_expr.rs
Normal file
699
crates/luminal_python/rust/src/pt2_expr.rs
Normal file
@@ -0,0 +1,699 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_parser::SymDimMap;
|
||||
use crate::pt2_schema::RangeConstraint;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) struct ExprBounds {
|
||||
pub(crate) min: Option<i64>,
|
||||
pub(crate) max: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct ParsedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
impl ParsedExpr {
|
||||
fn exact(expr: Expression, value: i64) -> Self {
|
||||
Self {
|
||||
expr,
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct BoundedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal `Expression`.
|
||||
///
|
||||
/// Supports the subset of sympy heads PT2 emits for symbolic shape metadata.
|
||||
pub(crate) fn parse_sympy_expr(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr, sym_to_char, &HashMap::new())
|
||||
}
|
||||
|
||||
pub(crate) fn parse_sympy_expr_with_ranges(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_inner(expr, sym_to_char, ranges).map(|parsed| parsed.expr)
|
||||
}
|
||||
|
||||
pub(crate) fn sym_char_ranges(sym_map: &SymDimMap) -> FxHashMap<char, ExprBounds> {
|
||||
sym_map
|
||||
.sym_to_char
|
||||
.iter()
|
||||
.map(|(sym_name, sym_char)| {
|
||||
let range = sym_map.ranges.get(sym_name);
|
||||
let min = range
|
||||
.and_then(|range| range.min_val)
|
||||
.map(|min| min.max(0))
|
||||
.or(Some(0));
|
||||
let max = range
|
||||
.and_then(|range| range.max_val)
|
||||
.filter(|max| *max >= 0);
|
||||
(*sym_char, ExprBounds { min, max })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn simplify_expr_with_ranges(
|
||||
expr: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Expression {
|
||||
simplify_bound_expr(expr, sym_ranges).expr
|
||||
}
|
||||
|
||||
pub(crate) fn same_expr_with_ranges(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
let lhs = simplify_bound_expr(lhs, sym_ranges);
|
||||
let rhs = simplify_bound_expr(rhs, sym_ranges);
|
||||
lhs.expr == rhs.expr
|
||||
|| lhs.expr.egglog_equal(rhs.expr)
|
||||
|| (exact_value(lhs) == exact_value(rhs) && exact_value(lhs).is_some())
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_equal_expr(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Option<Expression> {
|
||||
if !same_expr_with_ranges(lhs, rhs, sym_ranges) {
|
||||
return None;
|
||||
}
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, sym_ranges);
|
||||
Some(if lhs_simplified.len() <= rhs_simplified.len() {
|
||||
lhs_simplified
|
||||
} else {
|
||||
rhs_simplified
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_sympy_expr_inner(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<ParsedExpr> {
|
||||
let expr = expr.trim();
|
||||
if expr.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Ok(value) = expr.parse::<i64>() {
|
||||
return Some(ParsedExpr::exact(Expression::from(value), value));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(expr)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
let name = extract_first_quoted(body)?;
|
||||
let bounds = infer_symbol_bounds(body, ranges.get(&name));
|
||||
sym_to_char.get(&name).map(|c| ParsedExpr {
|
||||
expr: Expression::from(*c),
|
||||
bounds,
|
||||
})
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let value = body.trim().parse::<i64>().ok()?;
|
||||
Some(ParsedExpr::exact(Expression::from(value), value))
|
||||
}
|
||||
"NegativeOne" => Some(ParsedExpr::exact(Expression::from(-1i64), -1)),
|
||||
"Zero" => Some(ParsedExpr::exact(Expression::from(0i64), 0)),
|
||||
"One" => Some(ParsedExpr::exact(Expression::from(1i64), 1)),
|
||||
"Mul" | "Add" | "Min" | "Max" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr_inner(iter.next()?, sym_to_char, ranges)?;
|
||||
for part in iter {
|
||||
let rhs = parse_sympy_expr_inner(part, sym_to_char, ranges)?;
|
||||
acc = match head {
|
||||
"Mul" => ParsedExpr {
|
||||
expr: normalize_mul_expr(acc.expr, rhs.expr),
|
||||
bounds: mul_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Add" => ParsedExpr {
|
||||
expr: normalize_add_expr(acc.expr, rhs.expr),
|
||||
bounds: add_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Min" => reduce_min(acc, rhs),
|
||||
"Max" => reduce_max(acc, rhs),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
"FloorDiv" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr / rhs.expr,
|
||||
bounds: div_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
"Mod" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr % rhs.expr,
|
||||
bounds: mod_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_symbol_bounds(body: &str, range: Option<&RangeConstraint>) -> ExprBounds {
|
||||
let mut bounds = ExprBounds::default();
|
||||
if body.contains("positive=True") {
|
||||
bounds.min = Some(1);
|
||||
} else if body.contains("nonnegative=True") {
|
||||
bounds.min = Some(0);
|
||||
}
|
||||
if let Some(range) = range {
|
||||
bounds.min = match (bounds.min, range.min_val) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
(None, Some(rhs)) => Some(rhs),
|
||||
(lhs, None) => lhs,
|
||||
};
|
||||
bounds.max = range.max_val;
|
||||
}
|
||||
bounds
|
||||
}
|
||||
|
||||
fn exact_expr(value: i64) -> BoundedExpr {
|
||||
BoundedExpr {
|
||||
expr: Expression::from(value),
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn exact_value(expr: BoundedExpr) -> Option<i64> {
|
||||
expr.expr.as_num().or({
|
||||
(expr.bounds.min == expr.bounds.max)
|
||||
.then_some(expr.bounds.min)
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn exact_bound_value(bounds: ExprBounds) -> Option<i64> {
|
||||
(bounds.min == bounds.max).then_some(bounds.min).flatten()
|
||||
}
|
||||
|
||||
fn with_bounds(expr: Expression, bounds: ExprBounds) -> BoundedExpr {
|
||||
BoundedExpr { expr, bounds }
|
||||
}
|
||||
|
||||
fn bool_bounds() -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: Some(0),
|
||||
max: Some(1),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expr(expr: Expression) -> Expression {
|
||||
if expr.len() <= 16 {
|
||||
expr.simplify()
|
||||
} else {
|
||||
expr
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_add_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
normalize_expr(crate::dim_arith::add_dims(lhs, rhs))
|
||||
}
|
||||
|
||||
fn normalize_mul_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
normalize_expr(crate::dim_arith::mul_dims(lhs, rhs))
|
||||
}
|
||||
|
||||
fn checked_add_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_add(rhs))
|
||||
}
|
||||
|
||||
fn checked_sub_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_sub(rhs))
|
||||
}
|
||||
|
||||
fn checked_mul_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_mul(rhs))
|
||||
}
|
||||
|
||||
fn add_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_add_opt(lhs.min, rhs.min),
|
||||
max: checked_add_opt(lhs.max, rhs.max),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) >= 0 && rhs.min.unwrap_or(i64::MIN) >= 0 {
|
||||
return ExprBounds {
|
||||
min: checked_mul_opt(lhs.min, rhs.min),
|
||||
max: checked_mul_opt(lhs.max, rhs.max),
|
||||
};
|
||||
}
|
||||
ExprBounds::default()
|
||||
}
|
||||
|
||||
fn sub_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_sub_opt(lhs.min, rhs.max),
|
||||
max: checked_sub_opt(lhs.max, rhs.min),
|
||||
}
|
||||
}
|
||||
|
||||
fn div_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
let (Some(rhs_min), Some(rhs_max)) = (rhs.min, rhs.max) else {
|
||||
return ExprBounds::default();
|
||||
};
|
||||
if rhs_min <= 0 || rhs_max <= 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
ExprBounds {
|
||||
min: lhs.min.and_then(|lhs_min| lhs_min.checked_div(rhs_max)),
|
||||
max: lhs.max.and_then(|lhs_max| lhs_max.checked_div(rhs_min)),
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) < 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
match exact_bound_value(rhs) {
|
||||
Some(rhs_exact) if rhs_exact > 0 => ExprBounds {
|
||||
min: Some(0),
|
||||
max: rhs_exact.checked_sub(1),
|
||||
},
|
||||
_ => ExprBounds::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_min(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.min(rhs.expr),
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_max(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.max(rhs.expr),
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn min_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn max_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_is_offset_by_small_const(lhs: Expression, rhs: Expression) -> bool {
|
||||
(1..=8).any(|delta| lhs.egglog_equal(rhs + delta))
|
||||
}
|
||||
|
||||
fn split_add_const(expr: Expression) -> Option<(i64, Expression)> {
|
||||
let terms = expr.terms.read();
|
||||
if terms.len() >= 3 && terms.last() == Some(&Term::Add) {
|
||||
if let Some(Term::Num(n)) = terms.first() {
|
||||
return Some((*n, Expression::new(terms[1..terms.len() - 1].to_vec())));
|
||||
}
|
||||
if let Some(Term::Num(n)) = terms.get(terms.len() - 2) {
|
||||
return Some((*n, Expression::new(terms[..terms.len() - 2].to_vec())));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn simplify_add(lhs: BoundedExpr, rhs: BoundedExpr) -> BoundedExpr {
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) => rhs.expr,
|
||||
(_, Some(0)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs + rhs),
|
||||
(_, Some(rhs)) => normalize_add_expr(lhs.expr, Expression::from(rhs)),
|
||||
(Some(lhs), _) => normalize_add_expr(Expression::from(lhs), rhs.expr),
|
||||
_ => normalize_add_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
with_bounds(expr, add_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_sub(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return exact_expr(0);
|
||||
}
|
||||
let expr = match exact_value(rhs) {
|
||||
Some(0) => lhs.expr,
|
||||
Some(rhs_const) => {
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr) {
|
||||
normalize_expr(lhs_base + (lhs_const - rhs_const))
|
||||
} else {
|
||||
normalize_expr(lhs.expr - rhs_const)
|
||||
}
|
||||
}
|
||||
None => normalize_expr(lhs.expr - rhs.expr),
|
||||
};
|
||||
with_bounds(expr, sub_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_min(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = min_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.min(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_max(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = max_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.max(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_bound_expr(expr: Expression, sym_ranges: &FxHashMap<char, ExprBounds>) -> BoundedExpr {
|
||||
let mut stack: Vec<BoundedExpr> = Vec::new();
|
||||
let terms = expr.terms.read().clone();
|
||||
for term in terms {
|
||||
match term {
|
||||
Term::Num(n) => stack.push(exact_expr(n)),
|
||||
Term::Var(c) => stack.push(with_bounds(
|
||||
Expression::from(c),
|
||||
sym_ranges.get(&c).copied().unwrap_or_default(),
|
||||
)),
|
||||
Term::Add => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_add(lhs, rhs));
|
||||
}
|
||||
Term::Sub => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_sub(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Mul => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(0)) => Expression::from(0),
|
||||
(Some(1), _) => rhs.expr,
|
||||
(_, Some(1)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs * rhs),
|
||||
_ => normalize_mul_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mul_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Div | Term::CeilDiv => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(_, Some(0), _) => Expression::from(0),
|
||||
(_, _, Some(1)) => lhs.expr,
|
||||
(Term::Div, Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs / rhs),
|
||||
(Term::CeilDiv, Some(lhs), Some(rhs)) if rhs > 0 => {
|
||||
Expression::from(if lhs % rhs != 0 {
|
||||
lhs / rhs + 1
|
||||
} else {
|
||||
lhs / rhs
|
||||
})
|
||||
}
|
||||
(Term::Div, _, _) => normalize_expr(lhs.expr / rhs.expr),
|
||||
(Term::CeilDiv, _, _) => normalize_expr(lhs.expr.ceil_div(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, div_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Mod => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(1)) => Expression::from(0),
|
||||
(Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs % rhs),
|
||||
_ => normalize_expr(lhs.expr % rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mod_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Min => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_min(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Max => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_max(lhs, rhs, sym_ranges));
|
||||
}
|
||||
term @ (Term::And | Term::Or | Term::Gte | Term::Lt) => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(Term::And, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 && rhs != 0) as i64)
|
||||
}
|
||||
(Term::And, _, _) => normalize_expr(lhs.expr & rhs.expr),
|
||||
(Term::Or, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 || rhs != 0) as i64)
|
||||
}
|
||||
(Term::Or, _, _) => normalize_expr(lhs.expr | rhs.expr),
|
||||
(Term::Gte, Some(lhs), Some(rhs)) => Expression::from((lhs >= rhs) as i64),
|
||||
(Term::Gte, _, _) => normalize_expr(lhs.expr.gte(rhs.expr)),
|
||||
(Term::Lt, Some(lhs), Some(rhs)) => Expression::from((lhs < rhs) as i64),
|
||||
(Term::Lt, _, _) => normalize_expr(lhs.expr.lt(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, bool_bounds()));
|
||||
}
|
||||
}
|
||||
}
|
||||
stack
|
||||
.pop()
|
||||
.unwrap_or(with_bounds(expr, ExprBounds::default()))
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into `(head, body)`.
|
||||
fn split_head(expr: &str) -> Option<(&str, &str)> {
|
||||
let open = expr.find('(')?;
|
||||
if !expr.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&expr[..open], &expr[open + 1..expr.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list.
|
||||
fn extract_first_quoted(expr: &str) -> Option<String> {
|
||||
let bytes = expr.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(expr[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split a sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Drops `key=value` kwargs.
|
||||
fn split_top_level_args(expr: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = expr.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = expr[start..i].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = expr[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
return !key.is_empty() && key.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -15,7 +15,16 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
pub min_val: i64,
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
fn same_dim(lhs: Expression, rhs: Expression) -> bool {
|
||||
lhs == rhs || lhs.simplify() == rhs.simplify() || lhs.egglog_equal(rhs)
|
||||
}
|
||||
|
||||
/// Binary operation type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryOp {
|
||||
@@ -51,7 +55,7 @@ pub fn broadcast_binary(mut a: GraphTensor, mut b: GraphTensor) -> (GraphTensor,
|
||||
let a_dim = a.shape.dims[i];
|
||||
let b_dim = b.shape.dims[i];
|
||||
|
||||
if a_dim == b_dim {
|
||||
if same_dim(a_dim, b_dim) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -110,29 +114,17 @@ pub fn resolve_neg1_dim(target: &[i64], current_dims: &[Expression]) -> Vec<Expr
|
||||
}
|
||||
|
||||
if let Some(idx) = neg1_idx {
|
||||
let mut total = Expression::from(1usize);
|
||||
for d in current_dims {
|
||||
total *= *d;
|
||||
}
|
||||
if let (Some(total_val), Some(_)) = (
|
||||
{
|
||||
let mut t = 1i64;
|
||||
let mut all_concrete = true;
|
||||
for d in current_dims {
|
||||
if let Some(v) = d.to_usize() {
|
||||
t *= v as i64;
|
||||
} else {
|
||||
all_concrete = false;
|
||||
}
|
||||
}
|
||||
if all_concrete { Some(t) } else { None }
|
||||
},
|
||||
Some(known_product),
|
||||
) {
|
||||
result[idx] = Expression::from((total_val / known_product) as usize);
|
||||
} else {
|
||||
result[idx] = total / Expression::from(known_product as usize);
|
||||
}
|
||||
result[idx] = match current_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize())
|
||||
.collect::<Option<Vec<_>>>()
|
||||
{
|
||||
Some(vs) => Expression::from(vs.iter().product::<usize>() / known_product as usize),
|
||||
None => {
|
||||
crate::dim_arith::product_of_dims(current_dims.iter().copied())
|
||||
/ Expression::from(known_product as usize)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
result
|
||||
@@ -181,11 +173,12 @@ pub fn resolve_neg1_dim_exprs(
|
||||
if input_symbolic.is_empty() {
|
||||
result[idx] = Expression::from((input_concrete / target_concrete) as usize);
|
||||
} else {
|
||||
let mut expr = Expression::from((input_concrete / target_concrete) as usize);
|
||||
for s in &input_symbolic {
|
||||
expr *= *s;
|
||||
}
|
||||
result[idx] = expr;
|
||||
let mut operands: Vec<Expression> = Vec::with_capacity(input_symbolic.len() + 1);
|
||||
operands.push(Expression::from(
|
||||
(input_concrete / target_concrete) as usize,
|
||||
));
|
||||
operands.extend(input_symbolic.iter().copied());
|
||||
result[idx] = crate::dim_arith::product_of_dims(operands);
|
||||
}
|
||||
|
||||
result
|
||||
@@ -194,16 +187,29 @@ pub fn resolve_neg1_dim_exprs(
|
||||
}
|
||||
}
|
||||
|
||||
/// Map torch dtype integer (PT2 format) to luminal DType.
|
||||
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
|
||||
/// Map a PT2 dtype code to luminal `DType`. Panics for variants the IR
|
||||
/// doesn't model as first-class types (narrow ints `Byte` / `Char` /
|
||||
/// `Short`, the complex family, the float8 family) and for unknown
|
||||
/// codes — better to fail loudly at the translator boundary than to
|
||||
/// silently widen and lie about the user's dtype.
|
||||
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
|
||||
match dtype {
|
||||
6 => DType::F16,
|
||||
7 => DType::F32,
|
||||
8 => DType::F32, // float64 → F32 (no F64 in luminal)
|
||||
13 => DType::Bf16,
|
||||
12 => DType::Bool,
|
||||
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
|
||||
_ => DType::F32,
|
||||
let t = crate::torch_dtype::TorchDType::from_code(dtype)
|
||||
.unwrap_or_else(|c| panic!("torch_dtype_int_to_luminal: unknown PT2 dtype code {c}"));
|
||||
match t {
|
||||
crate::torch_dtype::TorchDType::Byte
|
||||
| crate::torch_dtype::TorchDType::Char
|
||||
| crate::torch_dtype::TorchDType::Short => panic!(
|
||||
"torch_dtype_int_to_luminal: PT2 dtype {} (code {}) isn't a first-class \
|
||||
IR type yet — cast to torch.int32 at the call site, or wait for the \
|
||||
narrower-int IR follow-up.",
|
||||
t.name(),
|
||||
t.code(),
|
||||
),
|
||||
other => DType::try_from(other).unwrap_or_else(|t| {
|
||||
panic!(
|
||||
"torch_dtype_int_to_luminal: {} isn't a first-class luminal IR type",
|
||||
t.name()
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
235
crates/luminal_python/rust/src/torch_dtype.rs
Normal file
235
crates/luminal_python/rust/src/torch_dtype.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
//! Typed mirror of PyTorch's PT2 export-schema `ScalarType` enum.
|
||||
//!
|
||||
//! The PT2 export pipeline wire-serializes tensor dtypes as `u32` codes drawn
|
||||
//! from `torch._export.serde.schema.ScalarType` (an `IntEnum` on the Python
|
||||
//! side). Three sites in this crate used to carry duplicate raw-`u32` match
|
||||
//! arms with the canonical numbering hand-rolled in each — silent miscompile
|
||||
//! risk when PyTorch renumbers or adds a code. This module collapses those
|
||||
//! sites onto one typed enum and pins the numbering with a parity test that
|
||||
//! asserts every Rust variant matches `torch._export.serde.schema.ScalarType`
|
||||
//! at CI time (see `crates/luminal_python/tests/test_torch_dtype_parity.py`).
|
||||
//!
|
||||
//! Note: PyTorch's C++ `c10::ScalarType` uses a different numbering than the
|
||||
//! PT2 schema (PT2 reserves 0 for `Unknown`); we bind to the **PT2 schema**,
|
||||
//! not the c10 header, because that is what flows over our wire.
|
||||
|
||||
use luminal::prelude::DType;
|
||||
|
||||
/// PT2 export-schema dtype code. Discriminants match
|
||||
/// `torch._export.serde.schema.ScalarType` variant values exactly; drift is
|
||||
/// caught by `tests/test_torch_dtype_parity.py`.
|
||||
#[repr(u32)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum TorchDType {
|
||||
Unknown = 0,
|
||||
Byte = 1,
|
||||
Char = 2,
|
||||
Short = 3,
|
||||
Int = 4,
|
||||
Long = 5,
|
||||
Half = 6,
|
||||
Float = 7,
|
||||
Double = 8,
|
||||
ComplexHalf = 9,
|
||||
ComplexFloat = 10,
|
||||
ComplexDouble = 11,
|
||||
Bool = 12,
|
||||
BFloat16 = 13,
|
||||
Uint16 = 28,
|
||||
Float8E4m3Fn = 29,
|
||||
Float8E5m2 = 30,
|
||||
Float8E4m3Fnuz = 31,
|
||||
Float8E5m2Fnuz = 32,
|
||||
}
|
||||
|
||||
impl TorchDType {
|
||||
/// All variants, in declaration order. Used by the pyo3-exported parity
|
||||
/// table and by tests; add new variants here when PyTorch adds them.
|
||||
pub const ALL: &'static [TorchDType] = &[
|
||||
TorchDType::Unknown,
|
||||
TorchDType::Byte,
|
||||
TorchDType::Char,
|
||||
TorchDType::Short,
|
||||
TorchDType::Int,
|
||||
TorchDType::Long,
|
||||
TorchDType::Half,
|
||||
TorchDType::Float,
|
||||
TorchDType::Double,
|
||||
TorchDType::ComplexHalf,
|
||||
TorchDType::ComplexFloat,
|
||||
TorchDType::ComplexDouble,
|
||||
TorchDType::Bool,
|
||||
TorchDType::BFloat16,
|
||||
TorchDType::Uint16,
|
||||
TorchDType::Float8E4m3Fn,
|
||||
TorchDType::Float8E5m2,
|
||||
TorchDType::Float8E4m3Fnuz,
|
||||
TorchDType::Float8E5m2Fnuz,
|
||||
];
|
||||
|
||||
/// Canonical wire code (matches `ScalarType.<name>.value` in Python).
|
||||
#[inline]
|
||||
pub fn code(self) -> u32 {
|
||||
self as u32
|
||||
}
|
||||
|
||||
/// PyTorch schema variant name (e.g. `"LONG"`, `"BFLOAT16"`). Used by the
|
||||
/// parity test to align Rust variants with `ScalarType.<name>`.
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
TorchDType::Unknown => "UNKNOWN",
|
||||
TorchDType::Byte => "BYTE",
|
||||
TorchDType::Char => "CHAR",
|
||||
TorchDType::Short => "SHORT",
|
||||
TorchDType::Int => "INT",
|
||||
TorchDType::Long => "LONG",
|
||||
TorchDType::Half => "HALF",
|
||||
TorchDType::Float => "FLOAT",
|
||||
TorchDType::Double => "DOUBLE",
|
||||
TorchDType::ComplexHalf => "COMPLEXHALF",
|
||||
TorchDType::ComplexFloat => "COMPLEXFLOAT",
|
||||
TorchDType::ComplexDouble => "COMPLEXDOUBLE",
|
||||
TorchDType::Bool => "BOOL",
|
||||
TorchDType::BFloat16 => "BFLOAT16",
|
||||
TorchDType::Uint16 => "UINT16",
|
||||
TorchDType::Float8E4m3Fn => "FLOAT8E4M3FN",
|
||||
TorchDType::Float8E5m2 => "FLOAT8E5M2",
|
||||
TorchDType::Float8E4m3Fnuz => "FLOAT8E4M3FNUZ",
|
||||
TorchDType::Float8E5m2Fnuz => "FLOAT8E5M2FNUZ",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from a wire code. `Err(code)` if the code isn't a known PyTorch
|
||||
/// variant — the caller decides whether to panic with context or fall
|
||||
/// through to a non-PT2 path.
|
||||
pub fn from_code(code: u32) -> Result<Self, u32> {
|
||||
for v in Self::ALL {
|
||||
if v.code() == code {
|
||||
return Ok(*v);
|
||||
}
|
||||
}
|
||||
Err(code)
|
||||
}
|
||||
}
|
||||
|
||||
/// PyTorch dtype → luminal `DType`. `Err(self)` for variants luminal's IR
|
||||
/// doesn't model as first-class types — the narrow ints (`Byte` / `Char` /
|
||||
/// `Short`), the complex family, and the float8 NUZ variants. `DType::U8`,
|
||||
/// `DType::I8`, `DType::I16` exist on the luminal side but the IR has no
|
||||
/// kernels / codegen for them, so we refuse the conversion here rather
|
||||
/// than silently producing a buffer the kernels can't actually run.
|
||||
/// Boundary code panics with the variant name on `Err`; cf.
|
||||
/// `typed_data::from_pytorch_bytes`, `pt2_util::torch_dtype_int_to_luminal`.
|
||||
impl TryFrom<TorchDType> for DType {
|
||||
type Error = TorchDType;
|
||||
fn try_from(t: TorchDType) -> Result<Self, Self::Error> {
|
||||
Ok(match t {
|
||||
TorchDType::Int => DType::Int,
|
||||
TorchDType::Long => DType::I64,
|
||||
TorchDType::Half => DType::F16,
|
||||
TorchDType::Float => DType::F32,
|
||||
TorchDType::Double => DType::F64,
|
||||
TorchDType::Bool => DType::Bool,
|
||||
TorchDType::BFloat16 => DType::Bf16,
|
||||
TorchDType::Float8E4m3Fn => DType::F8E4M3,
|
||||
TorchDType::Float8E5m2 => DType::F8E5M2,
|
||||
TorchDType::Byte
|
||||
| TorchDType::Char
|
||||
| TorchDType::Short
|
||||
| TorchDType::Uint16
|
||||
| TorchDType::Unknown
|
||||
| TorchDType::ComplexHalf
|
||||
| TorchDType::ComplexFloat
|
||||
| TorchDType::ComplexDouble
|
||||
| TorchDType::Float8E4m3Fnuz
|
||||
| TorchDType::Float8E5m2Fnuz => return Err(t),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// luminal `DType` → PyTorch dtype. `Err(dtype)` for luminal-specific
|
||||
/// variants without a first-class PyTorch counterpart — the narrow ints
|
||||
/// (`U8` / `I8` / `I16` / `U16`), the sub-byte / exotic widths (`I4`,
|
||||
/// `U4`, `F6E2M3`, ...), and `TF32`.
|
||||
///
|
||||
/// `TF32` is a compute-mode hint inside luminal, not a storage dtype on
|
||||
/// the PyTorch side (PyTorch has no `torch.tf32`); silently mapping it to
|
||||
/// `Float` would hand PyTorch an f32 buffer that the caller had been
|
||||
/// tracking as TF32 inside luminal. Refuse instead — a real cast to
|
||||
/// `DType::F32` upstream is the explicit way to bridge.
|
||||
impl TryFrom<DType> for TorchDType {
|
||||
type Error = DType;
|
||||
fn try_from(d: DType) -> Result<Self, Self::Error> {
|
||||
Ok(match d {
|
||||
DType::F32 => TorchDType::Float,
|
||||
DType::F64 => TorchDType::Double,
|
||||
DType::F16 => TorchDType::Half,
|
||||
DType::Bf16 => TorchDType::BFloat16,
|
||||
DType::Int => TorchDType::Int,
|
||||
DType::I64 => TorchDType::Long,
|
||||
DType::Bool => TorchDType::Bool,
|
||||
DType::F8E4M3 => TorchDType::Float8E4m3Fn,
|
||||
DType::F8E5M2 => TorchDType::Float8E5m2,
|
||||
_ => return Err(d),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn roundtrip_codes() {
|
||||
for v in TorchDType::ALL {
|
||||
assert_eq!(TorchDType::from_code(v.code()).unwrap(), *v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supported_dtypes_roundtrip() {
|
||||
// Only the variants luminal's IR models as first-class can
|
||||
// roundtrip cleanly. Narrow ints (`U8` / `I8` / `I16` / `U16`)
|
||||
// are intentionally excluded — see the `TryFrom` impls.
|
||||
for d in [
|
||||
DType::F32,
|
||||
DType::F64,
|
||||
DType::F16,
|
||||
DType::Bf16,
|
||||
DType::Int,
|
||||
DType::I64,
|
||||
DType::Bool,
|
||||
] {
|
||||
let t = TorchDType::try_from(d).expect("known DType");
|
||||
let back = DType::try_from(t).expect("known TorchDType");
|
||||
assert_eq!(d, back, "roundtrip mismatch for {d:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn narrow_ints_refuse_conversion() {
|
||||
// Forward (PyTorch → luminal) and reverse (luminal → PyTorch)
|
||||
// both refuse the narrow-int variants; downstream sites translate
|
||||
// the `Err` into a typed panic with the variant name.
|
||||
for t in [TorchDType::Byte, TorchDType::Char, TorchDType::Short] {
|
||||
assert!(DType::try_from(t).is_err(), "expected Err for {t:?}");
|
||||
}
|
||||
for d in [
|
||||
DType::U8,
|
||||
DType::I8,
|
||||
DType::I16,
|
||||
DType::U16,
|
||||
// TF32 is a luminal-internal compute-mode hint, not a PyTorch
|
||||
// storage dtype — refuse to silently alias it as `Float`.
|
||||
DType::TF32,
|
||||
] {
|
||||
assert!(TorchDType::try_from(d).is_err(), "expected Err for {d:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_code_errors() {
|
||||
assert!(TorchDType::from_code(99).is_err());
|
||||
assert!(TorchDType::from_code(14).is_err()); // gap in PT2 numbering
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,40 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, same_expr_with_ranges, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
fn normalize_equal_dims(
|
||||
a: &mut GraphTensor,
|
||||
b: &mut GraphTensor,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..a.shape.len() {
|
||||
let lhs = a.shape.dims[i];
|
||||
let rhs = b.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs, rhs, sym_ranges) {
|
||||
a.shape.dims[i] = canonical;
|
||||
b.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn same_dims(
|
||||
lhs: &[Expression],
|
||||
rhs: &[Expression],
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
lhs.len() == rhs.len()
|
||||
&& lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.all(|(lhs, rhs)| same_expr_with_ranges(*lhs, *rhs, sym_ranges))
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -13,7 +42,18 @@ impl<'a> Translator<'a> {
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let (mut a, mut b) = broadcast_binary(a, b);
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
normalize_equal_dims(&mut a, &mut b, &sym_ranges);
|
||||
let lhs_dims = a.dims();
|
||||
let rhs_dims = b.dims();
|
||||
if !same_dims(&lhs_dims, &rhs_dims, &sym_ranges) {
|
||||
anyhow::bail!(
|
||||
"binary op {} still has mismatched dims after broadcast: lhs={lhs_dims:?} rhs={rhs_dims:?} inputs={:?}",
|
||||
node.target,
|
||||
node.inputs
|
||||
);
|
||||
}
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
@@ -21,6 +61,12 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -32,6 +78,13 @@ impl<'a> Translator<'a> {
|
||||
op: BinaryOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -54,4 +107,47 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / scalar,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_symbolic_scalar_op(
|
||||
&mut self,
|
||||
a: GraphTensor,
|
||||
val: Expression,
|
||||
op: BinaryOp,
|
||||
) -> GraphTensor {
|
||||
match op {
|
||||
BinaryOp::Add => a + val,
|
||||
BinaryOp::Mul => a * val,
|
||||
BinaryOp::Sub => a - val,
|
||||
BinaryOp::Div => a / val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pt2_expr::simplify_expr_with_ranges;
|
||||
|
||||
#[test]
|
||||
fn simplifies_mark_dynamic_slice_shapes_using_lower_bound() {
|
||||
let a = Expression::from('a');
|
||||
let lhs = (a.min(1) + a).min(a + 1) - 1;
|
||||
let rhs = (a.min(1) + a).min(a);
|
||||
let sym_ranges = [(
|
||||
'a',
|
||||
ExprBounds {
|
||||
min: Some(2),
|
||||
max: None,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, &sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, &sym_ranges);
|
||||
|
||||
assert_eq!(lhs_simplified, Expression::from('a'));
|
||||
assert_eq!(rhs_simplified, Expression::from('a'));
|
||||
assert!(same_expr_with_ranges(lhs, rhs, &sym_ranges));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,8 +389,11 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -147,6 +148,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -173,7 +175,11 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let exp = self.get_float_arg(node, 1)?;
|
||||
a.pow(exp as f32)
|
||||
if (exp - 2.0).abs() < f64::EPSILON {
|
||||
a * a
|
||||
} else {
|
||||
a.pow(exp as f32)
|
||||
}
|
||||
}
|
||||
"torch.ops.aten.pow.Tensor_Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -219,6 +225,16 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -236,6 +252,13 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -274,18 +297,27 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
a.cumsum(dim)
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -381,6 +413,17 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -444,6 +487,28 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
mod movement;
|
||||
mod movement_dynamic;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
mod unary;
|
||||
@@ -17,6 +18,7 @@ use anyhow::{Context, Result};
|
||||
use luminal::graph::Graph;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_expr::parse_sympy_expr_with_ranges;
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
@@ -279,13 +281,13 @@ impl<'a> Translator<'a> {
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(ints) = arg.as_ints() {
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v as usize)).collect());
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v)).collect());
|
||||
}
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
return entries
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.with_context(|| format!("Cannot resolve sym_int: {}", s.as_name)),
|
||||
@@ -318,17 +320,13 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn dim_size_to_expr(&self, dim: &DimSize) -> Result<Expression> {
|
||||
match dim {
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
DimSize::Expr(e) => {
|
||||
let sym_name = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
.with_context(|| format!("Cannot parse symbol: {}", e.as_expr.expr_str))?;
|
||||
let c = self
|
||||
.sym_map
|
||||
.sym_to_char
|
||||
.get(&sym_name)
|
||||
.with_context(|| format!("Unknown symbol: {sym_name}"))?;
|
||||
Ok(Expression::from(*c))
|
||||
}
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
DimSize::Expr(e) => self.resolve_expr_value(&e.as_expr).with_context(|| {
|
||||
format!(
|
||||
"Cannot resolve symbolic dimension expression: {}",
|
||||
e.as_expr.expr_str
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,10 +337,9 @@ impl<'a> Translator<'a> {
|
||||
.get("as_expr")
|
||||
.and_then(|e| e.get("expr_str"))
|
||||
.and_then(|s| s.as_str())
|
||||
&& let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
&& let Some(expr) = self.resolve_expr_str(expr_str)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
return Some(expr);
|
||||
}
|
||||
if let Some(hint) = val
|
||||
.get("as_expr")
|
||||
@@ -350,7 +347,7 @@ impl<'a> Translator<'a> {
|
||||
.and_then(|h| h.get("as_int"))
|
||||
.and_then(|v| v.as_i64())
|
||||
{
|
||||
return Some(Expression::from(hint as usize));
|
||||
return Some(Expression::from(hint));
|
||||
}
|
||||
}
|
||||
None
|
||||
@@ -358,21 +355,32 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn resolve_arg_as_expression(&self, arg: &Argument) -> Option<Expression> {
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Some(Expression::from(v as usize));
|
||||
return Some(Expression::from(v));
|
||||
}
|
||||
if let Some(name) = arg.as_sym_int_name() {
|
||||
return self.resolve_sym_int(name);
|
||||
}
|
||||
if let Argument::Expr(e) = arg {
|
||||
if let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = e.as_expr.hint.as_ref().and_then(|h| h.as_int()) {
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
return self.resolve_expr_value(&e.as_expr);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_str(&self, expr_str: &str) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr_str, &self.sym_map.sym_to_char, &self.sym_map.ranges)
|
||||
.or_else(|| {
|
||||
crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
.and_then(|sym| self.sym_map.sym_to_char.get(&sym).copied())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_value(&self, expr: &ExprValue) -> Option<Expression> {
|
||||
self.resolve_expr_str(&expr.expr_str).or_else(|| {
|
||||
expr.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
@@ -11,6 +13,25 @@ const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
fn normalize_concat_dims(
|
||||
lhs: &mut GraphTensor,
|
||||
rhs: &mut GraphTensor,
|
||||
skip_dim: Option<usize>,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..lhs.shape.len() {
|
||||
if Some(i) == skip_dim {
|
||||
continue;
|
||||
}
|
||||
let lhs_dim = lhs.shape.dims[i];
|
||||
let rhs_dim = rhs.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs_dim, rhs_dim, sym_ranges) {
|
||||
lhs.shape.dims[i] = canonical;
|
||||
rhs.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -120,6 +141,47 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -160,8 +222,17 @@ impl<'a> Translator<'a> {
|
||||
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
for t in &tensors[1..] {
|
||||
result = result.concat_along(*t, dim);
|
||||
let mut next = *t;
|
||||
normalize_concat_dims(&mut result, &mut next, Some(dim), &sym_ranges);
|
||||
|
||||
let lhs_axis = result.dims()[dim];
|
||||
let rhs_axis = next.dims()[dim];
|
||||
let mut lhs_padded = result.pad_along(0, rhs_axis, dim, 0.);
|
||||
let mut rhs_padded = next.pad_along(lhs_axis, 0, dim, 0.);
|
||||
normalize_concat_dims(&mut lhs_padded, &mut rhs_padded, None, &sym_ranges);
|
||||
result = lhs_padded + rhs_padded;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -235,7 +306,11 @@ impl<'a> Translator<'a> {
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[first_non_none_dim] = idx_dim_size;
|
||||
expanded.shape.expand(target);
|
||||
return Ok(source.gather_elements(expanded, first_non_none_dim));
|
||||
return Ok(super::movement_dynamic::pt2_gather_elements(
|
||||
source,
|
||||
expanded,
|
||||
first_non_none_dim,
|
||||
));
|
||||
}
|
||||
} else {
|
||||
bail!(
|
||||
@@ -333,6 +408,17 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
@@ -344,7 +430,12 @@ impl<'a> Translator<'a> {
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
let result = super::movement_dynamic::pt2_gather_elements(a, normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -353,7 +444,12 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
let src = self.get_input_tensor(node, 3)?;
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
Ok(super::movement_dynamic::pt2_scatter_elements(
|
||||
a,
|
||||
indices.cast(DType::Int),
|
||||
src,
|
||||
dim,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -376,7 +472,12 @@ impl<'a> Translator<'a> {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
Ok(super::movement_dynamic::pt2_scatter_elements(
|
||||
a,
|
||||
indices.cast(DType::Int),
|
||||
value,
|
||||
dim,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -421,7 +522,7 @@ impl<'a> Translator<'a> {
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
Ok(super::movement_dynamic::pt2_scatter_nd(a, indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
}
|
||||
|
||||
231
crates/luminal_python/rust/src/translator/movement_dynamic.rs
Normal file
231
crates/luminal_python/rust/src/translator/movement_dynamic.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
//! Symbolic-dim-safe `gather_elements` / `scatter_elements` / `scatter_nd`
|
||||
//! lowerings for the PT2 translator.
|
||||
//!
|
||||
//! The luminal-core versions in `luminal::frontend::movement` require
|
||||
//! concrete shape dims — they call `d.to_usize().expect(...)` on every
|
||||
//! input dim and panic at translate-time when `torch.compile` hands us a
|
||||
//! batch dim, sequence-length dim, or any other dynamic dim. PT2's whole
|
||||
//! point is dynamic shapes, so we re-implement the same three ops here
|
||||
//! using `Expression`-typed shape arithmetic and only call luminal-core
|
||||
//! primitives that already accept `Expression`s (`Graph::constant`,
|
||||
//! `Graph::iota`, `flatten_strides`, `ShapeTracker::new(Vec<Expression>)`,
|
||||
//! `expand_dim`, `expand_rhs`, `flatten`, `slice_along`, `squeeze`,
|
||||
//! `cast`, `scatter`, `gather`).
|
||||
//!
|
||||
//! Every shape product flows through `crate::dim_arith::product_of_dims`
|
||||
//! so the `Expression`s we build are canonical: two callers that produce
|
||||
//! the same logical dim via differently-ordered multiplications end up
|
||||
//! with byte-identical `Expression`s. Without this, downstream dim-equality
|
||||
//! asserts in luminal-core's `Add` / `Sub` (see `src/frontend/binary.rs`)
|
||||
//! panic on `a*8` ≠ `8*a` after these helpers feed into broadcast paths.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::dim_arith::product_of_dims;
|
||||
|
||||
/// Row-major strides as `Expression`s. `stride[i] = prod(dims[i+1..])`.
|
||||
fn row_major_strides(dims: &[Expression]) -> Vec<Expression> {
|
||||
let rank = dims.len();
|
||||
(0..rank)
|
||||
.map(|i| product_of_dims(dims[i + 1..].iter().copied()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build the additive non-axis contribution to a flat index over a
|
||||
/// rank-`rank` output of shape `out_shape`. The axis dim contributes
|
||||
/// 0; every other dim `d` contributes `iota_d * strides[d]`. Materialised
|
||||
/// via one `Graph::iota` call with `flatten_strides(out_shape, axis_exprs)`
|
||||
/// — same pattern luminal core uses, just with `Expression` throughout.
|
||||
fn non_axis_flat(
|
||||
graph: &mut Graph,
|
||||
out_shape: &[Expression],
|
||||
strides: &[Expression],
|
||||
axis: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = out_shape.len();
|
||||
let axis_exprs: Vec<Expression> = (0..rank)
|
||||
.map(|d| {
|
||||
if d == axis {
|
||||
Expression::from(0)
|
||||
} else {
|
||||
Expression::from('z') * strides[d]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
graph.iota(flatten_strides(out_shape, &axis_exprs), out_shape.to_vec())
|
||||
}
|
||||
|
||||
/// Wrap negative axis indices into `[0, axis_dim)`. Equivalent to
|
||||
/// `if idx < 0 { idx + axis_dim } else { idx }` in tensor form.
|
||||
fn normalize_negative_index(indices: GraphTensor, axis_dim: Expression) -> GraphTensor {
|
||||
let idx_f32 = indices.cast(DType::F32);
|
||||
let zero = idx_f32
|
||||
.graph()
|
||||
.constant_float(0.0)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let adj = idx_f32
|
||||
.graph()
|
||||
.constant(axis_dim)
|
||||
.cast(DType::F32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_neg = idx_f32.lt(zero).cast(DType::F32);
|
||||
(idx_f32 + (is_neg * adj)).cast(DType::Int)
|
||||
}
|
||||
|
||||
/// Translator-local `gather_elements` that accepts symbolic shape dims.
|
||||
/// Mirrors `GraphTensor::gather_elements` semantics but uses
|
||||
/// `Expression`-typed shape arithmetic and only calls symbol-safe
|
||||
/// luminal-core primitives.
|
||||
///
|
||||
/// `output[i0,..,ik] = self[i0,..,i_{axis-1}, indices[i0,..,ik], i_{axis+1},..,ik]`
|
||||
pub fn pt2_gather_elements(data: GraphTensor, indexes: GraphTensor, axis: usize) -> GraphTensor {
|
||||
let dims = data.dims();
|
||||
let out_shape: Vec<Expression> = indexes.dims();
|
||||
let strides = row_major_strides(&dims);
|
||||
|
||||
let idx_normalized = normalize_negative_index(indexes, dims[axis]);
|
||||
let non_axis_flat = non_axis_flat(data.graph(), &out_shape, &strides, axis);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(strides[axis])
|
||||
.expand_rhs(idx_normalized.shape);
|
||||
let flat_idx = non_axis_flat + idx_normalized * stride_tensor;
|
||||
|
||||
data.gather(flat_idx)
|
||||
}
|
||||
|
||||
/// Translator-local `scatter_elements` that accepts symbolic shape dims.
|
||||
/// Same semantics as `GraphTensor::scatter_elements`.
|
||||
pub fn pt2_scatter_elements(
|
||||
data: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
updates: GraphTensor,
|
||||
axis: usize,
|
||||
) -> GraphTensor {
|
||||
let data_dims = data.dims();
|
||||
let idx_shape: Vec<Expression> = indices.dims();
|
||||
let strides = row_major_strides(&data_dims);
|
||||
|
||||
let idx_normalized = normalize_negative_index(indices, data_dims[axis]);
|
||||
let non_axis_flat = non_axis_flat(data.graph(), &idx_shape, &strides, axis);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(strides[axis])
|
||||
.expand_rhs(idx_normalized.shape);
|
||||
let flat_dest = non_axis_flat + idx_normalized * stride_tensor;
|
||||
|
||||
let flat_dest_1d = flat_dest.flatten();
|
||||
let flat_updates = updates.flatten();
|
||||
let flat_data = data.flatten();
|
||||
|
||||
let output_flat = flat_updates.scatter(flat_dest_1d, flat_data);
|
||||
|
||||
// View-only reshape back to data shape; the buffer is already laid
|
||||
// out row-major from the scatter, so swapping the tracker is safe.
|
||||
let mut result = output_flat;
|
||||
result.shape = ShapeTracker::new(data_dims);
|
||||
result
|
||||
}
|
||||
|
||||
/// Translator-local `scatter_nd` that accepts symbolic shape dims.
|
||||
/// Mirrors `GraphTensor::scatter_nd` semantics.
|
||||
pub fn pt2_scatter_nd(
|
||||
data: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
updates: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let indices = indices.cast(DType::Int);
|
||||
let data_dims = data.dims();
|
||||
let data_rank = data_dims.len();
|
||||
let idx_dims = indices.dims();
|
||||
let idx_rank = idx_dims.len();
|
||||
|
||||
// The last dim of indices is the index width K — it must be
|
||||
// concrete at translate-time because it controls how many
|
||||
// contribution terms we build statically. HuggingFace's MoE
|
||||
// accumulator (the path that brought us here via `index_put`)
|
||||
// always passes a literal; non-HF callers with a SymInt K would
|
||||
// need a different lowering.
|
||||
let k = idx_dims[idx_rank - 1]
|
||||
.to_usize()
|
||||
.expect("scatter_nd: indices innermost dim (K) must be concrete");
|
||||
assert!(k <= data_rank, "scatter_nd: K must be <= data rank");
|
||||
|
||||
// Batch shape = indices shape without last dim.
|
||||
let batch_shape: Vec<Expression> = idx_dims[..idx_rank - 1].to_vec();
|
||||
let batch_numel = product_of_dims(batch_shape.iter().copied());
|
||||
|
||||
// Trailing shape = data_shape[K..]
|
||||
let trailing_shape: Vec<Expression> = data_dims[k..].to_vec();
|
||||
let trailing_numel = product_of_dims(trailing_shape.iter().copied());
|
||||
|
||||
let data_strides = row_major_strides(&data_dims);
|
||||
|
||||
// Flatten batch dims of indices to [batch_numel, K] via view reshape.
|
||||
let mut indices_flat = indices;
|
||||
if idx_rank > 2 {
|
||||
indices_flat.shape = ShapeTracker::new(vec![batch_numel, Expression::from(k)]);
|
||||
}
|
||||
|
||||
let mut flat_base: Option<GraphTensor> = None;
|
||||
for (k_dim, stride) in data_strides.iter().copied().enumerate().take(k) {
|
||||
let idx_k = indices_flat.slice_along(k_dim..k_dim + 1, indices_flat.dims().len() - 1);
|
||||
let idx_k = idx_k.squeeze(idx_k.dims().len() - 1);
|
||||
|
||||
let stride_tensor = data.graph().constant(stride).expand_rhs(idx_k.shape);
|
||||
let contribution = idx_k * stride_tensor;
|
||||
|
||||
flat_base = Some(match flat_base {
|
||||
Some(fb) => fb + contribution,
|
||||
None => contribution,
|
||||
});
|
||||
}
|
||||
let flat_base = flat_base.unwrap();
|
||||
|
||||
// Trailing-numel concreteness drives whether we need the expand-and-fold
|
||||
// path. If trailing_shape is empty OR its numel collapses to 1, the flat
|
||||
// base is already the full destination index.
|
||||
let trailing_is_unit = trailing_shape.is_empty() || trailing_numel.to_usize() == Some(1);
|
||||
let mut full_flat_dest = if trailing_is_unit {
|
||||
flat_base
|
||||
} else {
|
||||
let mut base_expanded = flat_base.expand_dim(1, trailing_numel);
|
||||
|
||||
let trailing_rank = trailing_shape.len();
|
||||
for (ti, d) in (k..data_rank).enumerate() {
|
||||
let ar = data.graph().arange(data_dims[d]);
|
||||
let mut ar_shaped = ar;
|
||||
for _ in ti + 1..trailing_rank {
|
||||
let n = ar_shaped.dims().len();
|
||||
ar_shaped = ar_shaped.expand_dim(n, 1);
|
||||
}
|
||||
for _ in 0..ti {
|
||||
ar_shaped = ar_shaped.expand_dim(0, 1);
|
||||
}
|
||||
ar_shaped.shape.expand(trailing_shape.clone());
|
||||
let mut ar_flat = ar_shaped;
|
||||
ar_flat.shape = ShapeTracker::new(vec![trailing_numel]);
|
||||
ar_flat = ar_flat.expand_dim(0, batch_numel);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(data_strides[d])
|
||||
.expand_rhs(ar_flat.shape);
|
||||
base_expanded += ar_flat * stride_tensor;
|
||||
}
|
||||
base_expanded
|
||||
};
|
||||
|
||||
full_flat_dest = full_flat_dest.flatten();
|
||||
|
||||
let flat_updates = updates.flatten();
|
||||
let flat_data = data.flatten();
|
||||
|
||||
let output_flat = flat_updates.scatter(full_flat_dest, flat_data);
|
||||
|
||||
let mut result = output_flat;
|
||||
result.shape = ShapeTracker::new(data_dims);
|
||||
result
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user