mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
33 Commits
strided-in
...
worktree-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
028c7cb484 | ||
|
|
3a3cd04958 | ||
|
|
d21f55ed78 | ||
|
|
b2bd91f594 | ||
|
|
35ebf0c7c7 | ||
|
|
dea8a3e7aa | ||
|
|
439648a649 | ||
|
|
2d858829c7 | ||
|
|
6673d1d935 | ||
|
|
65f3cceaa1 | ||
|
|
f925431ad5 | ||
|
|
33ff774d62 | ||
|
|
ea04149691 | ||
|
|
aaeefeee8c | ||
|
|
0b917abd03 | ||
|
|
d9a5fcfe9f | ||
|
|
64eb2641fd | ||
|
|
dbdb31523c | ||
|
|
da84f1a5a3 | ||
|
|
322b85fd95 | ||
|
|
a590942274 | ||
|
|
cfbdef2569 | ||
|
|
de2e820f48 | ||
|
|
30f067fa94 | ||
|
|
ee0456d5bc | ||
|
|
b6403ec1be | ||
|
|
bfbefc2fe1 | ||
|
|
0e2ea24e46 | ||
|
|
d03a41ec96 | ||
|
|
8aa9f14741 | ||
|
|
1460e6a3ee | ||
|
|
a138db0236 | ||
|
|
6a17670244 |
4
.github/workflows/modal-examples.yml
vendored
4
.github/workflows/modal-examples.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 70
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
|
||||
example: [llama, gemma, qwen, qwen3_moe]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
|
||||
2
.github/workflows/test-core.yml
vendored
2
.github/workflows/test-core.yml
vendored
@@ -21,4 +21,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
|
||||
2
.github/workflows/test-cuda.yml
vendored
2
.github/workflows/test-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/test-metal.yml
vendored
2
.github/workflows/test-metal.yml
vendored
@@ -16,4 +16,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
4
.github/workflows/test-python-cuda.yml
vendored
4
.github/workflows/test-python-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 7200 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
2
.github/workflows/test-python-native.yml
vendored
2
.github/workflows/test-python-native.yml
vendored
@@ -23,6 +23,6 @@ jobs:
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --profile release
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -37,3 +37,9 @@ __pycache__/
|
||||
dist/
|
||||
build/
|
||||
uv.lock
|
||||
|
||||
# TTFT benchmark SQLite database (per-machine state)
|
||||
benchmarks/ttft/bench.db
|
||||
benchmarks/ttft/bench.db-journal
|
||||
benchmarks/ttft/bench.db-wal
|
||||
benchmarks/ttft/bench.db-shm
|
||||
|
||||
@@ -25,7 +25,6 @@ generational-box = "0.5.6"
|
||||
serde_json = "1.0.140"
|
||||
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-reports = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
|
||||
tracing = "0.1.43"
|
||||
paste = "1.0.15"
|
||||
|
||||
117
benchmarks/ttft/bench_python_baseline.py
Normal file
117
benchmarks/ttft/bench_python_baseline.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Pure HuggingFace/PyTorch TTFT + TPOT bench. Prints a JSON line on stdout.
|
||||
|
||||
Measures:
|
||||
TTFT — sum of single-token forward-pass durations over the prompt, using
|
||||
a StaticCache. Methodology matches bench_python_luminal.py and the
|
||||
rust path so the cross-path comparison is apples-to-apples.
|
||||
TPOT — average time per output token during KV-cache greedy decode.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, measure_tpot, static_cache_config
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--dtype", default="float32", choices=["float32", "bfloat16", "float16"])
|
||||
ap.add_argument("--decode-tokens", type=int, default=50,
|
||||
help="Number of tokens to generate for TPOT measurement (0 = skip).")
|
||||
ap.add_argument("--max-cache-len", type=int, default=256,
|
||||
help="StaticCache max sequence length.")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def measure_ttft() -> float:
|
||||
"""Sum of per-token forward-pass durations over prompt_tokens steps."""
|
||||
kv = make_cache()
|
||||
# Eager init at position 0 to satisfy StaticCache.lazy_initialization.
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([pos], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
result = {
|
||||
"path": "python_baseline",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"note": "sequential per-token, StaticCache KV cache",
|
||||
}
|
||||
|
||||
if args.decode_tokens > 0:
|
||||
tpot_samples_ms = measure_tpot(model, input_ids, device, args.decode_tokens)
|
||||
tpot_ms = sum(tpot_samples_ms) / len(tpot_samples_ms)
|
||||
result["decode_tokens"] = args.decode_tokens
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["tpot_ms_samples"] = tpot_samples_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
benchmarks/ttft/bench_python_luminal.py
Normal file
196
benchmarks/ttft/bench_python_luminal.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Python -> Luminal TTFT + TPOT bench via torch.compile(backend=luminal_backend).
|
||||
|
||||
Methodology mirrors examples/llama (the Rust path):
|
||||
- One eager prefill step initialises the StaticCache (required by transformers'
|
||||
StaticCache.lazy_initialization) before compilation.
|
||||
- TTFT: run one forward pass per prompt token sequentially, each advancing
|
||||
cache_position by 1; sum durations.
|
||||
- TPOT: run --decode-tokens more single-token passes; average durations.
|
||||
- StaticCache pre-allocates K/V buffers up to max_cache_len; no growing allocation.
|
||||
|
||||
Prints a BENCH_RESULT JSON line on stdout.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, static_cache_config
|
||||
from luminal import luminal_backend
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument(
|
||||
"--search-iters",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Egraph search iterations (matches examples/llama default of 500).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--decode-tokens",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Tokens to generate for TPOT measurement (0 = skip TPOT).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-cache-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="StaticCache max sequence length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
help="Torch dtype for model + StaticCache.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Step 0: run ONE eager prefill to initialise the cache tensors and call
|
||||
# mark_static_address (required by transformers' StaticCache before compile).
|
||||
cache = make_cache()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=cache, cache_position=torch.tensor([0], device=device))
|
||||
|
||||
# Compile for a single-token input — same graph is reused for every step.
|
||||
# Compilation happens on the first call after the eager init above.
|
||||
t0 = time.perf_counter()
|
||||
compiled = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": args.search_iters},
|
||||
)
|
||||
cache_position = torch.tensor([1], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=cache, cache_position=cache_position)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
compile_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
gc.collect()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def one_step(pos: int, kv_cache):
|
||||
cache_pos = torch.tensor([pos], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=kv_cache, cache_position=cache_pos)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def measure_ttft():
|
||||
"""Sum of per-token forward-pass durations over prompt_tokens steps.
|
||||
|
||||
Uses a fresh cache so each TTFT measurement is independent.
|
||||
"""
|
||||
kv = make_cache()
|
||||
# Eager init for this fresh cache (required before compiled can run on it).
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv, cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
# Step 0 was the eager init above; measure from step 1 to prompt_tokens.
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
one_step(pos, kv)
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
def measure_tpot(n, start_pos: int):
|
||||
"""Average single-token forward-pass duration over n decode steps."""
|
||||
kv = make_cache()
|
||||
# Eager init
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv, cache_position=torch.tensor([0], device=device))
|
||||
# One warmup step.
|
||||
one_step(1, kv)
|
||||
step_times_ms = []
|
||||
for i in range(n):
|
||||
pos = start_pos + i
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
one_step(pos, kv)
|
||||
step_times_ms.append((time.perf_counter() - t0) * 1000.0)
|
||||
return step_times_ms
|
||||
|
||||
# Warmups before timing TTFT (all run after compilation is complete).
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
tpot_ms_samples = []
|
||||
if args.decode_tokens > 0:
|
||||
tpot_ms_samples = measure_tpot(args.decode_tokens, start_pos=prompt_tokens)
|
||||
|
||||
tpot_ms = sum(tpot_ms_samples) / len(tpot_ms_samples) if tpot_ms_samples else None
|
||||
throughput_tps = (1000.0 / tpot_ms) if tpot_ms else None
|
||||
|
||||
result = {
|
||||
"path": "python_luminal",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"compile_ms": compile_ms,
|
||||
"search_iters": args.search_iters,
|
||||
"decode_tokens": args.decode_tokens if args.decode_tokens > 0 else None,
|
||||
"tpot_ms": tpot_ms,
|
||||
"tpot_ms_samples": tpot_ms_samples,
|
||||
"throughput_tps": throughput_tps,
|
||||
"note": "sequential per-token, StaticCache KV cache",
|
||||
}
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
138
benchmarks/ttft/bench_python_torch_compile.py
Normal file
138
benchmarks/ttft/bench_python_torch_compile.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Vanilla torch.compile TTFT + TPOT bench. Prints a JSON line on stdout.
|
||||
|
||||
Uses the default inductor backend (torch.compile without a custom backend).
|
||||
TTFT uses sequential per-token prefill with a StaticCache so the methodology
|
||||
matches bench_python_baseline.py, bench_python_luminal.py, and the rust path.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, measure_tpot, static_cache_config
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--dtype", default="float32", choices=["float32", "bfloat16", "float16"])
|
||||
ap.add_argument(
|
||||
"--decode-tokens", type=int, default=50,
|
||||
help="Number of tokens to generate for TPOT measurement (0 = skip).",
|
||||
)
|
||||
ap.add_argument("--max-cache-len", type=int, default=256,
|
||||
help="StaticCache max sequence length.")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Eager init on the uncompiled model so the StaticCache buffers get
|
||||
# registered (mark_static_address) before torch.compile traces them.
|
||||
init_cache = make_cache()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=init_cache,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
|
||||
compiled = torch.compile(model)
|
||||
|
||||
# First compiled call triggers JIT compilation; time it as compile_ms.
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=init_cache,
|
||||
cache_position=torch.tensor([1], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
compile_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
def measure_ttft() -> float:
|
||||
"""Sum of per-token compiled-forward durations over prompt_tokens steps."""
|
||||
kv = make_cache()
|
||||
# Fresh cache needs eager init via the uncompiled model first.
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([pos], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
result = {
|
||||
"path": "python_torch_compile",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"compile_ms": compile_ms,
|
||||
"note": "sequential per-token, StaticCache KV cache (torch.compile inductor)",
|
||||
}
|
||||
|
||||
if args.decode_tokens > 0:
|
||||
tpot_samples_ms = measure_tpot(compiled, input_ids, device, args.decode_tokens)
|
||||
tpot_ms = sum(tpot_samples_ms) / len(tpot_samples_ms)
|
||||
result["decode_tokens"] = args.decode_tokens
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["tpot_ms_samples"] = tpot_samples_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
94
benchmarks/ttft/bench_utils.py
Normal file
94
benchmarks/ttft/bench_utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Shared helpers for the Python benchmark scripts."""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _CfgWithoutKvShared:
|
||||
"""Wrapper that hides `num_kv_shared_layers` from a HF config.
|
||||
|
||||
transformers 5.6 has a bug in StaticCache.__init__:
|
||||
if hasattr(config, "num_kv_shared_layers"):
|
||||
layer_types = layer_types[: -config.num_kv_shared_layers]
|
||||
For configs where the attribute is 0 (e.g. Gemma-4), `[:-0]` returns an
|
||||
empty list, leaving StaticCache with zero layer slots, and the LM's
|
||||
first `past_key_values.update(..., layer_idx=0)` raises IndexError.
|
||||
|
||||
This wrapper makes `hasattr(...)` return False so the bad branch never
|
||||
fires. Used via `static_cache_config(config)` below.
|
||||
"""
|
||||
__slots__ = ("_inner",)
|
||||
|
||||
def __init__(self, inner):
|
||||
object.__setattr__(self, "_inner", inner)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "num_kv_shared_layers":
|
||||
raise AttributeError(name)
|
||||
return getattr(self._inner, name)
|
||||
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
return _CfgWithoutKvShared(self._inner.get_text_config(*args, **kwargs))
|
||||
|
||||
|
||||
def static_cache_config(config):
|
||||
"""Return a config suitable for `StaticCache(config=..., ...)`.
|
||||
|
||||
Two normalizations:
|
||||
1. Multimodal wrappers (Gemma4ForConditionalGeneration, ...) nest the
|
||||
actual LM config under `.text_config`. Pass that, not the wrapper,
|
||||
so layer/head counts match the inner LM.
|
||||
2. If the resulting config has `num_kv_shared_layers == 0`, wrap it to
|
||||
hide the attribute (works around the transformers 5.6 slice bug).
|
||||
"""
|
||||
cfg = getattr(config, "text_config", config)
|
||||
if getattr(cfg, "num_kv_shared_layers", None) == 0:
|
||||
cfg = _CfgWithoutKvShared(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def encode_prompt(tokenizer, prompt: str, device):
|
||||
"""Tokenize prompt using chat template if available, falling back to raw tokenization."""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
try:
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
except (ValueError, AttributeError):
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
if hasattr(encoded, "input_ids"):
|
||||
return encoded.input_ids.to(device)
|
||||
if isinstance(encoded, dict):
|
||||
return encoded["input_ids"].to(device)
|
||||
return encoded.to(device)
|
||||
|
||||
|
||||
def measure_tpot(model, input_ids, device, decode_tokens: int) -> list[float]:
|
||||
"""Prefill once with KV cache, then time each subsequent single-token decode step."""
|
||||
with torch.no_grad():
|
||||
out = model(input_ids, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
out = model(next_id, past_key_values=past, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
step_times_ms = []
|
||||
for _ in range(decode_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = model(next_id, past_key_values=past, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
step_times_ms.append((time.perf_counter() - t0) * 1000.0)
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
return step_times_ms
|
||||
92
benchmarks/ttft/benchmarks.toml
Normal file
92
benchmarks/ttft/benchmarks.toml
Normal file
@@ -0,0 +1,92 @@
|
||||
[ur_test]
|
||||
models = ["llama-8b", "qwen3-4b", "gemma3-4b", "gemma4-moe", "qwen3-moe"]
|
||||
# 3-point sweep (low/mid/high). The previous list [5, 10, 20, 50, 100, 500]
|
||||
# spent ~62 extra minutes on s=5/s=20/s=50 with little additional information.
|
||||
search_sweep_iters = [10, 100, 500]
|
||||
|
||||
[configs.llama-8b]
|
||||
model = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
rust_package = "llama"
|
||||
search_iters = 500
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 50
|
||||
# On-disk weights are bf16-majority. fp32 upcast doubled python_luminal's
|
||||
# egglog Search peak past the 525 GB unified pool and triggered SIGKILLs on
|
||||
# gemma3-4b (and same risk here). bf16 matches rust's load path.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.as_fast_as_possible]
|
||||
prompt = "The"
|
||||
search_iters = 1
|
||||
iters = 1
|
||||
warmups = 0
|
||||
decode_tokens = 5
|
||||
|
||||
[configs.qwen3-4b]
|
||||
model = "Qwen/Qwen3-4B"
|
||||
rust_package = "qwen"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# bf16-majority on-disk; see llama-8b note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.gemma3-4b]
|
||||
model = "unsloth/gemma-3-4b-it"
|
||||
rust_package = "gemma"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# bf16-majority on-disk; see llama-8b note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.gemma4-moe]
|
||||
model = "google/gemma-4-26B-A4B"
|
||||
rust_package = "gemma4_moe"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# 26B params at fp32 = 104 GB → OOM on a 94 GB GPU. Use bf16 (matches the
|
||||
# on-disk safetensors dtype) so the python paths can actually load.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.qwen3-moe]
|
||||
model = "Qwen/Qwen3-30B-A3B"
|
||||
rust_package = "qwen3_moe"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# 30B params at fp32 = 120 GB → OOM. See gemma4-moe note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.llama-8b-const]
|
||||
model = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
rust_package = "llama"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 500
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
|
||||
[configs.qwen3-4b-const]
|
||||
model = "Qwen/Qwen3-4B"
|
||||
rust_package = "qwen"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
|
||||
[configs.gemma3-4b-const]
|
||||
model = "unsloth/gemma-3-4b-it"
|
||||
rust_package = "gemma"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
610
benchmarks/ttft/dashboard.html
Normal file
610
benchmarks/ttft/dashboard.html
Normal file
@@ -0,0 +1,610 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal · Benchmark Dashboard</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Geist:wght@300;400;500;600&family=Geist+Mono:wght@300;400;500&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
html { -webkit-font-smoothing: antialiased; scroll-behavior: smooth; }
|
||||
|
||||
body {
|
||||
font-family: 'Geist', system-ui, sans-serif;
|
||||
background: #030712;
|
||||
color: #d7d8d9;
|
||||
min-height: 100vh;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
/* ── NAV ── */
|
||||
nav {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 50;
|
||||
height: 56px;
|
||||
background: rgba(8, 15, 17, 0.92);
|
||||
backdrop-filter: blur(8px);
|
||||
-webkit-backdrop-filter: blur(8px);
|
||||
border-bottom: 1px solid #2d3335;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 24px;
|
||||
gap: 0;
|
||||
}
|
||||
.nav-brand {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
letter-spacing: 0.05em;
|
||||
color: #2faa6e;
|
||||
text-decoration: none;
|
||||
}
|
||||
.nav-dot {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background: #2faa6e;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
animation: pulse-glow 2s ease-in-out infinite;
|
||||
}
|
||||
.nav-sep {
|
||||
color: #2d3335;
|
||||
margin: 0 14px;
|
||||
font-size: 18px;
|
||||
font-weight: 300;
|
||||
}
|
||||
.nav-page {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}
|
||||
|
||||
@keyframes pulse-glow {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.35; }
|
||||
}
|
||||
|
||||
/* ── MAIN ── */
|
||||
main {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 40px 24px 80px;
|
||||
}
|
||||
|
||||
/* ── PAGE HEADER ── */
|
||||
.page-header {
|
||||
margin-bottom: 40px;
|
||||
padding-bottom: 32px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
}
|
||||
.page-eyebrow {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.page-title {
|
||||
font-size: 30px;
|
||||
font-weight: 500;
|
||||
letter-spacing: -0.025em;
|
||||
color: #d7d8d9;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.page-meta {
|
||||
font-size: 14px;
|
||||
color: #7e8385;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.meta-sep {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
color: #2d3335;
|
||||
margin: 0 10px;
|
||||
}
|
||||
.meta-val {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 13px;
|
||||
color: #5b5f61;
|
||||
}
|
||||
|
||||
/* ── LEGEND STRIP ── */
|
||||
.legend-strip {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
.legend-pill {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #a1a4a5;
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
padding: 4px 10px;
|
||||
}
|
||||
.legend-swatch {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── SECTIONS ── */
|
||||
section { margin-bottom: 48px; }
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 10px;
|
||||
margin-bottom: 16px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.section-eyebrow {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #404647;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 18px;
|
||||
font-weight: 500;
|
||||
color: #d7d8d9;
|
||||
letter-spacing: -0.01em;
|
||||
}
|
||||
.section-title .unit {
|
||||
color: #7e8385;
|
||||
font-weight: 400;
|
||||
}
|
||||
.section-tag {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
background: #162322;
|
||||
border: 1px solid #1c372e;
|
||||
padding: 2px 8px;
|
||||
border-radius: 2px;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
/* ── CHART GRID ── */
|
||||
.chart-grid {
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}
|
||||
.chart-card {
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
transition: border-color 150ms;
|
||||
min-width: 0;
|
||||
}
|
||||
.chart-card:hover { border-color: #404647; }
|
||||
.chart-card-header {
|
||||
padding: 10px 14px 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.model-tag {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.06em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}
|
||||
|
||||
/* ── FOOTER ── */
|
||||
footer {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px 24px;
|
||||
border-top: 1px solid #1c2225;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.section-divider {
|
||||
border: none;
|
||||
border-top: 1px solid #1c2225;
|
||||
margin: 8px 0 40px;
|
||||
}
|
||||
.sweep-hint {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.chart-grid { grid-template-columns: 1fr !important; }
|
||||
.page-title { font-size: 22px; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<nav>
|
||||
<a class="nav-brand" href="https://luminal.com">
|
||||
<span class="nav-dot"></span>luminal
|
||||
</a>
|
||||
<span class="nav-sep">/</span>
|
||||
<span class="nav-page">benchmarks</span>
|
||||
</nav>
|
||||
|
||||
<main>
|
||||
|
||||
<header class="page-header">
|
||||
<p class="page-eyebrow">performance · time-series</p>
|
||||
<h1 class="page-title">Benchmark Dashboard</h1>
|
||||
<div class="page-meta">
|
||||
<span>Last updated</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">May 01, 2026 · 18:56</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">1 run in history</span>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="legend-strip">
|
||||
<div class="legend-pill"><span class="legend-swatch" style="background:#5b5f61"></span>HF Baseline</div><div class="legend-pill"><span class="legend-swatch" style="background:#3b82f6"></span>torch.compile</div><div class="legend-pill"><span class="legend-swatch" style="background:#a855f7"></span>luminal backend</div><div class="legend-pill"><span class="legend-swatch" style="background:#e8855a"></span>Rust (luminal)</div>
|
||||
</div>
|
||||
|
||||
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">TTFT <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Time to first token (ms)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [705.9654394979589], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [307.66548847896047], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [461.48114453535527], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [1026.86], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [869.2860195587855], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [298.27259748708457], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [485.3892414830625], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [398.58], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [951.1196144158021], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [300.9451600664761], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [404.43], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [837.3980740143452], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [245.510076492792], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [1565.540504961973], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [460.077923577046], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [21002.791983017232], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [662.07], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">TPOT <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Time per output token (ms)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [34.15271903970279], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [171.7862353892997], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [23.078908618772402], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [51.64], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [47.71483448566869], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [468.56868775503244], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [26.90318431414198], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [40.62], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [52.498737201676704], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [2197.426627812092], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [38.99], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [83.64427039632574], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [654.9649795080768], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [84.527321747737], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [753.0061075551203], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [1166.8824461026816], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [60.08], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">Time to Search <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Search time (sec)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [18.760145067994017], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [95.96263545705006], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [84.45343], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [4.680963660997804], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [45.345814052037895], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [19.92977], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [26.649526304972824], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [156.84164], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [38.81582092499593], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [8.341281775035895], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [111.70731823903043], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [80.83241000000001], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<hr class='section-divider'>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">TTFT <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [470.7036415056791, 460.72837291285396, 472.43661794345826], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [751.03, 1038.34, 453.16], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [465.02652901108377, 465.9317950136028, 495.75577257201076], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [398.44, 390.08, 559.29], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [388.19, 436.49, 386.13], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [21002.663500519702, 21018.686580006033, 21034.366824431345], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [656.7, 540.37, 542.34], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">TPOT <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [23.540849717101082, 23.101884137140587, 23.610779400914907], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [38.2, 51.92, 24.09], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [25.875402649398893, 25.884080055402592, 27.492373346467502], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [40.64, 39.98, 55.37], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [37.47, 41.95, 37.25], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [1166.6714247548953, 1167.2746865515364, 1168.7990181031637], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [59.6, 48.79, 48.88], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">Time to Search <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [28.428826077957638, 43.57440591201885, 95.52432684396626], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [15.14307, 30.12727, 84.87889], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [37.92102829599753, 54.08867314597592, 118.29659596900456], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [12.448030000000001, 27.06796, 81.89342], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [102.18644, 186.34269, 498.48983000000004], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [93.47603664599592, 132.266081985028, 298.05094401398674], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [25.48138, 47.5342, 134.79345], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
</main>
|
||||
|
||||
<footer>
|
||||
<span>luminal · benchmark dashboard</span>
|
||||
<span>generated May 01, 2026 · 18:56</span>
|
||||
</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
242
benchmarks/ttft/db.py
Normal file
242
benchmarks/ttft/db.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""SQLite persistence for TTFT/TPOT benchmark runs.
|
||||
|
||||
Two tables:
|
||||
runs — one row per orchestrator invocation
|
||||
results — many rows per run, one per (path, config) combination
|
||||
|
||||
`results` carries every field that today's BENCH_RESULT JSON record carries.
|
||||
Per-iteration sample arrays (`ttft_ms_samples`, `tpot_ms_samples`) are kept as
|
||||
JSON TEXT — they're archival, no consumer aggregates over them.
|
||||
|
||||
The default DB path is benchmarks/ttft/bench.db (gitignored). Schema is
|
||||
created lazily on first connect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
DEFAULT_DB_PATH = BENCH_DIR / "bench.db"
|
||||
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
git_commit TEXT,
|
||||
git_branch TEXT,
|
||||
gpu_name TEXT,
|
||||
gpu_driver TEXT,
|
||||
gpu_vram_mb INTEGER,
|
||||
cuda_version TEXT,
|
||||
mode TEXT NOT NULL -- 'single' | 'all-configs' | 'search-sweep' | 'ur-test' | 'ur-test-fast'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL REFERENCES runs(run_id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
model_key TEXT,
|
||||
config TEXT NOT NULL,
|
||||
device TEXT,
|
||||
dtype TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
iters INTEGER,
|
||||
decode_tokens INTEGER,
|
||||
search_iters INTEGER,
|
||||
ttft_ms REAL,
|
||||
ttft_ms_mean REAL,
|
||||
tpot_ms REAL,
|
||||
throughput_tps REAL,
|
||||
compile_ms REAL,
|
||||
note TEXT,
|
||||
error TEXT,
|
||||
ttft_ms_samples TEXT,
|
||||
tpot_ms_samples TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_results_run ON results(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_path ON results(path);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_config ON results(config);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_modelk ON results(model_key);
|
||||
"""
|
||||
|
||||
|
||||
# Columns that map 1:1 from a BENCH_RESULT record dict into `results`.
|
||||
_SCALAR_RESULT_COLS = (
|
||||
"path", "model", "model_key", "config",
|
||||
"device", "dtype",
|
||||
"prompt_tokens", "iters", "decode_tokens", "search_iters",
|
||||
"ttft_ms", "ttft_ms_mean", "tpot_ms", "throughput_tps", "compile_ms",
|
||||
"note", "error",
|
||||
)
|
||||
_SAMPLE_COLS = ("ttft_ms_samples", "tpot_ms_samples")
|
||||
_ALL_RESULT_COLS = ("run_id",) + _SCALAR_RESULT_COLS + _SAMPLE_COLS
|
||||
|
||||
|
||||
def connect(path: str | Path = DEFAULT_DB_PATH) -> sqlite3.Connection:
|
||||
"""Open (or create) the bench DB and ensure the schema exists."""
|
||||
p = Path(path)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(p)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.executescript(_SCHEMA)
|
||||
return conn
|
||||
|
||||
|
||||
def insert_run(
|
||||
conn: sqlite3.Connection,
|
||||
*,
|
||||
run_id: str,
|
||||
timestamp: str,
|
||||
mode: str,
|
||||
git_commit: str | None = None,
|
||||
git_branch: str | None = None,
|
||||
gpu_name: str | None = None,
|
||||
gpu_driver: str | None = None,
|
||||
gpu_vram_mb: int | None = None,
|
||||
cuda_version: str | None = None,
|
||||
if_exists: str = "ignore",
|
||||
) -> str:
|
||||
"""Insert a run row. if_exists='ignore' (default) leaves an existing
|
||||
row untouched; 'replace' overwrites."""
|
||||
verb = {"ignore": "INSERT OR IGNORE", "replace": "INSERT OR REPLACE"}[if_exists]
|
||||
conn.execute(
|
||||
f"""{verb} INTO runs
|
||||
(run_id, timestamp, git_commit, git_branch,
|
||||
gpu_name, gpu_driver, gpu_vram_mb, cuda_version, mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(run_id, timestamp, git_commit, git_branch,
|
||||
gpu_name, gpu_driver, gpu_vram_mb, cuda_version, mode),
|
||||
)
|
||||
return run_id
|
||||
|
||||
|
||||
def insert_result(conn: sqlite3.Connection, run_id: str, record: dict[str, Any]) -> int:
|
||||
"""Insert one BENCH_RESULT-shaped record under the given run_id."""
|
||||
values = [run_id]
|
||||
for col in _SCALAR_RESULT_COLS:
|
||||
values.append(record.get(col))
|
||||
for col in _SAMPLE_COLS:
|
||||
v = record.get(col)
|
||||
values.append(json.dumps(v) if v is not None else None)
|
||||
placeholders = ", ".join(["?"] * len(_ALL_RESULT_COLS))
|
||||
cols = ", ".join(_ALL_RESULT_COLS)
|
||||
cur = conn.execute(
|
||||
f"INSERT INTO results ({cols}) VALUES ({placeholders})",
|
||||
values,
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def insert_results(conn: sqlite3.Connection, run_id: str, records: Iterable[dict[str, Any]]) -> int:
|
||||
"""Bulk-insert; returns count."""
|
||||
n = 0
|
||||
for r in records:
|
||||
insert_result(conn, run_id, r)
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def latest_run_id(conn: sqlite3.Connection) -> str | None:
|
||||
row = conn.execute(
|
||||
"SELECT run_id FROM runs ORDER BY timestamp DESC, run_id DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return row["run_id"] if row else None
|
||||
|
||||
|
||||
def load_run(conn: sqlite3.Connection, run_id: str) -> dict[str, Any] | None:
|
||||
row = conn.execute("SELECT * FROM runs WHERE run_id = ?", (run_id,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def load_runs(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
||||
"""All runs, oldest → newest."""
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM runs ORDER BY timestamp ASC, run_id ASC"
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def _row_to_record(row: sqlite3.Row) -> dict[str, Any]:
|
||||
"""Convert a results row into a BENCH_RESULT-shaped dict, stripping NULLs
|
||||
so consumers see the same shape they did with JSON."""
|
||||
out: dict[str, Any] = {}
|
||||
for col in _SCALAR_RESULT_COLS:
|
||||
v = row[col]
|
||||
if v is not None:
|
||||
out[col] = v
|
||||
for col in _SAMPLE_COLS:
|
||||
v = row[col]
|
||||
if v is not None:
|
||||
out[col] = json.loads(v)
|
||||
return out
|
||||
|
||||
|
||||
def load_results(conn: sqlite3.Connection, run_id: str) -> list[dict[str, Any]]:
|
||||
"""All results for one run, in insertion order."""
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM results WHERE run_id = ? ORDER BY id ASC", (run_id,)
|
||||
).fetchall()
|
||||
return [_row_to_record(r) for r in rows]
|
||||
|
||||
|
||||
def load_history(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
||||
"""Mirror the legacy gen_dashboard.load_history() shape:
|
||||
[{"meta": {...}, "results": [...], "sweep": [...]}], sorted oldest→newest.
|
||||
Splits results vs sweep by config-startswith('s=')."""
|
||||
out = []
|
||||
for run in load_runs(conn):
|
||||
run_id = run["run_id"]
|
||||
meta = {
|
||||
"run_id": run_id,
|
||||
"timestamp": run["timestamp"],
|
||||
"git_commit": run["git_commit"] or "?",
|
||||
"git_branch": run["git_branch"] or "?",
|
||||
}
|
||||
if run["gpu_name"] is not None:
|
||||
meta["gpu_name"] = run["gpu_name"]
|
||||
if run["gpu_driver"] is not None:
|
||||
meta["gpu_driver"] = run["gpu_driver"]
|
||||
if run["gpu_vram_mb"] is not None:
|
||||
meta["gpu_vram_mb"] = run["gpu_vram_mb"]
|
||||
if run["cuda_version"] is not None:
|
||||
meta["cuda_version"] = run["cuda_version"]
|
||||
|
||||
records = load_results(conn, run_id)
|
||||
comparison, sweep = [], []
|
||||
for r in records:
|
||||
(sweep if r.get("config", "").startswith("s=") else comparison).append(r)
|
||||
out.append({"meta": meta, "results": comparison, "sweep": sweep})
|
||||
return out
|
||||
|
||||
|
||||
# ── self-test ────────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
# In-memory smoke test: round-trip one record.
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.executescript(_SCHEMA)
|
||||
insert_run(conn, run_id="test", timestamp="2026-04-27T00:00:00", mode="single")
|
||||
insert_result(conn, "test", {
|
||||
"path": "rust",
|
||||
"model": "test-model",
|
||||
"config": "default",
|
||||
"ttft_ms": 12.34,
|
||||
"ttft_ms_samples": [12.0, 12.5, 12.3],
|
||||
"search_iters": 500,
|
||||
})
|
||||
[row] = load_results(conn, "test")
|
||||
assert row["path"] == "rust", row
|
||||
assert row["ttft_ms"] == 12.34, row
|
||||
assert row["ttft_ms_samples"] == [12.0, 12.5, 12.3], row
|
||||
assert latest_run_id(conn) == "test"
|
||||
print("db.py smoke test ok")
|
||||
832
benchmarks/ttft/gen_dashboard.py
Normal file
832
benchmarks/ttft/gen_dashboard.py
Normal file
@@ -0,0 +1,832 @@
|
||||
"""Time-series benchmark dashboard generator.
|
||||
|
||||
Reads every run from the SQLite DB (benchmarks/ttft/bench.db) and produces a
|
||||
single standalone HTML file with Plotly.js charts styled to match luminal.com.
|
||||
|
||||
Layout:
|
||||
TTFT over time → one chart per model, lines = execution paths
|
||||
TPOT over time → same
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/ttft/gen_dashboard.py [--db PATH] [--out FILE]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import db
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# Path colours – kept distinct against the dark green Luminal accent
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#5b5f61", # muted slate
|
||||
"python_torch_compile": "#3b82f6", # blue (luminal accent palette)
|
||||
"python_luminal": "#a855f7", # purple (luminal accent palette)
|
||||
"rust": "#e8855a", # warm orange – Rust brand feel
|
||||
}
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "HF Baseline",
|
||||
"python_torch_compile": "torch.compile",
|
||||
"python_luminal": "luminal backend",
|
||||
"rust": "Rust (luminal)",
|
||||
}
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
|
||||
# (key, short label, y-axis label, scale, axis ticksuffix)
|
||||
# scale is applied to raw value before plotting (e.g. ms → sec via 0.001).
|
||||
METRICS = [
|
||||
("ttft_ms", "TTFT", "Time to first token (ms)", 1.0, " ms"),
|
||||
("tpot_ms", "TPOT", "Time per output token (ms)", 1.0, " ms"),
|
||||
("compile_ms", "Time to Search", "Search time (sec)", 0.001, " sec"),
|
||||
]
|
||||
|
||||
|
||||
# ── data loading ─────────────────────────────────────────────────────────────
|
||||
|
||||
def load_history(db_path: Path) -> list[dict]:
|
||||
"""Return [{"meta", "results", "sweep"}, …] from the bench DB,
|
||||
oldest→newest. Same shape the legacy JSON loader returned."""
|
||||
if not Path(db_path).exists():
|
||||
return []
|
||||
conn = db.connect(db_path)
|
||||
return db.load_history(conn)
|
||||
|
||||
|
||||
def build_series(runs: list[dict]) -> tuple[dict, list[str], list[str]]:
|
||||
"""Returns (data, run_ids, run_labels).
|
||||
|
||||
- data[model][path][metric] = [(run_id, value, commit, ts), ...]
|
||||
`run_id` is the categorical x value; `ts` is kept for tooltip formatting.
|
||||
- run_ids: chronological list of every run that appears in the comparison data.
|
||||
- run_labels: parallel to run_ids; "MMM DD · HH:MM" for nice axis ticks.
|
||||
|
||||
The categorical x-axis (one column per run_id) replaces the previous
|
||||
`type: date` axis. With multiple runs on the same day, the date axis
|
||||
silently stacked them on one column; the category axis spaces them
|
||||
evenly so each run is visually distinct.
|
||||
"""
|
||||
data: dict = {}
|
||||
seen_run_ids: list[str] = []
|
||||
seen_ts: dict[str, str] = {}
|
||||
|
||||
for run in runs:
|
||||
run_id = run["meta"]["run_id"]
|
||||
ts = run["meta"]["timestamp"]
|
||||
commit = run["meta"].get("git_commit", "?")
|
||||
had_data = False
|
||||
for r in run["results"]:
|
||||
if r.get("error") or r.get("ttft_ms") is None:
|
||||
continue
|
||||
model = r.get("config", r.get("model", "unknown"))
|
||||
path = r.get("path", "unknown")
|
||||
data.setdefault(model, {}).setdefault(path, {})
|
||||
for metric, _, _, scale, _ in METRICS:
|
||||
val = r.get(metric)
|
||||
if val is not None:
|
||||
data[model][path].setdefault(metric, []).append(
|
||||
(run_id, val * scale, commit, ts)
|
||||
)
|
||||
had_data = True
|
||||
if had_data and run_id not in seen_ts:
|
||||
seen_run_ids.append(run_id)
|
||||
seen_ts[run_id] = ts
|
||||
|
||||
run_ids = sorted(seen_run_ids, key=lambda rid: seen_ts.get(rid, rid))
|
||||
run_labels = []
|
||||
for rid in run_ids:
|
||||
ts = seen_ts.get(rid, rid)
|
||||
try:
|
||||
run_labels.append(datetime.fromisoformat(ts).strftime("%b %d · %H:%M"))
|
||||
except ValueError:
|
||||
run_labels.append(rid[:16].replace("T", " "))
|
||||
return data, run_ids, run_labels
|
||||
|
||||
|
||||
def build_sweep_series(runs: list[dict]) -> tuple[dict, list[str]]:
|
||||
"""Collect sweep records from ALL runs for 3D charting.
|
||||
|
||||
Returns:
|
||||
data[model_key][path][metric][run_id] = {
|
||||
"label": str, # short date label for Y axis
|
||||
"commit": str,
|
||||
"points": [(iters, ms), …] # sorted by iters
|
||||
}
|
||||
run_ids: list[str] in chronological order (oldest → newest)
|
||||
"""
|
||||
data: dict = {}
|
||||
run_ids: list[str] = []
|
||||
|
||||
for run in runs:
|
||||
if not run.get("sweep"):
|
||||
continue
|
||||
run_id = run["meta"]["run_id"]
|
||||
commit = run["meta"].get("git_commit", "?")
|
||||
try:
|
||||
label = datetime.fromisoformat(run["meta"]["timestamp"]).strftime("%b %d")
|
||||
except ValueError:
|
||||
label = run_id[:10]
|
||||
if run_id not in run_ids:
|
||||
run_ids.append(run_id)
|
||||
|
||||
for r in run["sweep"]:
|
||||
if r.get("error"):
|
||||
continue
|
||||
n = r.get("search_iters")
|
||||
if n is None:
|
||||
cfg = r.get("config", "")
|
||||
if cfg.startswith("s="):
|
||||
try:
|
||||
n = int(cfg[2:])
|
||||
except ValueError:
|
||||
continue
|
||||
if n is None:
|
||||
continue
|
||||
model_key = r.get("model_key", "unknown")
|
||||
path = r.get("path", "unknown")
|
||||
for metric, _, _, scale, _ in METRICS:
|
||||
val = r.get(metric)
|
||||
if val is None:
|
||||
continue
|
||||
(data
|
||||
.setdefault(model_key, {})
|
||||
.setdefault(path, {})
|
||||
.setdefault(metric, {})
|
||||
.setdefault(run_id, {"label": label, "commit": commit, "points": []})
|
||||
["points"].append((n, val * scale)))
|
||||
|
||||
# Sort points within each run by search_iters
|
||||
for mk in data:
|
||||
for path in data[mk]:
|
||||
for metric in data[mk][path]:
|
||||
for run_id in data[mk][path][metric]:
|
||||
data[mk][path][metric][run_id]["points"].sort(key=lambda x: x[0])
|
||||
|
||||
return data, run_ids
|
||||
|
||||
|
||||
# ── chart building ────────────────────────────────────────────────────────────
|
||||
|
||||
def _traces_json(path_data: dict, metric: str, show_legend: bool, unit: str = " ms") -> str:
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
if path not in path_data or metric not in path_data[path]:
|
||||
continue
|
||||
pts = path_data[path][metric]
|
||||
# pts: list of (run_id, val, commit, ts)
|
||||
trace = {
|
||||
"x": [p[0] for p in pts],
|
||||
"y": [p[1] for p in pts],
|
||||
"customdata": [[p[2], p[3]] for p in pts],
|
||||
"type": "scatter",
|
||||
"mode": "lines+markers",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"line": {"color": PATH_COLORS.get(path, "#aaa"), "width": 2},
|
||||
"marker": {"size": 7, "symbol": "circle"},
|
||||
"connectgaps": False,
|
||||
"showlegend": show_legend,
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)}</b><br>"
|
||||
"%{customdata[1]}<br>"
|
||||
f"%{{y:.1f}}{unit}<br>"
|
||||
"<span style='color:#7e8385'>commit %{customdata[0]}</span>"
|
||||
"<extra></extra>"
|
||||
),
|
||||
}
|
||||
traces.append(trace)
|
||||
return json.dumps(traces)
|
||||
|
||||
|
||||
_CHART_LAYOUT = {
|
||||
"plot_bgcolor": "#0d1416",
|
||||
"paper_bgcolor": "#141b1d",
|
||||
"font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"},
|
||||
"margin": {"t": 16, "b": 48, "l": 52, "r": 12},
|
||||
"height": 280,
|
||||
"xaxis": {
|
||||
# Categorical: one column per run, evenly spaced. Same-day runs
|
||||
# used to collapse on a date axis; this keeps every run distinct.
|
||||
"type": "category",
|
||||
"categoryorder": "array", # categoryarray injected per chart
|
||||
"color": "#5b5f61",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"tickfont": {"size": 11, "family": "Geist Mono, monospace"},
|
||||
"tickangle": -30,
|
||||
"automargin": True,
|
||||
"zeroline": False,
|
||||
},
|
||||
"yaxis": {
|
||||
"rangemode": "tozero",
|
||||
"color": "#5b5f61",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"tickfont": {"size": 11, "family": "Geist Mono, monospace"},
|
||||
"ticksuffix": " ms",
|
||||
"zeroline": False,
|
||||
},
|
||||
"legend": {
|
||||
"orientation": "h",
|
||||
"y": -0.28,
|
||||
"x": 0,
|
||||
"font": {"size": 11, "color": "#a1a4a5"},
|
||||
"bgcolor": "rgba(0,0,0,0)",
|
||||
},
|
||||
"hoverlabel": {
|
||||
"bgcolor": "#1c2225",
|
||||
"bordercolor":"#2d3335",
|
||||
"font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _chart_card(div_id: str, model: str, traces_json: str, show_legend: bool,
|
||||
run_ids: list[str], run_labels: list[str], unit: str = " ms") -> str:
|
||||
layout = dict(_CHART_LAYOUT)
|
||||
xaxis = {
|
||||
**layout["xaxis"],
|
||||
"categoryarray": run_ids,
|
||||
"tickvals": run_ids,
|
||||
"ticktext": run_labels,
|
||||
}
|
||||
layout = {**layout,
|
||||
"xaxis": xaxis,
|
||||
"yaxis": {**layout["yaxis"], "ticksuffix": unit}}
|
||||
if not show_legend:
|
||||
layout = {**layout, "legend": {**layout["legend"], "visible": False},
|
||||
"margin": {**layout["margin"], "b": 16}}
|
||||
return f"""<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">{model}</span>
|
||||
</div>
|
||||
<div id="{div_id}"></div>
|
||||
<script>
|
||||
Plotly.newPlot("{div_id}", {traces_json}, {json.dumps(layout)},
|
||||
{{responsive: true, displayModeBar: false}});
|
||||
</script>
|
||||
</div>"""
|
||||
|
||||
|
||||
def _sweep_3d_traces_json(model_data: dict, metric: str, run_ids: list[str], unit: str = " ms") -> str:
|
||||
"""One scatter3d trace per (path, run) — same colour per path, stacked by run on Y."""
|
||||
traces = []
|
||||
path_legend_shown: set[str] = set()
|
||||
|
||||
for run_id in run_ids:
|
||||
for path in PATH_ORDER:
|
||||
run_map = model_data.get(path, {}).get(metric, {})
|
||||
if run_id not in run_map:
|
||||
continue
|
||||
entry = run_map[run_id]
|
||||
pts = entry["points"]
|
||||
label = entry["label"]
|
||||
commit = entry["commit"]
|
||||
color = PATH_COLORS.get(path, "#aaa")
|
||||
show_legend = path not in path_legend_shown
|
||||
path_legend_shown.add(path)
|
||||
|
||||
traces.append({
|
||||
"type": "scatter3d",
|
||||
"mode": "lines+markers",
|
||||
"x": [p[0] for p in pts], # search iters
|
||||
"y": [label] * len(pts), # run label (categorical)
|
||||
"z": [p[1] for p in pts], # value (already scaled by build_sweep_series)
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"legendgroup": path,
|
||||
"showlegend": show_legend,
|
||||
"line": {"color": color, "width": 5},
|
||||
"marker": {"color": color, "size": 4},
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)}</b><br>"
|
||||
f"s=%{{x}} iters<br>%{{z:.1f}}{unit}<br>"
|
||||
f"{label} · {commit}"
|
||||
"<extra></extra>"
|
||||
),
|
||||
})
|
||||
|
||||
# Cross-run wire lines: for each path, connect same-budget points across
|
||||
# runs. Makes regressions at a fixed search budget visible as a kink in the
|
||||
# wireframe. Dashed + thinner than the per-run curves; legendgroup matches
|
||||
# the path so toggling one toggles both.
|
||||
for path in PATH_ORDER:
|
||||
metric_runs = model_data.get(path, {}).get(metric, {})
|
||||
if len(metric_runs) < 2:
|
||||
continue
|
||||
color = PATH_COLORS.get(path, "#aaa")
|
||||
# by_budget[iters] -> list of (run_label, value) in chronological order
|
||||
by_budget: dict = {}
|
||||
for run_id in run_ids:
|
||||
if run_id not in metric_runs:
|
||||
continue
|
||||
entry = metric_runs[run_id]
|
||||
for iters, val in entry["points"]:
|
||||
by_budget.setdefault(iters, []).append((entry["label"], val))
|
||||
for budget, items in sorted(by_budget.items()):
|
||||
if len(items) < 2:
|
||||
continue
|
||||
traces.append({
|
||||
"type": "scatter3d",
|
||||
"mode": "lines",
|
||||
"x": [budget] * len(items),
|
||||
"y": [it[0] for it in items],
|
||||
"z": [it[1] for it in items],
|
||||
"legendgroup": path,
|
||||
"showlegend": False,
|
||||
"line": {"color": color, "width": 2, "dash": "dash"},
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)} @ s={budget}</b><br>"
|
||||
f"%{{y}}: %{{z:.1f}}{unit}"
|
||||
"<extra></extra>"
|
||||
),
|
||||
})
|
||||
return json.dumps(traces)
|
||||
|
||||
|
||||
_SWEEP_3D_LAYOUT = {
|
||||
"paper_bgcolor": "#141b1d",
|
||||
"font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11},
|
||||
"height": 420,
|
||||
"margin": {"t": 20, "b": 0, "l": 0, "r": 0},
|
||||
"legend": {
|
||||
"orientation": "h",
|
||||
"y": -0.05,
|
||||
"x": 0,
|
||||
"font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"},
|
||||
"bgcolor": "rgba(0,0,0,0)",
|
||||
},
|
||||
"hoverlabel": {
|
||||
"bgcolor": "#1c2225",
|
||||
"bordercolor": "#2d3335",
|
||||
"font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"},
|
||||
},
|
||||
"scene": {
|
||||
"bgcolor": "#0d1416",
|
||||
"xaxis": {
|
||||
"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"type": "log",
|
||||
"tickvals": [5, 10, 20, 50, 100, 500],
|
||||
"ticktext": ["5", "10", "20", "50", "100", "500"],
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"zerolinecolor": "#2d3335",
|
||||
},
|
||||
"yaxis": {
|
||||
"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
},
|
||||
"zaxis": {
|
||||
"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"rangemode": "tozero",
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"ticksuffix": " ms",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
},
|
||||
"camera": {
|
||||
"eye": {"x": 1.6, "y": -1.6, "z": 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _sweep_3d_card(div_id: str, model: str, traces_json: str, unit: str = " ms") -> str:
|
||||
layout = {**_SWEEP_3D_LAYOUT,
|
||||
"scene": {**_SWEEP_3D_LAYOUT["scene"],
|
||||
"zaxis": {**_SWEEP_3D_LAYOUT["scene"]["zaxis"],
|
||||
"title": {**_SWEEP_3D_LAYOUT["scene"]["zaxis"]["title"],
|
||||
"text": unit.strip()},
|
||||
"ticksuffix": unit}}}
|
||||
return f"""<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">{model}</span>
|
||||
</div>
|
||||
<div id="{div_id}"></div>
|
||||
<script>
|
||||
Plotly.newPlot("{div_id}", {traces_json}, {json.dumps(layout)},
|
||||
{{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]}});
|
||||
</script>
|
||||
</div>"""
|
||||
|
||||
|
||||
# ── HTML assembly ─────────────────────────────────────────────────────────────
|
||||
|
||||
def build_html(runs: list[dict], data: dict,
|
||||
run_ids: list[str], run_labels: list[str],
|
||||
sweep_data: dict | None = None,
|
||||
sweep_run_ids: list[str] | None = None) -> str:
|
||||
# Preserve insertion order of models as seen across runs
|
||||
models = list(dict.fromkeys(
|
||||
r["config"]
|
||||
for run in runs
|
||||
for r in run["results"]
|
||||
if not r.get("config", "").startswith("s=") and not r.get("error")
|
||||
))
|
||||
|
||||
last_ts = ""
|
||||
if runs:
|
||||
raw = runs[-1]["meta"]["timestamp"]
|
||||
try:
|
||||
last_ts = datetime.fromisoformat(raw).strftime("%b %d, %Y · %H:%M")
|
||||
except ValueError:
|
||||
last_ts = raw[:16].replace("T", " ")
|
||||
|
||||
n_runs = len(runs)
|
||||
|
||||
sections_html = ""
|
||||
for metric_key, metric_label, ylabel, _scale, unit in METRICS:
|
||||
active_models = [
|
||||
m for m in models
|
||||
if any(metric_key in data.get(m, {}).get(p, {}) for p in PATH_ORDER)
|
||||
]
|
||||
if not active_models:
|
||||
continue
|
||||
|
||||
cards_html = ""
|
||||
first = True
|
||||
for model in active_models:
|
||||
path_data = data.get(model, {})
|
||||
div_id = f"c_{metric_key}_{model.replace('-','_').replace('.','_')}"
|
||||
traces = _traces_json(path_data, metric_key, show_legend=first, unit=unit)
|
||||
cards_html += _chart_card(div_id, model, traces, show_legend=first,
|
||||
run_ids=run_ids, run_labels=run_labels, unit=unit)
|
||||
first = False
|
||||
|
||||
n = len(active_models)
|
||||
# Clamp columns so charts don't get too narrow; wrap at 4
|
||||
cols = min(n, 4)
|
||||
sections_html += f"""
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">{metric_label} <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">{ylabel}</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat({cols}, 1fr)">
|
||||
{cards_html}
|
||||
</div>
|
||||
</section>"""
|
||||
|
||||
# ── sweep sections (3D) ──────────────────────────────────────────────────
|
||||
sweep_sections_html = ""
|
||||
if sweep_data and sweep_run_ids:
|
||||
sweep_models = list(sweep_data.keys())
|
||||
for metric_key, metric_label, ylabel, _scale, unit in METRICS:
|
||||
active = [
|
||||
m for m in sweep_models
|
||||
if any(
|
||||
run_id in sweep_data[m].get(p, {}).get(metric_key, {})
|
||||
for p in PATH_ORDER
|
||||
for run_id in sweep_run_ids
|
||||
)
|
||||
]
|
||||
if not active:
|
||||
continue
|
||||
cards_html = ""
|
||||
for model in active:
|
||||
div_id = f"sw_{metric_key}_{model.replace('-','_').replace('.','_')}"
|
||||
traces = _sweep_3d_traces_json(sweep_data[model], metric_key, sweep_run_ids, unit=unit)
|
||||
cards_html += _sweep_3d_card(div_id, model, traces, unit=unit)
|
||||
cols = min(len(active), 4)
|
||||
run_count = len(sweep_run_ids)
|
||||
sweep_sections_html += f"""
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">{metric_label} <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">{run_count} run{"s" if run_count != 1 else ""}</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat({cols}, 1fr)">
|
||||
{cards_html}
|
||||
</div>
|
||||
</section>"""
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal · Benchmark Dashboard</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Geist:wght@300;400;500;600&family=Geist+Mono:wght@300;400;500&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
html {{ -webkit-font-smoothing: antialiased; scroll-behavior: smooth; }}
|
||||
|
||||
body {{
|
||||
font-family: 'Geist', system-ui, sans-serif;
|
||||
background: #030712;
|
||||
color: #d7d8d9;
|
||||
min-height: 100vh;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
|
||||
/* ── NAV ── */
|
||||
nav {{
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 50;
|
||||
height: 56px;
|
||||
background: rgba(8, 15, 17, 0.92);
|
||||
backdrop-filter: blur(8px);
|
||||
-webkit-backdrop-filter: blur(8px);
|
||||
border-bottom: 1px solid #2d3335;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 24px;
|
||||
gap: 0;
|
||||
}}
|
||||
.nav-brand {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
letter-spacing: 0.05em;
|
||||
color: #2faa6e;
|
||||
text-decoration: none;
|
||||
}}
|
||||
.nav-dot {{
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background: #2faa6e;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
animation: pulse-glow 2s ease-in-out infinite;
|
||||
}}
|
||||
.nav-sep {{
|
||||
color: #2d3335;
|
||||
margin: 0 14px;
|
||||
font-size: 18px;
|
||||
font-weight: 300;
|
||||
}}
|
||||
.nav-page {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}}
|
||||
|
||||
@keyframes pulse-glow {{
|
||||
0%, 100% {{ opacity: 1; }}
|
||||
50% {{ opacity: 0.35; }}
|
||||
}}
|
||||
|
||||
/* ── MAIN ── */
|
||||
main {{
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 40px 24px 80px;
|
||||
}}
|
||||
|
||||
/* ── PAGE HEADER ── */
|
||||
.page-header {{
|
||||
margin-bottom: 40px;
|
||||
padding-bottom: 32px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
}}
|
||||
.page-eyebrow {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.page-title {{
|
||||
font-size: 30px;
|
||||
font-weight: 500;
|
||||
letter-spacing: -0.025em;
|
||||
color: #d7d8d9;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.page-meta {{
|
||||
font-size: 14px;
|
||||
color: #7e8385;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.meta-sep {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
color: #2d3335;
|
||||
margin: 0 10px;
|
||||
}}
|
||||
.meta-val {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 13px;
|
||||
color: #5b5f61;
|
||||
}}
|
||||
|
||||
/* ── LEGEND STRIP ── */
|
||||
.legend-strip {{
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-bottom: 32px;
|
||||
}}
|
||||
.legend-pill {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #a1a4a5;
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
padding: 4px 10px;
|
||||
}}
|
||||
.legend-swatch {{
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}}
|
||||
|
||||
/* ── SECTIONS ── */
|
||||
section {{ margin-bottom: 48px; }}
|
||||
.section-header {{
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 10px;
|
||||
margin-bottom: 16px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.section-eyebrow {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #404647;
|
||||
}}
|
||||
.section-title {{
|
||||
font-size: 18px;
|
||||
font-weight: 500;
|
||||
color: #d7d8d9;
|
||||
letter-spacing: -0.01em;
|
||||
}}
|
||||
.section-title .unit {{
|
||||
color: #7e8385;
|
||||
font-weight: 400;
|
||||
}}
|
||||
.section-tag {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
background: #162322;
|
||||
border: 1px solid #1c372e;
|
||||
padding: 2px 8px;
|
||||
border-radius: 2px;
|
||||
margin-left: auto;
|
||||
}}
|
||||
|
||||
/* ── CHART GRID ── */
|
||||
.chart-grid {{
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}}
|
||||
.chart-card {{
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
transition: border-color 150ms;
|
||||
min-width: 0;
|
||||
}}
|
||||
.chart-card:hover {{ border-color: #404647; }}
|
||||
.chart-card-header {{
|
||||
padding: 10px 14px 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}}
|
||||
.model-tag {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.06em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}}
|
||||
|
||||
/* ── FOOTER ── */
|
||||
footer {{
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px 24px;
|
||||
border-top: 1px solid #1c2225;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}}
|
||||
|
||||
.section-divider {{
|
||||
border: none;
|
||||
border-top: 1px solid #1c2225;
|
||||
margin: 8px 0 40px;
|
||||
}}
|
||||
.sweep-hint {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
margin-bottom: 12px;
|
||||
}}
|
||||
|
||||
@media (max-width: 768px) {{
|
||||
.chart-grid {{ grid-template-columns: 1fr !important; }}
|
||||
.page-title {{ font-size: 22px; }}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<nav>
|
||||
<a class="nav-brand" href="https://luminal.com">
|
||||
<span class="nav-dot"></span>luminal
|
||||
</a>
|
||||
<span class="nav-sep">/</span>
|
||||
<span class="nav-page">benchmarks</span>
|
||||
</nav>
|
||||
|
||||
<main>
|
||||
|
||||
<header class="page-header">
|
||||
<p class="page-eyebrow">performance · time-series</p>
|
||||
<h1 class="page-title">Benchmark Dashboard</h1>
|
||||
<div class="page-meta">
|
||||
<span>Last updated</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">{last_ts}</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">{n_runs} run{"s" if n_runs != 1 else ""} in history</span>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="legend-strip">
|
||||
{"".join(
|
||||
f'<div class="legend-pill"><span class="legend-swatch" style="background:{PATH_COLORS[p]}"></span>{PATH_LABELS[p]}</div>'
|
||||
for p in PATH_ORDER
|
||||
)}
|
||||
</div>
|
||||
|
||||
{sections_html}
|
||||
{"<hr class='section-divider'>" + sweep_sections_html if sweep_sections_html else ""}
|
||||
|
||||
</main>
|
||||
|
||||
<footer>
|
||||
<span>luminal · benchmark dashboard</span>
|
||||
<span>generated {last_ts}</span>
|
||||
</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
# ── entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help=f"SQLite bench DB (default: {db.DEFAULT_DB_PATH})")
|
||||
ap.add_argument("--out", default=str(BENCH_DIR / "dashboard.html"),
|
||||
help="Output HTML file")
|
||||
args = ap.parse_args()
|
||||
|
||||
runs = load_history(Path(args.db))
|
||||
if not runs:
|
||||
print(f"No runs found in {args.db}. Run --ur-test (or backfill) first.")
|
||||
return
|
||||
|
||||
data, run_ids, run_labels = build_series(runs)
|
||||
sweep_data, sweep_run_ids = build_sweep_series(runs)
|
||||
html = build_html(runs, data, run_ids, run_labels, sweep_data, sweep_run_ids)
|
||||
Path(args.out).write_text(html)
|
||||
|
||||
print(f"wrote {args.out} ({len(runs)} runs, {sum(len(v) for v in data.values())} model×path series)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
349
benchmarks/ttft/gen_report.py
Normal file
349
benchmarks/ttft/gen_report.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate a standalone HTML benchmark report from a single benchmark run.
|
||||
|
||||
Usage:
|
||||
python3 gen_report.py [--db PATH] [--run RUN_ID] [--out report.html] [--title "..."]
|
||||
|
||||
Sections are split out of a single run automatically:
|
||||
- per-model_key, "comparison" (configs not matching s=N) → grouped bar chart
|
||||
- per-model_key, "sweep" (configs matching s=N) → line chart (log X)
|
||||
For runs without model_key (e.g. single-config runs), one section per detected
|
||||
shape is produced instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import db
|
||||
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "HF Baseline",
|
||||
"python_torch_compile": "torch.compile",
|
||||
"python_luminal": "luminal backend",
|
||||
"rust": "Rust (luminal)",
|
||||
}
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#888888",
|
||||
"python_torch_compile": "#5ab552",
|
||||
"python_luminal": "#4c9ed9",
|
||||
"rust": "#d97a4c",
|
||||
}
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _fmt(v, decimals=1, suffix=""):
|
||||
return f"{v:.{decimals}f}{suffix}" if v is not None else "—"
|
||||
|
||||
def _section_title(path: Path) -> str:
|
||||
stem = path.stem.replace("_", " ").replace("-", " ")
|
||||
return stem.title()
|
||||
|
||||
def _is_sweep(configs: list[str]) -> bool:
|
||||
return bool(configs) and all(re.fullmatch(r"s=\d+", c) for c in configs)
|
||||
|
||||
def _group_by_config(results: list[dict]) -> dict[str, dict[str, dict]]:
|
||||
"""Return {config: {path: result_dict}}."""
|
||||
out: dict[str, dict[str, dict]] = {}
|
||||
for r in results:
|
||||
cfg = r.get("config", "default")
|
||||
out.setdefault(cfg, {})[r["path"]] = r
|
||||
return out
|
||||
|
||||
|
||||
# ── chart builders (return Plotly figure dicts) ───────────────────────────────
|
||||
|
||||
def _bar_figure(by_config: dict, metric: str, title: str,
|
||||
scale: float = 1.0, unit: str = "ms") -> dict:
|
||||
configs = list(by_config.keys())
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
ys, texts = [], []
|
||||
for cfg in configs:
|
||||
r = by_config[cfg].get(path)
|
||||
raw = r.get(metric) if r and not r.get("error") else None
|
||||
v = raw * scale if raw is not None else None
|
||||
ys.append(v if v is not None else 0)
|
||||
texts.append(f"{v:.1f} {unit}" if v is not None else "n/a")
|
||||
if any(y > 0 for y in ys):
|
||||
traces.append({
|
||||
"type": "bar",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"x": configs,
|
||||
"y": ys,
|
||||
"text": texts,
|
||||
"textposition": "outside",
|
||||
"marker": {"color": PATH_COLORS.get(path, "#aaaaaa")},
|
||||
"hovertemplate": "%{x}<br>" + PATH_LABELS.get(path, path)
|
||||
+ f": %{{y:.1f}} {unit}<extra></extra>",
|
||||
})
|
||||
return {
|
||||
"data": traces,
|
||||
"layout": {
|
||||
"title": title,
|
||||
"yaxis": {"title": unit, "rangemode": "tozero"},
|
||||
"barmode": "group",
|
||||
"legend": {"orientation": "h", "y": -0.2},
|
||||
"margin": {"t": 50, "b": 80},
|
||||
"plot_bgcolor": "#fafafa",
|
||||
"paper_bgcolor": "#ffffff",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _line_figure(by_config: dict, metric: str, title: str,
|
||||
scale: float = 1.0, unit: str = "ms") -> dict:
|
||||
"""Line chart for sweep data. Config names are 's=N'; X = N (log scale)."""
|
||||
def _iter(cfg):
|
||||
m = re.fullmatch(r"s=(\d+)", cfg)
|
||||
return int(m.group(1)) if m else 0
|
||||
|
||||
configs_sorted = sorted(by_config.keys(), key=_iter)
|
||||
xs = [_iter(c) for c in configs_sorted]
|
||||
|
||||
paths_present = {p for cfg in by_config.values() for p in cfg}
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
if path not in paths_present:
|
||||
continue
|
||||
ys = []
|
||||
for cfg in configs_sorted:
|
||||
r = by_config[cfg].get(path)
|
||||
raw = r.get(metric) if r and not r.get("error") else None
|
||||
ys.append(raw * scale if raw is not None else None)
|
||||
if any(y is not None for y in ys):
|
||||
traces.append({
|
||||
"type": "scatter",
|
||||
"mode": "lines+markers",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"x": xs,
|
||||
"y": ys,
|
||||
"marker": {"size": 8, "color": PATH_COLORS.get(path, "#aaaaaa")},
|
||||
"line": {"color": PATH_COLORS.get(path, "#aaaaaa"), "width": 2},
|
||||
"hovertemplate": "iters=%{x}<br>" + PATH_LABELS.get(path, path)
|
||||
+ f": %{{y:.1f}} {unit}<extra></extra>",
|
||||
})
|
||||
return {
|
||||
"data": traces,
|
||||
"layout": {
|
||||
"title": title,
|
||||
"xaxis": {"title": "Search iterations", "type": "log",
|
||||
"tickvals": xs, "ticktext": [str(x) for x in xs]},
|
||||
"yaxis": {"title": unit, "rangemode": "tozero"},
|
||||
"legend": {"orientation": "h", "y": -0.25},
|
||||
"margin": {"t": 50, "b": 90},
|
||||
"plot_bgcolor": "#fafafa",
|
||||
"paper_bgcolor": "#ffffff",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── table builder ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _table_html(results: list[dict]) -> str:
|
||||
rows = []
|
||||
for r in sorted(results, key=lambda r: (r.get("config", ""), PATH_ORDER.index(r["path"]) if r["path"] in PATH_ORDER else 99)):
|
||||
error = r.get("error")
|
||||
style = ' style="background:#fff0f0"' if error else ""
|
||||
path_label = PATH_LABELS.get(r["path"], r["path"])
|
||||
cfg = r.get("config", "—")
|
||||
ttft = _fmt(r.get("ttft_ms"), 1, " ms")
|
||||
tpot = _fmt(r.get("tpot_ms"), 1, " ms")
|
||||
tput = _fmt(r.get("throughput_tps"), 1, " tok/s")
|
||||
comp = _fmt(r.get("compile_ms"), 0, " ms") if r.get("compile_ms") else "—"
|
||||
ptok = str(r.get("prompt_tokens", "—"))
|
||||
note = (r.get("error") or r.get("note") or "")[:90]
|
||||
note_style = ' style="color:#c00"' if error else ' style="color:#777"'
|
||||
rows.append(
|
||||
f'<tr{style}>'
|
||||
f'<td>{path_label}</td><td>{cfg}</td>'
|
||||
f'<td>{ttft}</td><td>{tpot}</td><td>{tput}</td>'
|
||||
f'<td>{comp}</td><td>{ptok}</td>'
|
||||
f'<td{note_style}>{note}</td>'
|
||||
f'</tr>'
|
||||
)
|
||||
return (
|
||||
'<table>'
|
||||
'<thead><tr>'
|
||||
'<th>Path</th><th>Config</th>'
|
||||
'<th>TTFT</th><th>TPOT</th><th>Throughput</th>'
|
||||
'<th>Compile</th><th>Prompt tokens</th><th>Note</th>'
|
||||
'</tr></thead>'
|
||||
'<tbody>' + "\n".join(rows) + '</tbody>'
|
||||
'</table>'
|
||||
)
|
||||
|
||||
|
||||
# ── section builder ───────────────────────────────────────────────────────────
|
||||
|
||||
def _section_html(sec_id: str, title: str, results: list[dict], fig_counter: list) -> str:
|
||||
by_config = _group_by_config(results)
|
||||
configs = list(by_config.keys())
|
||||
sweep = _is_sweep(configs)
|
||||
|
||||
models = list(dict.fromkeys(r.get("model", "") for r in results if r.get("model")))
|
||||
model_str = ", ".join(models) if models else "—"
|
||||
prompt_tokens = list(dict.fromkeys(r.get("prompt_tokens") for r in results if r.get("prompt_tokens")))
|
||||
tok_str = "/".join(str(t) for t in prompt_tokens) + " prompt tokens" if prompt_tokens else ""
|
||||
|
||||
builder = _line_figure if sweep else _bar_figure
|
||||
ttft_fig = builder(by_config, "ttft_ms", "TTFT")
|
||||
has_tpot = any(r.get("tpot_ms") is not None for r in results if not r.get("error"))
|
||||
tpot_fig = builder(by_config, "tpot_ms", "TPOT") if has_tpot else None
|
||||
has_compile = any(r.get("compile_ms") is not None and r.get("compile_ms") > 0
|
||||
for r in results if not r.get("error"))
|
||||
compile_fig = (builder(by_config, "compile_ms", "Time to Search",
|
||||
scale=0.001, unit="sec")
|
||||
if has_compile else None)
|
||||
|
||||
def chart_div(fig):
|
||||
n = fig_counter[0]
|
||||
fig_counter[0] += 1
|
||||
return (
|
||||
f'<div id="fig{n}" class="chart"></div>'
|
||||
f'<script>Plotly.newPlot("fig{n}", {json.dumps(fig["data"])}, {json.dumps(fig["layout"])}, {{responsive:true}});</script>'
|
||||
)
|
||||
|
||||
charts_html = f'<div class="charts-row">{chart_div(ttft_fig)}'
|
||||
if tpot_fig:
|
||||
charts_html += chart_div(tpot_fig)
|
||||
if compile_fig:
|
||||
charts_html += chart_div(compile_fig)
|
||||
charts_html += '</div>'
|
||||
|
||||
return f"""
|
||||
<section id="{sec_id}">
|
||||
<h2>{title}</h2>
|
||||
<p class="meta">{model_str}{" · " + tok_str if tok_str else ""} · {len(results)} results</p>
|
||||
{charts_html}
|
||||
{_table_html(results)}
|
||||
</section>
|
||||
"""
|
||||
|
||||
|
||||
# ── full page ─────────────────────────────────────────────────────────────────
|
||||
|
||||
CSS = """
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: system-ui, sans-serif; background: #f0f2f5; color: #222; }
|
||||
header { background: #1a1a2e; color: #fff; padding: 1rem 2rem;
|
||||
position: sticky; top: 0; z-index: 100; display: flex;
|
||||
align-items: center; gap: 2rem; }
|
||||
header h1 { font-size: 1.2rem; white-space: nowrap; }
|
||||
nav a { color: #a0c4ff; text-decoration: none; font-size: 0.9rem;
|
||||
padding: 0.3rem 0.7rem; border-radius: 4px; white-space: nowrap; }
|
||||
nav a:hover { background: rgba(255,255,255,0.15); }
|
||||
main { max-width: 1400px; margin: 0 auto; padding: 2rem; display: flex;
|
||||
flex-direction: column; gap: 2.5rem; }
|
||||
section { background: #fff; border-radius: 8px; padding: 1.5rem 2rem;
|
||||
box-shadow: 0 1px 4px rgba(0,0,0,.08); }
|
||||
h2 { font-size: 1.3rem; margin-bottom: 0.4rem; }
|
||||
.meta { color: #666; font-size: 0.85rem; margin-bottom: 1.2rem; }
|
||||
.charts-row { display: flex; gap: 1.5rem; flex-wrap: wrap; margin-bottom: 1.5rem; }
|
||||
.chart { flex: 1; min-width: 340px; height: 360px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 0.82rem; }
|
||||
thead tr { background: #f5f5f5; }
|
||||
th, td { padding: 0.45rem 0.7rem; text-align: left;
|
||||
border-bottom: 1px solid #e8e8e8; }
|
||||
th { font-weight: 600; white-space: nowrap; }
|
||||
tr:last-child td { border-bottom: none; }
|
||||
tr:hover { background: #fafafa; }
|
||||
"""
|
||||
|
||||
def _build_html(sections: list[tuple[str, str, list[dict]]], title: str) -> str:
|
||||
nav_links = "".join(f'<a href="#{sid}">{stitle}</a>' for sid, stitle, _ in sections)
|
||||
fig_counter = [0]
|
||||
body = "".join(_section_html(sid, stitle, results, fig_counter)
|
||||
for sid, stitle, results in sections)
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>{title}</title>
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>{CSS}</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>{title}</h1>
|
||||
<nav>{nav_links}</nav>
|
||||
</header>
|
||||
<main>{body}</main>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _sections_for_run(results: list[dict]) -> list[tuple[str, str, list[dict]]]:
|
||||
"""Split a single run's results into (sec_id, title, records) sections.
|
||||
|
||||
Splits first by model_key (NULL → 'results'), then within each by
|
||||
sweep-vs-comparison based on config 's=N' shape."""
|
||||
by_key: dict[str | None, list[dict]] = {}
|
||||
for r in results:
|
||||
by_key.setdefault(r.get("model_key"), []).append(r)
|
||||
|
||||
sections: list[tuple[str, str, list[dict]]] = []
|
||||
for key, recs in by_key.items():
|
||||
comp, sweep = [], []
|
||||
for r in recs:
|
||||
(sweep if str(r.get("config", "")).startswith("s=") else comp).append(r)
|
||||
prefix = (key or "results").replace("-", "_").replace(".", "_")
|
||||
title_prefix = key or "Results"
|
||||
if comp:
|
||||
sections.append((f"{prefix}_comparison",
|
||||
f"{title_prefix} comparison".strip().title(),
|
||||
comp))
|
||||
if sweep:
|
||||
sections.append((f"{prefix}_sweep",
|
||||
f"{title_prefix} sweep".strip().title(),
|
||||
sweep))
|
||||
return sections
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help=f"SQLite bench DB (default: {db.DEFAULT_DB_PATH})")
|
||||
ap.add_argument("--run", default=None,
|
||||
help="Run ID to render (default: latest run in DB)")
|
||||
ap.add_argument("--out", default=None,
|
||||
help="Output HTML path (default: report.html in benchmarks/ttft/)")
|
||||
ap.add_argument("--title", default="Luminal TTFT Benchmark Report",
|
||||
help="Page title and heading")
|
||||
args = ap.parse_args()
|
||||
|
||||
if not Path(args.db).exists():
|
||||
print(f"DB not found: {args.db}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
conn = db.connect(args.db)
|
||||
run_id = args.run or db.latest_run_id(conn)
|
||||
if run_id is None:
|
||||
print(f"No runs in {args.db}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
results = db.load_results(conn, run_id)
|
||||
if not results:
|
||||
print(f"No results for run {run_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
sections = _sections_for_run(results)
|
||||
if not sections:
|
||||
print(f"No section data for run {run_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
out = Path(args.out) if args.out else Path(__file__).parent / "report.html"
|
||||
html = _build_html(sections, f"{args.title} — {run_id}")
|
||||
out.write_text(html)
|
||||
print(f"wrote {out} (run {run_id}, {len(sections)} sections, {len(results)} results)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
benchmarks/ttft/report.html
Normal file
148
benchmarks/ttft/report.html
Normal file
@@ -0,0 +1,148 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal TTFT Benchmark Report — 2026-05-01T18-56-26-996695</title>
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: system-ui, sans-serif; background: #f0f2f5; color: #222; }
|
||||
header { background: #1a1a2e; color: #fff; padding: 1rem 2rem;
|
||||
position: sticky; top: 0; z-index: 100; display: flex;
|
||||
align-items: center; gap: 2rem; }
|
||||
header h1 { font-size: 1.2rem; white-space: nowrap; }
|
||||
nav a { color: #a0c4ff; text-decoration: none; font-size: 0.9rem;
|
||||
padding: 0.3rem 0.7rem; border-radius: 4px; white-space: nowrap; }
|
||||
nav a:hover { background: rgba(255,255,255,0.15); }
|
||||
main { max-width: 1400px; margin: 0 auto; padding: 2rem; display: flex;
|
||||
flex-direction: column; gap: 2.5rem; }
|
||||
section { background: #fff; border-radius: 8px; padding: 1.5rem 2rem;
|
||||
box-shadow: 0 1px 4px rgba(0,0,0,.08); }
|
||||
h2 { font-size: 1.3rem; margin-bottom: 0.4rem; }
|
||||
.meta { color: #666; font-size: 0.85rem; margin-bottom: 1.2rem; }
|
||||
.charts-row { display: flex; gap: 1.5rem; flex-wrap: wrap; margin-bottom: 1.5rem; }
|
||||
.chart { flex: 1; min-width: 340px; height: 360px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 0.82rem; }
|
||||
thead tr { background: #f5f5f5; }
|
||||
th, td { padding: 0.45rem 0.7rem; text-align: left;
|
||||
border-bottom: 1px solid #e8e8e8; }
|
||||
th { font-weight: 600; white-space: nowrap; }
|
||||
tr:last-child td { border-bottom: none; }
|
||||
tr:hover { background: #fafafa; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Luminal TTFT Benchmark Report — 2026-05-01T18-56-26-996695</h1>
|
||||
<nav><a href="#llama_8b_comparison">Llama-8B Comparison</a><a href="#llama_8b_sweep">Llama-8B Sweep</a><a href="#qwen3_4b_comparison">Qwen3-4B Comparison</a><a href="#qwen3_4b_sweep">Qwen3-4B Sweep</a><a href="#gemma3_4b_comparison">Gemma3-4B Comparison</a><a href="#gemma3_4b_sweep">Gemma3-4B Sweep</a><a href="#gemma4_moe_comparison">Gemma4-Moe Comparison</a><a href="#gemma4_moe_sweep">Gemma4-Moe Sweep</a><a href="#qwen3_moe_comparison">Qwen3-Moe Comparison</a><a href="#qwen3_moe_sweep">Qwen3-Moe Sweep</a></nav>
|
||||
</header>
|
||||
<main>
|
||||
<section id="llama_8b_comparison">
|
||||
<h2>Llama-8B Comparison</h2>
|
||||
<p class="meta">NousResearch/Meta-Llama-3-8B-Instruct · 21 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig0" class="chart"></div><script>Plotly.newPlot("fig0", [{"type": "bar", "name": "HF Baseline", "x": ["llama-8b"], "y": [705.9654394979589], "text": ["706.0 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [307.66548847896047], "text": ["307.7 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [461.48114453535527], "text": ["461.5 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [1026.86], "text": ["1026.9 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig1" class="chart"></div><script>Plotly.newPlot("fig1", [{"type": "bar", "name": "HF Baseline", "x": ["llama-8b"], "y": [34.15271903970279], "text": ["34.2 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [171.7862353892997], "text": ["171.8 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [23.078908618772402], "text": ["23.1 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [51.64], "text": ["51.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig2" class="chart"></div><script>Plotly.newPlot("fig2", [{"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [18.760145067994017], "text": ["18.8 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [95.96263545705006], "text": ["96.0 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [84.45343], "text": ["84.5 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>llama-8b</td><td>706.0 ms</td><td>34.2 ms</td><td>29.3 tok/s</td><td>—</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>llama-8b</td><td>307.7 ms</td><td>171.8 ms</td><td>5.8 tok/s</td><td>18760 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>llama-8b</td><td>461.5 ms</td><td>23.1 ms</td><td>43.3 tok/s</td><td>95963 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>llama-8b</td><td>1026.9 ms</td><td>51.6 ms</td><td>19.4 tok/s</td><td>84453 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="llama_8b_sweep">
|
||||
<h2>Llama-8B Sweep</h2>
|
||||
<p class="meta">NousResearch/Meta-Llama-3-8B-Instruct · 21 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig3" class="chart"></div><script>Plotly.newPlot("fig3", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [470.7036415056791, 460.72837291285396, 472.43661794345826], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [751.03, 1038.34, 453.16], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig4" class="chart"></div><script>Plotly.newPlot("fig4", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [23.540849717101082, 23.101884137140587, 23.610779400914907], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [38.2, 51.92, 24.09], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig5" class="chart"></div><script>Plotly.newPlot("fig5", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [28.428826077957638, 43.57440591201885, 95.52432684396626], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [15.14307, 30.12727, 84.87889], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>470.7 ms</td><td>23.5 ms</td><td>42.5 tok/s</td><td>28429 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>751.0 ms</td><td>38.2 ms</td><td>26.2 tok/s</td><td>15143 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>460.7 ms</td><td>23.1 ms</td><td>43.3 tok/s</td><td>43574 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>1038.3 ms</td><td>51.9 ms</td><td>19.3 tok/s</td><td>30127 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>472.4 ms</td><td>23.6 ms</td><td>42.4 tok/s</td><td>95524 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>453.2 ms</td><td>24.1 ms</td><td>41.5 tok/s</td><td>84879 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_4b_comparison">
|
||||
<h2>Qwen3-4B Comparison</h2>
|
||||
<p class="meta">Qwen/Qwen3-4B · 19/11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig6" class="chart"></div><script>Plotly.newPlot("fig6", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-4b"], "y": [869.2860195587855], "text": ["869.3 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [298.27259748708457], "text": ["298.3 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [485.3892414830625], "text": ["485.4 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [398.58], "text": ["398.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig7" class="chart"></div><script>Plotly.newPlot("fig7", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-4b"], "y": [47.71483448566869], "text": ["47.7 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [468.56868775503244], "text": ["468.6 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [26.90318431414198], "text": ["26.9 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [40.62], "text": ["40.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig8" class="chart"></div><script>Plotly.newPlot("fig8", [{"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [4.680963660997804], "text": ["4.7 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [45.345814052037895], "text": ["45.3 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [19.92977], "text": ["19.9 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>qwen3-4b</td><td>869.3 ms</td><td>47.7 ms</td><td>21.0 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>qwen3-4b</td><td>298.3 ms</td><td>468.6 ms</td><td>2.1 tok/s</td><td>4681 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>qwen3-4b</td><td>485.4 ms</td><td>26.9 ms</td><td>37.2 tok/s</td><td>45346 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>qwen3-4b</td><td>398.6 ms</td><td>40.6 ms</td><td>24.6 tok/s</td><td>19930 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_4b_sweep">
|
||||
<h2>Qwen3-4B Sweep</h2>
|
||||
<p class="meta">Qwen/Qwen3-4B · 19/11 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig9" class="chart"></div><script>Plotly.newPlot("fig9", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [465.02652901108377, 465.9317950136028, 495.75577257201076], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [398.44, 390.08, 559.29], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig10" class="chart"></div><script>Plotly.newPlot("fig10", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [25.875402649398893, 25.884080055402592, 27.492373346467502], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [40.64, 39.98, 55.37], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig11" class="chart"></div><script>Plotly.newPlot("fig11", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [37.92102829599753, 54.08867314597592, 118.29659596900456], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [12.448030000000001, 27.06796, 81.89342], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>465.0 ms</td><td>25.9 ms</td><td>38.6 tok/s</td><td>37921 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>398.4 ms</td><td>40.6 ms</td><td>24.6 tok/s</td><td>12448 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>465.9 ms</td><td>25.9 ms</td><td>38.6 tok/s</td><td>54089 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>390.1 ms</td><td>40.0 ms</td><td>25.0 tok/s</td><td>27068 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>495.8 ms</td><td>27.5 ms</td><td>36.4 tok/s</td><td>118297 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>559.3 ms</td><td>55.4 ms</td><td>18.1 tok/s</td><td>81893 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma3_4b_comparison">
|
||||
<h2>Gemma3-4B Comparison</h2>
|
||||
<p class="meta">unsloth/gemma-3-4b-it · 19/11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig12" class="chart"></div><script>Plotly.newPlot("fig12", [{"type": "bar", "name": "HF Baseline", "x": ["gemma3-4b"], "y": [951.1196144158021], "text": ["951.1 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [300.9451600664761], "text": ["300.9 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [404.43], "text": ["404.4 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig13" class="chart"></div><script>Plotly.newPlot("fig13", [{"type": "bar", "name": "HF Baseline", "x": ["gemma3-4b"], "y": [52.498737201676704], "text": ["52.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [2197.426627812092], "text": ["2197.4 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [38.99], "text": ["39.0 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig14" class="chart"></div><script>Plotly.newPlot("fig14", [{"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [26.649526304972824], "text": ["26.6 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [156.84164], "text": ["156.8 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>gemma3-4b</td><td>951.1 ms</td><td>52.5 ms</td><td>19.0 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>gemma3-4b</td><td>300.9 ms</td><td>2197.4 ms</td><td>0.5 tok/s</td><td>26650 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>gemma3-4b</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>gemma3-4b</td><td>404.4 ms</td><td>39.0 ms</td><td>25.6 tok/s</td><td>156842 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma3_4b_sweep">
|
||||
<h2>Gemma3-4B Sweep</h2>
|
||||
<p class="meta">unsloth/gemma-3-4b-it · 11 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig15" class="chart"></div><script>Plotly.newPlot("fig15", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [388.19, 436.49, 386.13], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig16" class="chart"></div><script>Plotly.newPlot("fig16", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [37.47, 41.95, 37.25], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig17" class="chart"></div><script>Plotly.newPlot("fig17", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [102.18644, 186.34269, 498.48983000000004], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr style="background:#fff0f0"><td>luminal backend</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>388.2 ms</td><td>37.5 ms</td><td>26.7 tok/s</td><td>102186 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>s=100</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>436.5 ms</td><td>42.0 ms</td><td>23.8 tok/s</td><td>186343 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>s=500</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>386.1 ms</td><td>37.2 ms</td><td>26.8 tok/s</td><td>498490 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma4_moe_comparison">
|
||||
<h2>Gemma4-Moe Comparison</h2>
|
||||
<p class="meta">google/gemma-4-26B-A4B · 11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig18" class="chart"></div><script>Plotly.newPlot("fig18", [{"type": "bar", "name": "HF Baseline", "x": ["gemma4-moe"], "y": [837.3980740143452], "text": ["837.4 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [245.510076492792], "text": ["245.5 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig19" class="chart"></div><script>Plotly.newPlot("fig19", [{"type": "bar", "name": "HF Baseline", "x": ["gemma4-moe"], "y": [83.64427039632574], "text": ["83.6 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [654.9649795080768], "text": ["655.0 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig20" class="chart"></div><script>Plotly.newPlot("fig20", [{"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [38.81582092499593], "text": ["38.8 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>gemma4-moe</td><td>837.4 ms</td><td>83.6 ms</td><td>12.0 tok/s</td><td>—</td><td>11</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>gemma4-moe</td><td>245.5 ms</td><td>655.0 ms</td><td>1.5 tok/s</td><td>38816 ms</td><td>11</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>gemma4-moe</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code -9</td></tr>
|
||||
<tr style="background:#fff0f0"><td>Rust (luminal)</td><td>gemma4-moe</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">rust bench failed with code -9</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma4_moe_sweep">
|
||||
<h2>Gemma4-Moe Sweep</h2>
|
||||
<p class="meta">google/gemma-4-26B-A4B · 2 results</p>
|
||||
<div class="charts-row"><div id="fig21" class="chart"></div><script>Plotly.newPlot("fig21", [], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10], "ticktext": ["10"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr style="background:#fff0f0"><td>luminal backend</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code -9</td></tr>
|
||||
<tr style="background:#fff0f0"><td>Rust (luminal)</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">rust bench failed with code -9</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_moe_comparison">
|
||||
<h2>Qwen3-Moe Comparison</h2>
|
||||
<p class="meta">Qwen/Qwen3-30B-A3B · 19 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig22" class="chart"></div><script>Plotly.newPlot("fig22", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-moe"], "y": [1565.540504961973], "text": ["1565.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [460.077923577046], "text": ["460.1 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [21002.791983017232], "text": ["21002.8 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [662.07], "text": ["662.1 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig23" class="chart"></div><script>Plotly.newPlot("fig23", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-moe"], "y": [84.527321747737], "text": ["84.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [753.0061075551203], "text": ["753.0 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [1166.8824461026816], "text": ["1166.9 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [60.08], "text": ["60.1 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig24" class="chart"></div><script>Plotly.newPlot("fig24", [{"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [8.341281775035895], "text": ["8.3 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [111.70731823903043], "text": ["111.7 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [80.83241000000001], "text": ["80.8 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>qwen3-moe</td><td>1565.5 ms</td><td>84.5 ms</td><td>11.8 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>qwen3-moe</td><td>460.1 ms</td><td>753.0 ms</td><td>1.3 tok/s</td><td>8341 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>qwen3-moe</td><td>21002.8 ms</td><td>1166.9 ms</td><td>0.9 tok/s</td><td>111707 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>qwen3-moe</td><td>662.1 ms</td><td>60.1 ms</td><td>16.6 tok/s</td><td>80832 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_moe_sweep">
|
||||
<h2>Qwen3-Moe Sweep</h2>
|
||||
<p class="meta">Qwen/Qwen3-30B-A3B · 19 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig25" class="chart"></div><script>Plotly.newPlot("fig25", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [21002.663500519702, 21018.686580006033, 21034.366824431345], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [656.7, 540.37, 542.34], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig26" class="chart"></div><script>Plotly.newPlot("fig26", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [1166.6714247548953, 1167.2746865515364, 1168.7990181031637], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [59.6, 48.79, 48.88], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig27" class="chart"></div><script>Plotly.newPlot("fig27", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [93.47603664599592, 132.266081985028, 298.05094401398674], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [25.48138, 47.5342, 134.79345], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>21002.7 ms</td><td>1166.7 ms</td><td>0.9 tok/s</td><td>93476 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>656.7 ms</td><td>59.6 ms</td><td>16.8 tok/s</td><td>25481 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>21018.7 ms</td><td>1167.3 ms</td><td>0.9 tok/s</td><td>132266 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>540.4 ms</td><td>48.8 ms</td><td>20.5 tok/s</td><td>47534 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>21034.4 ms</td><td>1168.8 ms</td><td>0.9 tok/s</td><td>298051 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>542.3 ms</td><td>48.9 ms</td><td>20.5 tok/s</td><td>134793 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
683
benchmarks/ttft/run.py
Normal file
683
benchmarks/ttft/run.py
Normal file
@@ -0,0 +1,683 @@
|
||||
"""TTFT + TPOT benchmark orchestrator.
|
||||
|
||||
Runs four paths in isolated subprocesses:
|
||||
1. python_baseline — HuggingFace / PyTorch eager on CUDA
|
||||
2. python_torch_compile — torch.compile(model) inductor backend
|
||||
3. python_luminal — torch.compile(model, backend=luminal_backend)
|
||||
4. rust — examples/<package> binary (luminal_cuda_lite)
|
||||
|
||||
Use --config to select a named configuration, or --all-configs to run every
|
||||
entry in CONFIGS. All output is written to the SQLite bench DB
|
||||
(benchmarks/ttft/bench.db); the TUI / dashboard / report read from there.
|
||||
|
||||
Notes on comparability:
|
||||
- python_baseline: single chunked forward for TTFT; KV-cache decode for TPOT.
|
||||
- python_torch_compile: inductor, same chunked prefill as baseline; first
|
||||
call triggers JIT compilation (recorded separately as compile_ms).
|
||||
- python_luminal: sequential per-token prefill with StaticCache; TPOT via
|
||||
autoregressive decode steps.
|
||||
- rust: sequential per-token prefill; TTFT = sum of prefill step durations.
|
||||
Steady-state execution only — compile / egraph-search time excluded from TTFT but
|
||||
recorded separately as compile_ms for all paths that support it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
try:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
raise ImportError("Python 3.11+ or 'pip install tomli' required to load benchmarks.toml")
|
||||
|
||||
import db
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
REPO_ROOT = BENCH_DIR.parent.parent
|
||||
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
_CONFIG_PATH = BENCH_DIR / "benchmarks.toml"
|
||||
with open(_CONFIG_PATH, "rb") as _f:
|
||||
_BENCH_CONFIG = tomllib.load(_f)
|
||||
|
||||
# Named benchmark configurations. Each entry overrides any subset of the
|
||||
# CLI defaults; explicit CLI flags always take precedence over the config.
|
||||
CONFIGS: dict = _BENCH_CONFIG["configs"]
|
||||
UR_TEST_MODELS: list = _BENCH_CONFIG["ur_test"]["models"]
|
||||
SEARCH_SWEEP_ITERS: list = _BENCH_CONFIG["ur_test"]["search_sweep_iters"]
|
||||
|
||||
SWEEP_CONFIG_PREFIX = "s="
|
||||
|
||||
BENCH_LINE = re.compile(r"^BENCH_RESULT (.*)$", re.MULTILINE)
|
||||
RUST_TTFT_LINE = re.compile(r"TTFT:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_TPOT_LINE = re.compile(r"TPOT:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_COMPILE_LINE = re.compile(r"COMPILE:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_PROMPT_LINE = re.compile(r"Prompt:\s*(\d+)\s*tokens")
|
||||
|
||||
|
||||
def _stream(proc, tee_prefix):
|
||||
"""Drain subprocess stdout, tee-ing to our stdout line-by-line. Returns full stdout."""
|
||||
buf = []
|
||||
assert proc.stdout is not None
|
||||
for line in proc.stdout:
|
||||
buf.append(line)
|
||||
sys.stdout.write(f"[{tee_prefix}] {line}")
|
||||
sys.stdout.flush()
|
||||
proc.wait()
|
||||
return "".join(buf)
|
||||
|
||||
|
||||
_MEM_LOG_PATH = os.environ.get("BENCH_MEM_LOG", "/tmp/bench_mem_snapshots.log")
|
||||
|
||||
|
||||
def _snapshot_memory(label: str) -> None:
|
||||
"""Append a host+GPU memory snapshot to BENCH_MEM_LOG. Cheap, never raises."""
|
||||
try:
|
||||
ts = datetime.datetime.now().isoformat(timespec="seconds")
|
||||
meminfo_keys = ("MemTotal", "MemFree", "MemAvailable", "Cached", "Slab", "SReclaimable")
|
||||
meminfo = {}
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
k, _, rest = line.partition(":")
|
||||
if k in meminfo_keys:
|
||||
meminfo[k] = rest.strip().split()[0] # kB
|
||||
try:
|
||||
gpu = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used,memory.free,memory.total",
|
||||
"--format=csv,noheader,nounits"],
|
||||
stderr=subprocess.DEVNULL, text=True, timeout=5,
|
||||
).strip().splitlines()[0]
|
||||
except Exception:
|
||||
gpu = "n/a"
|
||||
parent_rss = "?"
|
||||
try:
|
||||
with open(f"/proc/{os.getpid()}/status") as f:
|
||||
for line in f:
|
||||
if line.startswith("VmRSS:"):
|
||||
parent_rss = line.split()[1]
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
host_str = " ".join(f"{k}={meminfo.get(k, '?')}kB" for k in meminfo_keys)
|
||||
with open(_MEM_LOG_PATH, "a") as f:
|
||||
f.write(f"{ts} [{label}] parent_rss={parent_rss}kB {host_str} gpu(used,free,total MiB)={gpu}\n")
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"[mem-snapshot warn] {e}\n")
|
||||
|
||||
|
||||
def _cargo_env():
|
||||
"""Return env dict with ~/.cargo/bin prepended to PATH."""
|
||||
cargo_bin = str(Path.home() / ".cargo" / "bin")
|
||||
path = os.environ.get("PATH", "")
|
||||
if cargo_bin not in path:
|
||||
path = f"{cargo_bin}:{path}"
|
||||
return {**os.environ, "PATH": path}
|
||||
|
||||
|
||||
def run_rust(_prompt, package="llama", env_vars=None):
|
||||
print(f"\n=== Running: rust (examples/{package}) ===", flush=True)
|
||||
cmd = ["cargo", "run", "--release", "-p", package]
|
||||
env = _cargo_env()
|
||||
if env_vars:
|
||||
env.update(env_vars)
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=REPO_ROOT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
output = _stream(proc, "rust")
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"rust bench failed with code {proc.returncode}")
|
||||
m = RUST_TTFT_LINE.search(output)
|
||||
if not m:
|
||||
raise RuntimeError("could not find 'TTFT: X ms' in rust stdout")
|
||||
ttft_ms = float(m.group(1))
|
||||
result = {
|
||||
"path": "rust",
|
||||
"model": DEFAULT_MODEL,
|
||||
"ttft_ms": ttft_ms,
|
||||
"note": "sum of per-token prefill durations",
|
||||
}
|
||||
m_compile = RUST_COMPILE_LINE.search(output)
|
||||
if m_compile:
|
||||
result["compile_ms"] = float(m_compile.group(1))
|
||||
m_tpot = RUST_TPOT_LINE.search(output)
|
||||
if m_tpot:
|
||||
tpot_ms = float(m_tpot.group(1))
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
m_prompt = RUST_PROMPT_LINE.search(output)
|
||||
if m_prompt:
|
||||
result["prompt_tokens"] = int(m_prompt.group(1))
|
||||
return result
|
||||
|
||||
|
||||
def run_python_script(name, extra_args):
|
||||
script = BENCH_DIR / name
|
||||
print(f"\n=== Running: {script.name} ===", flush=True)
|
||||
cmd = [sys.executable, str(script), *extra_args]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=REPO_ROOT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env={**os.environ},
|
||||
)
|
||||
output = _stream(proc, script.stem)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"{script.name} failed with code {proc.returncode}")
|
||||
m = BENCH_LINE.search(output)
|
||||
if not m:
|
||||
raise RuntimeError(f"no BENCH_RESULT line in {script.name} output")
|
||||
return json.loads(m.group(1))
|
||||
|
||||
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "Python\n(HF baseline)",
|
||||
"python_torch_compile": "Python\n(torch.compile)",
|
||||
"python_luminal": "Python → Rust\n(luminal_backend)",
|
||||
"rust": "Rust\n(examples/llama)",
|
||||
}
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#888888",
|
||||
"python_torch_compile": "#5ab552",
|
||||
"python_luminal": "#4c9ed9",
|
||||
"rust": "#d97a4c",
|
||||
}
|
||||
|
||||
|
||||
def run_one_config(config_name, settings, global_skip, inter_path_cooldown=0):
|
||||
"""Run all four paths for one config. Returns list of result dicts tagged with 'config'."""
|
||||
model = settings["model"]
|
||||
rust_package = settings["rust_package"]
|
||||
prompt = settings["prompt"]
|
||||
iters = settings["iters"]
|
||||
warmups = settings["warmups"]
|
||||
decode_tokens = settings["decode_tokens"]
|
||||
search_iters = settings["search_iters"]
|
||||
dtype = settings.get("dtype", "float32")
|
||||
skip = set(global_skip) | set(settings.get("skip", []))
|
||||
|
||||
common_py = [
|
||||
"--model", model,
|
||||
"--prompt", prompt,
|
||||
"--iters", str(iters),
|
||||
"--warmups", str(warmups),
|
||||
"--decode-tokens", str(decode_tokens),
|
||||
"--dtype", dtype,
|
||||
]
|
||||
luminal_py = common_py + ["--search-iters", str(search_iters)]
|
||||
|
||||
rust_env = {"SEARCH_GRAPHS": str(search_iters), "PROMPT": prompt, "ITERS": str(iters)}
|
||||
|
||||
results = []
|
||||
first_path = True
|
||||
for path, fn in [
|
||||
("python_baseline", lambda: run_python_script("bench_python_baseline.py", common_py)),
|
||||
("python_torch_compile", lambda: run_python_script("bench_python_torch_compile.py", common_py)),
|
||||
("python_luminal", lambda: run_python_script("bench_python_luminal.py", luminal_py)),
|
||||
("rust", lambda: run_rust(prompt, package=rust_package, env_vars=rust_env)),
|
||||
]:
|
||||
if path in skip:
|
||||
continue
|
||||
if not first_path and inter_path_cooldown > 0:
|
||||
print(f" [cooldown {inter_path_cooldown}s]", flush=True)
|
||||
time.sleep(inter_path_cooldown)
|
||||
first_path = False
|
||||
_snapshot_memory(f"{config_name}/{path} BEFORE")
|
||||
try:
|
||||
r = fn()
|
||||
r["config"] = config_name
|
||||
r["model"] = model # ensure correct model is always tagged
|
||||
if path in ("python_luminal", "rust"):
|
||||
r["search_iters"] = search_iters
|
||||
results.append(r)
|
||||
except Exception as e:
|
||||
print(f"\n[WARN] {config_name}/{path} failed: {e}", flush=True)
|
||||
results.append({
|
||||
"path": path,
|
||||
"config": config_name,
|
||||
"model": model,
|
||||
"error": str(e),
|
||||
"ttft_ms": None,
|
||||
})
|
||||
_snapshot_memory(f"{config_name}/{path} AFTER")
|
||||
return results
|
||||
|
||||
|
||||
def plot(results, out_path):
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Group by config so each config gets its own subplot column.
|
||||
configs_seen: list[str] = []
|
||||
by_config: dict[str, dict] = {}
|
||||
for r in results:
|
||||
cfg = r.get("config", "default")
|
||||
if cfg not in by_config:
|
||||
configs_seen.append(cfg)
|
||||
by_config[cfg] = {}
|
||||
by_config[cfg][r["path"]] = r
|
||||
|
||||
has_tpot = any(
|
||||
r.get("tpot_ms") is not None
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
)
|
||||
nrows = 2 if has_tpot else 1
|
||||
ncols = len(configs_seen)
|
||||
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4.5 * nrows), squeeze=False)
|
||||
|
||||
for col, cfg in enumerate(configs_seen):
|
||||
by_path = by_config[cfg]
|
||||
present = [p for p in PATH_ORDER if p in by_path]
|
||||
|
||||
def _bar(ax, title, ylabel, key):
|
||||
raw = [by_path[p].get(key) for p in present]
|
||||
ys = [v if v is not None else 0.0 for v in raw]
|
||||
cs = [PATH_COLORS.get(p, "#aaaaaa") if raw[i] is not None else "#cccccc"
|
||||
for i, p in enumerate(present)]
|
||||
xs = [PATH_LABELS.get(p, p) for p in present]
|
||||
bars = ax.bar(xs, ys, color=cs)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_title(f"{title} — {cfg}")
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
for b, v in zip(bars, raw):
|
||||
if v is not None:
|
||||
ax.text(b.get_x() + b.get_width() / 2, v, f"{v:.0f} ms",
|
||||
ha="center", va="bottom", fontsize=9)
|
||||
|
||||
_bar(axes[0][col], "TTFT", "Time to first token (ms)", "ttft_ms")
|
||||
if has_tpot:
|
||||
_bar(axes[1][col], "TPOT", "Time per output token (ms)", "tpot_ms")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=150)
|
||||
print(f"wrote {out_path}")
|
||||
|
||||
|
||||
def run_ur_test(args, conn, run_id):
|
||||
"""The ur-test: all 4 paths at default budget + full search sweep, for each model.
|
||||
|
||||
Inserts each result into the DB as it is produced so a mid-run crash still
|
||||
leaves partial data behind.
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
for model_idx, model_key in enumerate(UR_TEST_MODELS):
|
||||
s = _settings_for_config(model_key, args)
|
||||
|
||||
if model_idx > 0:
|
||||
print(f"\n [cooldown 30s between models]", flush=True)
|
||||
time.sleep(30)
|
||||
|
||||
# ── Phase 1: comparison — all 4 paths at the model's default search budget ──
|
||||
print(f"\n{'='*60}\nUR-TEST COMPARISON: {model_key}\n{'='*60}", flush=True)
|
||||
comp_results = run_one_config(model_key, s, args.skip, inter_path_cooldown=20)
|
||||
for r in comp_results:
|
||||
r["model_key"] = model_key
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
all_results.extend(comp_results)
|
||||
|
||||
# ── Phase 2: search sweep — python_luminal + rust across all budgets ──
|
||||
if args.no_sweep:
|
||||
continue
|
||||
print(f"\n{'='*60}\nUR-TEST SWEEP: {model_key}\n{'='*60}", flush=True)
|
||||
sweep_skip_base = set(args.skip) | {"python_baseline", "python_torch_compile"}
|
||||
# Memory peak in egglog Search grows monotonically with search-iters.
|
||||
# If a path SIGKILLs (-9) at budget N, every higher budget will too —
|
||||
# skip it to avoid wasting another ~hour per model on guaranteed OOMs.
|
||||
oom_paths: set[str] = set()
|
||||
for n in SEARCH_SWEEP_ITERS:
|
||||
print(f" [cooldown 20s before s={n}]", flush=True)
|
||||
time.sleep(20)
|
||||
sweep_skip = list(sweep_skip_base | oom_paths)
|
||||
if oom_paths:
|
||||
print(f" [skip-on-prior-OOM] {sorted(oom_paths)} OOM'd at lower budget; skipping at s={n}", flush=True)
|
||||
sweep_s = {**s, "search_iters": n}
|
||||
results_n = run_one_config(f"s={n}", sweep_s, sweep_skip, inter_path_cooldown=20)
|
||||
for r in results_n:
|
||||
r["model_key"] = model_key # preserve ur-test model identity for dashboard
|
||||
db.insert_result(conn, run_id, r)
|
||||
if "code -9" in (r.get("error") or ""):
|
||||
oom_paths.add(r["path"])
|
||||
conn.commit()
|
||||
all_results.extend(results_n)
|
||||
|
||||
print("\nGenerate report with:")
|
||||
print(f" python3 benchmarks/ttft/gen_report.py --db benchmarks/ttft/bench.db --run {run_id} \\")
|
||||
print(" --out benchmarks/ttft/report.html")
|
||||
print("\nGenerate dashboard with:")
|
||||
print(" python3 benchmarks/ttft/gen_dashboard.py --out benchmarks/ttft/dashboard.html")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def _git_info():
|
||||
"""Return (short_commit, branch) from the repo, or ('unknown', 'unknown') if unavailable."""
|
||||
try:
|
||||
commit = subprocess.check_output(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
cwd=REPO_ROOT, stderr=subprocess.DEVNULL, text=True,
|
||||
).strip()
|
||||
branch = subprocess.check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=REPO_ROOT, stderr=subprocess.DEVNULL, text=True,
|
||||
).strip()
|
||||
return commit, branch
|
||||
except Exception:
|
||||
return "unknown", "unknown"
|
||||
|
||||
|
||||
def _gpu_info() -> dict:
|
||||
"""Return GPU metadata from nvidia-smi, or empty dict if unavailable."""
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=name,driver_version,memory.total",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
).strip()
|
||||
if not out:
|
||||
return {}
|
||||
parts = [p.strip() for p in out.splitlines()[0].split(",")]
|
||||
if len(parts) < 3:
|
||||
return {}
|
||||
return {
|
||||
"gpu_name": parts[0],
|
||||
"gpu_driver": parts[1],
|
||||
"gpu_vram_mb": int(parts[2]),
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _cuda_version() -> str:
|
||||
"""Return CUDA version string from nvidia-smi, or 'unknown'."""
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["nvidia-smi", "--query", "--display=COMPUTE"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in out.splitlines():
|
||||
if "CUDA Version" in line:
|
||||
return line.split(":")[-1].strip()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["nvidia-smi"], stderr=subprocess.DEVNULL, text=True
|
||||
)
|
||||
import re as _re
|
||||
m = _re.search(r"CUDA Version:\s*([\d.]+)", out)
|
||||
if m:
|
||||
return m.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _record_run(conn, mode):
|
||||
"""Insert a `runs` row capturing this orchestrator invocation. Returns run_id.
|
||||
|
||||
Uses microsecond resolution in the run_id so two invocations within the
|
||||
same wallclock second never collide on the runs PRIMARY KEY (insert_run
|
||||
defaults to OR IGNORE, which would otherwise silently merge them and
|
||||
corrupt history). Microseconds also let the dashboard plot back-to-back
|
||||
runs at distinct x-positions instead of stacking them on one date label.
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
run_id = now.strftime("%Y-%m-%dT%H-%M-%S-%f")
|
||||
commit, branch = _git_info()
|
||||
db.insert_run(
|
||||
conn,
|
||||
run_id=run_id,
|
||||
timestamp=now.isoformat(),
|
||||
mode=mode,
|
||||
git_commit=commit,
|
||||
git_branch=branch,
|
||||
cuda_version=_cuda_version(),
|
||||
**_gpu_info(),
|
||||
)
|
||||
conn.commit()
|
||||
return run_id
|
||||
|
||||
|
||||
def _settings_from_args(args):
|
||||
"""Build a settings dict from parsed CLI args."""
|
||||
return {
|
||||
"model": args.model,
|
||||
"rust_package": args.rust_package,
|
||||
"prompt": args.prompt,
|
||||
"iters": args.iters,
|
||||
"warmups": args.warmups,
|
||||
"decode_tokens": args.decode_tokens,
|
||||
"search_iters": args.search_iters,
|
||||
"dtype": args.dtype,
|
||||
"skip": [],
|
||||
}
|
||||
|
||||
|
||||
def _settings_for_config(config_name, args):
|
||||
"""Merge CONFIGS[config_name] over CLI arg defaults."""
|
||||
cfg = CONFIGS[config_name]
|
||||
return {
|
||||
"model": cfg.get("model", args.model),
|
||||
"rust_package": cfg.get("rust_package", args.rust_package),
|
||||
"prompt": cfg.get("prompt", args.prompt),
|
||||
"iters": cfg.get("iters", args.iters),
|
||||
"warmups": cfg.get("warmups", args.warmups),
|
||||
"decode_tokens":cfg.get("decode_tokens",args.decode_tokens),
|
||||
"search_iters": cfg.get("search_iters", args.search_iters),
|
||||
"dtype": cfg.get("dtype", args.dtype),
|
||||
"skip": cfg.get("skip", []),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
choices=list(CONFIGS),
|
||||
default=None,
|
||||
help="Named benchmark configuration. Sets parameter defaults; explicit flags override.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--all-configs",
|
||||
action="store_true",
|
||||
dest="all_configs",
|
||||
help="Run every entry in CONFIGS into a single run_id in the DB.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--search-sweep",
|
||||
action="store_true",
|
||||
dest="search_sweep",
|
||||
help=(
|
||||
"Run python_luminal + rust across all SEARCH_SWEEP_ITERS budgets "
|
||||
f"({SEARCH_SWEEP_ITERS}). Uses --config (default: llama-8b) as the base settings."
|
||||
),
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-configs",
|
||||
nargs="*",
|
||||
default=[],
|
||||
choices=list(CONFIGS),
|
||||
dest="skip_configs",
|
||||
metavar="CONFIG",
|
||||
help="Config names to exclude when using --all-configs.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-sweep",
|
||||
action="store_true",
|
||||
dest="no_sweep",
|
||||
help=(
|
||||
"With --ur-test: skip the search-budget sweep phase and only run "
|
||||
"the 4-path comparison for each model. ~1.5 hr instead of ~5 hr."
|
||||
),
|
||||
)
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--rust-package", default="llama", dest="rust_package",
|
||||
help="Cargo package name for the rust bench (examples/<name>).")
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--skip", nargs="*", default=[],
|
||||
choices=["rust", "python_luminal", "python_baseline", "python_torch_compile"])
|
||||
ap.add_argument("--out", default=str(BENCH_DIR / "ttft.png"))
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help="SQLite database file (default: benchmarks/ttft/bench.db).")
|
||||
ap.add_argument("--run", default=None, dest="run",
|
||||
help="With --render-only: run_id to render (default: latest).")
|
||||
ap.add_argument(
|
||||
"--decode-tokens", type=int, default=50,
|
||||
help="Tokens to generate for TPOT measurement (0 = skip TPOT).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--search-iters", type=int, default=500,
|
||||
help="Egraph search iterations for the python_luminal path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dtype", default="float32",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
help="Torch dtype for the python paths. Configs may override per-model.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--render-only", action="store_true",
|
||||
help="Skip running benches; render an existing run from the DB. "
|
||||
"Use --run RUN_ID to pick a specific run, otherwise the latest is used.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--ur-test", action="store_true", dest="ur_test",
|
||||
help=(
|
||||
f"The mega-test: run all 4 paths at default budget + full search sweep "
|
||||
f"({SEARCH_SWEEP_ITERS}) for each of {UR_TEST_MODELS}."
|
||||
),
|
||||
)
|
||||
|
||||
# Pre-parse to apply named config as argparse defaults so explicit CLI
|
||||
# flags still override them.
|
||||
pre, _ = ap.parse_known_args()
|
||||
if pre.config and not (pre.all_configs or getattr(pre, "search_sweep", False)):
|
||||
cfg = CONFIGS[pre.config]
|
||||
ap.set_defaults(**{k: v for k, v in cfg.items() if k not in ("skip",)})
|
||||
args = ap.parse_args()
|
||||
if pre.config and not args.all_configs and not args.search_sweep:
|
||||
for path in CONFIGS[pre.config].get("skip", []):
|
||||
if path not in args.skip:
|
||||
args.skip.append(path)
|
||||
|
||||
conn = db.connect(args.db)
|
||||
|
||||
if args.render_only:
|
||||
run_id = args.run or db.latest_run_id(conn)
|
||||
if run_id is None:
|
||||
sys.exit(f"--render-only: no runs found in {args.db}")
|
||||
results = db.load_results(conn, run_id)
|
||||
if not results:
|
||||
sys.exit(f"--render-only: no results found for run {run_id} in {args.db}")
|
||||
print(f"rendering run {run_id} ({len(results)} results)")
|
||||
else:
|
||||
mode = (
|
||||
("ur-test-fast" if args.no_sweep else "ur-test") if args.ur_test
|
||||
else "search-sweep" if args.search_sweep
|
||||
else "all-configs" if args.all_configs
|
||||
else "single"
|
||||
)
|
||||
run_id = _record_run(conn, mode)
|
||||
print(f"run_id: {run_id} → {args.db}")
|
||||
|
||||
if args.ur_test:
|
||||
results = run_ur_test(args, conn, run_id)
|
||||
elif args.search_sweep:
|
||||
results = []
|
||||
# Base settings come from --config (default: llama-8b) or bare CLI args.
|
||||
base = (
|
||||
_settings_for_config(args.config, args)
|
||||
if args.config
|
||||
else _settings_for_config("llama-8b", args)
|
||||
)
|
||||
sweep_skip = set(args.skip) | {"python_baseline", "python_torch_compile"}
|
||||
for i, n in enumerate(SEARCH_SWEEP_ITERS):
|
||||
if i > 0:
|
||||
print(f" [cooldown 20s — letting CUDA free previous model memory]", flush=True)
|
||||
time.sleep(20)
|
||||
print(f"\n{'='*60}\nSEARCH SWEEP: s={n}\n{'='*60}", flush=True)
|
||||
s = {**base, "search_iters": n}
|
||||
rs = run_one_config(f"s={n}", s, list(sweep_skip))
|
||||
for r in rs:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
results.extend(rs)
|
||||
elif args.all_configs:
|
||||
results = []
|
||||
for config_name in CONFIGS:
|
||||
if config_name in args.skip_configs:
|
||||
continue
|
||||
print(f"\n{'='*60}\nCONFIG: {config_name}\n{'='*60}", flush=True)
|
||||
settings = _settings_for_config(config_name, args)
|
||||
rs = run_one_config(config_name, settings, args.skip)
|
||||
for r in rs:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
results.extend(rs)
|
||||
else:
|
||||
config_name = args.config or "default"
|
||||
settings = (
|
||||
_settings_for_config(args.config, args)
|
||||
if args.config
|
||||
else _settings_from_args(args)
|
||||
)
|
||||
results = run_one_config(config_name, settings, args.skip)
|
||||
for r in results:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
|
||||
# Summary
|
||||
configs_in_results = list(dict.fromkeys(r.get("config", "default") for r in results))
|
||||
for cfg in configs_in_results:
|
||||
group = [r for r in results if r.get("config", "default") == cfg]
|
||||
print(f"\nSummary ({cfg}):")
|
||||
for r in group:
|
||||
if r.get("error"):
|
||||
print(f" {r['path']:>22}: FAILED — {r['error']}")
|
||||
continue
|
||||
if r.get("ttft_ms") is None:
|
||||
print(f" {r['path']:>22}: no data")
|
||||
continue
|
||||
compile_ms = r.get("compile_ms")
|
||||
compile_str = f" compile {compile_ms:.0f} ms" if compile_ms is not None else ""
|
||||
tpot = r.get("tpot_ms")
|
||||
tput = r.get("throughput_tps")
|
||||
tpot_str = f" TPOT {tpot:.2f} ms ({tput:.1f} tok/s)" if tpot is not None else ""
|
||||
print(f" {r['path']:>22}: TTFT {r['ttft_ms']:.2f} ms{compile_str}{tpot_str}")
|
||||
|
||||
plot(results, args.out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
benchmarks/ttft/run.sh
Executable file
7
benchmarks/ttft/run.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
# TTFT benchmark entrypoint. Runs via uv against the luminal_python venv.
|
||||
set -e
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
REPO_ROOT="$( cd "$SCRIPT_DIR/../.." && pwd )"
|
||||
cd "$REPO_ROOT/crates/luminal_python"
|
||||
exec uv run python "$SCRIPT_DIR/run.py" "$@"
|
||||
BIN
benchmarks/ttft/ttft.png
Normal file
BIN
benchmarks/ttft/ttft.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 272 KiB |
@@ -28,7 +28,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -47,7 +47,6 @@ def run_cargo_test():
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
@@ -21,79 +18,6 @@ hf_cache = modal.Volume.from_name(
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"llama": [
|
||||
"complex system modeled after the structure and function of the human brain",
|
||||
],
|
||||
"gemma": [
|
||||
"recognize pictures of cats",
|
||||
"little detectives looking for specific features",
|
||||
],
|
||||
"qwen": [
|
||||
"computational model inspired by the structure and function of the human brain",
|
||||
],
|
||||
"qwen3_moe": [
|
||||
"The capital of France is Paris",
|
||||
],
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
normalized_output = normalize_output(output)
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
@@ -115,7 +39,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
@@ -124,17 +48,16 @@ def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
subprocess.run(
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
validate_output(example, output)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
@@ -24,7 +23,6 @@ memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
lru = "0.16.2"
|
||||
libc = "0.2"
|
||||
libloading = "0.8"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
use std::{collections::BTreeMap, sync::Arc, time::Instant};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::prelude::egglog::{ast::Span, prelude::RustSpan};
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
base::{base_cleanup_egglog, base_expression_egglog},
|
||||
hlir_to_egglog,
|
||||
},
|
||||
hlir::HLIROps,
|
||||
op::{EgglogOp, IntoEgglogOp, Runtime},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
const DEFAULT_PASSES: usize = 256;
|
||||
const EGGLOG_RULESETS: &[&str] = &[
|
||||
"matmul_flatten",
|
||||
"kernel_lower",
|
||||
"direct_kernel",
|
||||
"kernel_specialize",
|
||||
"buffer_reuse",
|
||||
"matmul_backend",
|
||||
"glumoe",
|
||||
"fusion_pair",
|
||||
"fusion_grow",
|
||||
"fusion_merge",
|
||||
];
|
||||
const MOE_SEQ: usize = 2;
|
||||
const MOE_HIDDEN: usize = 16;
|
||||
const MOE_NUM_EXPERTS: usize = 8;
|
||||
const MOE_TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const GEMMA_RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Backend {
|
||||
Native,
|
||||
Cuda,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Mode {
|
||||
Current,
|
||||
Steps,
|
||||
FullDefault,
|
||||
FullCycle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Case {
|
||||
Mul,
|
||||
UnaryChain(usize),
|
||||
Gelu,
|
||||
Softmax,
|
||||
LayerNorm,
|
||||
Matmul,
|
||||
Attention,
|
||||
QwenMoe,
|
||||
GemmaMoe,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
backend: Backend,
|
||||
mode: Mode,
|
||||
case: Case,
|
||||
passes: usize,
|
||||
cleanup: bool,
|
||||
skip_roll: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args {
|
||||
backend: Backend::Cuda,
|
||||
mode: Mode::Current,
|
||||
case: Case::Gelu,
|
||||
passes: DEFAULT_PASSES,
|
||||
cleanup: true,
|
||||
skip_roll: false,
|
||||
};
|
||||
|
||||
let mut iter = std::env::args().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--backend" => {
|
||||
args.backend = match iter.next().as_deref() {
|
||||
Some("native") => Backend::Native,
|
||||
Some("cuda") => Backend::Cuda,
|
||||
other => panic!("invalid --backend {other:?}; use native|cuda"),
|
||||
};
|
||||
}
|
||||
"--mode" => {
|
||||
args.mode = match iter.next().as_deref() {
|
||||
Some("current") => Mode::Current,
|
||||
Some("steps") => Mode::Steps,
|
||||
Some("full-default") => Mode::FullDefault,
|
||||
Some("full-cycle") => Mode::FullCycle,
|
||||
other => panic!(
|
||||
"invalid --mode {other:?}; use current|steps|full-default|full-cycle"
|
||||
),
|
||||
};
|
||||
}
|
||||
"--case" => {
|
||||
args.case = parse_case(&iter.next().expect("missing --case value"));
|
||||
}
|
||||
"--passes" => {
|
||||
args.passes = iter
|
||||
.next()
|
||||
.expect("missing --passes value")
|
||||
.parse()
|
||||
.expect("invalid --passes value");
|
||||
}
|
||||
"--no-cleanup" => args.cleanup = false,
|
||||
"--skip-roll" => args.skip_roll = true,
|
||||
"--help" | "-h" => {
|
||||
println!(
|
||||
"Usage: egglog_saturation [OPTIONS]\n\
|
||||
\n\
|
||||
Options:\n\
|
||||
--backend native|cuda default: cuda\n\
|
||||
--mode current|steps|full-default|full-cycle\n\
|
||||
--case mul|unary-chain:N|gelu|softmax|layer-norm|matmul|attention|qwen-moe|gemma-moe\n\
|
||||
--passes N default: 256\n\
|
||||
--no-cleanup omit backend/HLIR cleanup rules\n\
|
||||
--skip-roll skip auto loop rolling prepass"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => panic!("unknown argument {other}; use --help"),
|
||||
}
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
fn parse_case(s: &str) -> Case {
|
||||
if let Some(n) = s.strip_prefix("unary-chain:") {
|
||||
return Case::UnaryChain(n.parse().expect("invalid unary-chain length"));
|
||||
}
|
||||
match s {
|
||||
"mul" => Case::Mul,
|
||||
"gelu" => Case::Gelu,
|
||||
"softmax" => Case::Softmax,
|
||||
"layer-norm" | "layer_norm" => Case::LayerNorm,
|
||||
"matmul" => Case::Matmul,
|
||||
"attention" => Case::Attention,
|
||||
"qwen-moe" | "qwen_moe" => Case::QwenMoe,
|
||||
"gemma-moe" | "gemma_moe" => Case::GemmaMoe,
|
||||
other => panic!("unknown case {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_case(case: Case) -> Graph {
|
||||
let mut cx = Graph::new();
|
||||
let out = match case {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor((64, 64));
|
||||
x * x
|
||||
}
|
||||
Case::UnaryChain(n) => {
|
||||
let mut x = cx.tensor((64, 64));
|
||||
for i in 0..n {
|
||||
x = match i % 6 {
|
||||
0 => x.sin(),
|
||||
1 => x.sqrt(),
|
||||
2 => x.reciprocal(),
|
||||
3 => x.exp2(),
|
||||
4 => x.log2(),
|
||||
_ => x * 1.125,
|
||||
};
|
||||
}
|
||||
x
|
||||
}
|
||||
Case::Gelu => cx.tensor((64, 64)).gelu(),
|
||||
Case::Softmax => cx.tensor((128, 128)).softmax(1),
|
||||
Case::LayerNorm => cx.tensor((128, 128)).layer_norm(1, 1e-5),
|
||||
Case::Matmul => {
|
||||
let a = cx.tensor((32, 64));
|
||||
let b = cx.tensor((64, 32));
|
||||
a.matmul(b)
|
||||
}
|
||||
Case::Attention => {
|
||||
let q = cx.tensor((64, 32));
|
||||
let k = cx.tensor((64, 32));
|
||||
let v = cx.tensor((64, 32));
|
||||
let scores = q.matmul(k.permute((1, 0))) * (1.0 / 32.0_f32.sqrt());
|
||||
scores.softmax(1).matmul(v)
|
||||
}
|
||||
Case::QwenMoe => build_qwen_moe(&mut cx),
|
||||
Case::GemmaMoe => build_gemma_moe(&mut cx),
|
||||
};
|
||||
let _ = out.output();
|
||||
cx
|
||||
}
|
||||
|
||||
fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let x = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let router_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let expert_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router_scale = cx.tensor(MOE_HIDDEN);
|
||||
let router_proj = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let per_expert_scale = cx.tensor(MOE_NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, GEMMA_RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (MOE_HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, MOE_TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
weights.gather(exp_base + exp_within)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
let mut ir_variants = Vec::new();
|
||||
let mut opkind_variants = Vec::new();
|
||||
for op in ops {
|
||||
let sort = op.sort();
|
||||
let variant = format!(
|
||||
"({} {})",
|
||||
sort.name,
|
||||
sort.fields.iter().map(|field| &field.sort).join(" ")
|
||||
);
|
||||
match sort.class.as_str() {
|
||||
"IR" => ir_variants.push(variant),
|
||||
"OpKind" => opkind_variants.push(variant),
|
||||
other => panic!("unknown sort class {other} for {}", sort.name),
|
||||
}
|
||||
}
|
||||
let extra_ir = ops.iter().flat_map(|op| op.ir_defs()).unique().join("\n");
|
||||
format!(
|
||||
"
|
||||
(datatype*
|
||||
(IR
|
||||
(OutputJoin IR IR)
|
||||
(Op OpKind IList)
|
||||
{extra_ir}
|
||||
{}
|
||||
)
|
||||
(OpKind
|
||||
{}
|
||||
)
|
||||
(IList
|
||||
(ICons IR IList)
|
||||
(INil)
|
||||
)
|
||||
)
|
||||
(function dtype (IR) DType :merge new)
|
||||
",
|
||||
ir_variants.join("\n"),
|
||||
opkind_variants.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
ops.iter()
|
||||
.filter(|op| op.cleanup())
|
||||
.map(|op| {
|
||||
let sort = op.sort();
|
||||
let fields = (0..sort.fields.len())
|
||||
.map(|i| (b'a' + i as u8) as char)
|
||||
.join(" ");
|
||||
if sort.class == "OpKind" {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
((delete (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m ({} {fields})))
|
||||
((delete ({} {fields})))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
}
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn setup_program(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let rewrites = ops
|
||||
.iter()
|
||||
.flat_map(|op| op.rewrites())
|
||||
.map(|rule| rule.to_egglog_string())
|
||||
.join("\n");
|
||||
[
|
||||
EGGLOG_RULESETS
|
||||
.iter()
|
||||
.map(|ruleset| format!("(ruleset {ruleset})"))
|
||||
.join("\n"),
|
||||
base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
base_cleanup_egglog(),
|
||||
rewrites,
|
||||
program.to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn producer_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run matmul_flatten)
|
||||
(run kernel_lower)
|
||||
(run direct_kernel)
|
||||
(run kernel_specialize)
|
||||
(run buffer_reuse)
|
||||
(run matmul_backend)
|
||||
(run glumoe)
|
||||
(run fusion_pair)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fusion_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run fusion_grow)
|
||||
(run fusion_merge)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn split_cycle() -> Vec<(&'static str, String)> {
|
||||
vec![
|
||||
("producers", format!("(saturate {})", producer_schedule())),
|
||||
("fusion", format!("(saturate {})", fusion_schedule())),
|
||||
]
|
||||
}
|
||||
|
||||
fn split_cycle_schedule() -> String {
|
||||
format!(
|
||||
"(seq
|
||||
(saturate {})
|
||||
(saturate {})
|
||||
)",
|
||||
producer_schedule(),
|
||||
fusion_schedule()
|
||||
)
|
||||
}
|
||||
|
||||
fn phase(egraph: &mut egglog::EGraph, name: &str, schedule: &str) -> bool {
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let command = format!("(run-schedule {schedule})");
|
||||
let outputs = egraph
|
||||
.parse_and_run_program(None, &command)
|
||||
.unwrap_or_else(|err| panic!("failed phase {name} schedule {schedule}: {err}"));
|
||||
let elapsed = start.elapsed();
|
||||
let after = egraph.num_tuples();
|
||||
let report = outputs
|
||||
.into_iter()
|
||||
.find_map(|output| match output {
|
||||
egglog::CommandOutput::RunSchedule(report) => Some(report),
|
||||
_ => None,
|
||||
})
|
||||
.expect("run-schedule did not return a report");
|
||||
let mut rules = report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(rule, time)| {
|
||||
(
|
||||
rule.to_string(),
|
||||
*time,
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.get(rule)
|
||||
.copied()
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
rules.sort_by_key(|(_, time, matches)| (std::cmp::Reverse(*time), std::cmp::Reverse(*matches)));
|
||||
let matches = report.num_matches_per_rule.values().sum::<usize>();
|
||||
println!(
|
||||
"phase {name:<18} {elapsed_ms:>8.2} ms | tuples {before} -> {after} ({delta:+}) | updated={updated} | iters={iters} | matches={matches}",
|
||||
elapsed_ms = elapsed.as_secs_f64() * 1000.0,
|
||||
delta = after as isize - before as isize,
|
||||
updated = report.updated,
|
||||
iters = report.iterations.len(),
|
||||
);
|
||||
for (rule, time, matches) in rules
|
||||
.into_iter()
|
||||
.filter(|(_, time, matches)| !time.is_zero() || *matches > 0)
|
||||
.take(8)
|
||||
{
|
||||
println!(
|
||||
" rule {rule:<82} {ms:>8.2} ms | matches {matches}",
|
||||
ms = time.as_secs_f64() * 1000.0,
|
||||
);
|
||||
}
|
||||
report.updated
|
||||
}
|
||||
|
||||
fn serialize_summary(egraph: &mut egglog::EGraph, root: &str) {
|
||||
let (sort, value) = egraph.eval_expr(&egglog::var!(root.to_string())).unwrap();
|
||||
let output = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
let mut classes = std::collections::BTreeSet::new();
|
||||
let mut top_ops = BTreeMap::<String, usize>::new();
|
||||
let mut nodes = 0usize;
|
||||
for node in output.egraph.nodes.values().filter(|node| !node.subsumed) {
|
||||
nodes += 1;
|
||||
classes.insert(node.eclass.clone());
|
||||
*top_ops.entry(node.op.clone()).or_default() += 1;
|
||||
}
|
||||
let top_ops = top_ops
|
||||
.into_iter()
|
||||
.sorted_by_key(|(_, count)| std::cmp::Reverse(*count))
|
||||
.take(12)
|
||||
.map(|(op, count)| format!("{op}={count}"))
|
||||
.join(", ");
|
||||
println!(
|
||||
"serialize nodes={nodes} classes={} roots={} top_ops={top_ops}",
|
||||
classes.len(),
|
||||
output.egraph.root_eclasses.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn run(args: Args) {
|
||||
let mut graph = build_case(args.case);
|
||||
let rolled = if args.skip_roll {
|
||||
0
|
||||
} else {
|
||||
graph.auto_roll_loops_prepass()
|
||||
};
|
||||
let (program, root) = hlir_to_egglog(&graph);
|
||||
|
||||
let mut ops = match args.backend {
|
||||
Backend::Native => <NativeRuntime as Runtime>::Ops::into_vec(),
|
||||
Backend::Cuda => <CudaRuntime as Runtime>::Ops::into_vec(),
|
||||
};
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup = args.cleanup && matches!(args.backend, Backend::Cuda);
|
||||
let setup = setup_program(&program, &ops, cleanup);
|
||||
|
||||
println!(
|
||||
"case={:?} backend={:?} mode={:?} passes={} cleanup={} rolled={} hlir_nodes={} setup_lines={} setup_bytes={} root={root}",
|
||||
args.case,
|
||||
args.backend,
|
||||
args.mode,
|
||||
args.passes,
|
||||
cleanup,
|
||||
rolled,
|
||||
graph.graph.node_count(),
|
||||
setup.lines().count(),
|
||||
setup.len(),
|
||||
);
|
||||
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let commands = egraph.parser.get_program_from_string(None, &setup).unwrap();
|
||||
egraph.run_program(commands).unwrap();
|
||||
println!(
|
||||
"setup {:>8.2} ms | tuples {before} -> {} ({:+})",
|
||||
start.elapsed().as_secs_f64() * 1000.0,
|
||||
egraph.num_tuples(),
|
||||
egraph.num_tuples() as isize - before as isize,
|
||||
);
|
||||
|
||||
match args.mode {
|
||||
Mode::Current | Mode::Steps => {
|
||||
for pass in 1..=args.passes {
|
||||
let mut updated = false;
|
||||
for (name, schedule) in split_cycle() {
|
||||
updated |= phase(&mut egraph, &format!("{pass:03} {name}"), &schedule);
|
||||
}
|
||||
if matches!(args.mode, Mode::Current) && !updated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Mode::FullDefault => {
|
||||
phase(&mut egraph, "expr", "(saturate expr)");
|
||||
phase(&mut egraph, "dtype", "(saturate dtype_prop)");
|
||||
phase(&mut egraph, "default-full", "(saturate (run))");
|
||||
}
|
||||
Mode::FullCycle => {
|
||||
phase(
|
||||
&mut egraph,
|
||||
"cycle-full",
|
||||
&format!("(saturate {})", split_cycle_schedule()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
phase(&mut egraph, "final expr", "(saturate expr)");
|
||||
if cleanup {
|
||||
phase(&mut egraph, "cleanup", "(saturate cleanup)");
|
||||
}
|
||||
phase(&mut egraph, "base cleanup", "(saturate base_cleanup)");
|
||||
serialize_summary(&mut egraph, &root);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
run(parse_args());
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
//! ComputeAttnMask — fused op that computes the paged attention mask from indptrs.
|
||||
//!
|
||||
//! This op exists so the indptr tensors (qo_indptr, kv_indptr) are visible in the
|
||||
//! same e-graph chunk as the attention pattern, letting the FlashInfer egglog rule
|
||||
//! capture them directly.
|
||||
//!
|
||||
//! Inputs (3): q_pos (s,) Int, qo_indptr (r,) Int, kv_indptr (r,) Int.
|
||||
//! Output: mask (s, c) F32 where mask[i, j] = 0.0 (attend) or -1e10 (block).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, HLIROp, LLIROp},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaStream, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Computes the paged attention mask from indptr arrays.
|
||||
///
|
||||
/// The mask encodes both request-membership and causality:
|
||||
/// `mask[i, j] = 0.0` if query `i` and context `j` belong to the same request AND
|
||||
/// context `j`'s local position is `<= q_pos[i]`; `-1e10` otherwise.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ComputeAttnMask {
|
||||
pub s_dim: Expression,
|
||||
pub c_dim: Expression,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ComputeAttnMask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ComputeAttnMask(s={}, c={})", self.s_dim, self.c_dim)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for ComputeAttnMask {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (ComputeAttnMask {} {}) (ICons {} (ICons {} (ICons {} (INil)))))",
|
||||
self.s_dim.to_egglog(),
|
||||
self.c_dim.to_egglog(),
|
||||
inputs[0].1, // q_pos
|
||||
inputs[1].1, // qo_indptr
|
||||
inputs[2].1, // kv_indptr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for ComputeAttnMask {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"ComputeAttnMask",
|
||||
&[("s_dim", EXPRESSION), ("c_dim", EXPRESSION)],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrites — inserted directly by model code.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let s_dim = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let c_dim = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let op = Self { s_dim, c_dim };
|
||||
let llir_op = LLIROp::new::<dyn HostOp>(Box::new(op) as Box<dyn HostOp>);
|
||||
(llir_op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for ComputeAttnMask {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 3 {
|
||||
anyhow::bail!(
|
||||
"ComputeAttnMask expects 3 inputs (q_pos, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let s = self
|
||||
.s_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask s_dim unresolved"))?;
|
||||
let c = self
|
||||
.c_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask c_dim unresolved"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask requires dynamic dim 'r'"))?;
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("ComputeAttnMask missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_pos_buf = get_buf("q_pos", inputs[0])?;
|
||||
let qo_indptr_buf = get_buf("qo_indptr", inputs[1])?;
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[2])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
let q_pos = dtoh_i32(stream, q_pos_buf.ptr(), s)?;
|
||||
let qo_indptr = dtoh_i32(stream, qo_indptr_buf.ptr(), r)?;
|
||||
let kv_indptr = dtoh_i32(stream, kv_indptr_buf.ptr(), r)?;
|
||||
|
||||
let mut mask = vec![-1e10f32; s * c];
|
||||
for i in 0..s {
|
||||
let q_req = indptr_to_request(&qo_indptr, i as i32);
|
||||
for j in 0..c {
|
||||
let c_req = indptr_to_request(&kv_indptr, j as i32);
|
||||
if q_req == c_req && q_req >= 0 {
|
||||
let c_local = j as i32 - kv_indptr[c_req as usize];
|
||||
if c_local <= q_pos[i] {
|
||||
mask[i * c + j] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mask_bytes =
|
||||
unsafe { std::slice::from_raw_parts(mask.as_ptr() as *const u8, mask.len() * 4) };
|
||||
unsafe {
|
||||
let res = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
out_buf.ptr(),
|
||||
mask_bytes.as_ptr() as *const std::ffi::c_void,
|
||||
mask_bytes.len(),
|
||||
);
|
||||
if res != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
|
||||
anyhow::bail!("ComputeAttnMask cuMemcpyHtoD failed: {res:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.s_dim * self.c_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("ComputeAttnMask")
|
||||
}
|
||||
}
|
||||
|
||||
fn dtoh_i32(stream: &Arc<CudaStream>, dev_ptr: u64, len: usize) -> anyhow::Result<Vec<i32>> {
|
||||
let mut host = vec![0u8; len * std::mem::size_of::<i32>()];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, dev_ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let v = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host);
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut i32, len, len)
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Given an indptr array `[0, a, b, ...]`, find which segment `idx` belongs to.
|
||||
/// Returns `count(indptr[i] <= idx) - 1`.
|
||||
fn indptr_to_request(indptr: &[i32], idx: i32) -> i32 {
|
||||
indptr.iter().filter(|&&v| v <= idx).count() as i32 - 1
|
||||
}
|
||||
@@ -19,9 +19,9 @@ use crate::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
)
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
)
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
@@ -121,7 +116,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
@@ -129,21 +123,17 @@
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?n ; ldd
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
|
||||
@@ -1,428 +0,0 @@
|
||||
; Fuse a row-major Add on top of an existing cuBLASLt matmul into
|
||||
; D = alpha * A * B + beta * C.
|
||||
;
|
||||
; The existing matmul rewrites view Luminal's row-major output [m,n] as a
|
||||
; column-major cuBLASLt matrix [n,m]. A row-major C input with logical strides
|
||||
; [row_stride, 1] therefore maps to ldc=row_stride. This lets a C slice from a
|
||||
; wider parent tensor use a larger ldc while D keeps the matmul output layout.
|
||||
; cuBLASLt requires out-of-place C and D to have the same matrix order, so these
|
||||
; beta rules only fuse C layouts that map to the current COL-ordered D layout.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched c plus matmul beta"
|
||||
)
|
||||
|
||||
; ROW-ordered D beta fusions. These pair with cublaslt_row_order_rewrite.egg,
|
||||
; where the cuBLASLt problem dimensions match Luminal's logical output [m,n].
|
||||
; A row-major C input with logical strides [row_stride, 1] maps directly to a
|
||||
; ROW-ordered cuBLASLt C[m,n] descriptor with ldc=row_stride.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched c plus matmul beta"
|
||||
)
|
||||
@@ -1,614 +0,0 @@
|
||||
; cuBLASLt epilogue rewrites.
|
||||
;
|
||||
; ReLU in the frontend lowers through maximum_f32(0.0):
|
||||
;
|
||||
; (matmul < 0) * 0 + cast(cast((-cast(matmul < 0) + 1) as bool) as f32) * matmul
|
||||
;
|
||||
; These rules fuse that expression back into CUBLASLT_EPILOGUE_RELU.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu bias epilogue"
|
||||
)
|
||||
|
||||
; Canonical tanh-approx GELU can also appear directly as:
|
||||
;
|
||||
; x * sigmoid(1.5957691216 * x * (1 + 0.044715 * x * x))
|
||||
;
|
||||
; Match that sigmoid form and fuse it into the cuBLASLt GELU epilogues.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu bias epilogue"
|
||||
)
|
||||
|
||||
; This first slice fuses column-bias adds into CUBLASLT_EPILOGUE_BIAS for the
|
||||
; older COL-ordered output view. In that view Luminal's logical [m,n] output is
|
||||
; represented as a cuBLASLt [n,m] matrix, so cuBLASLt's row-broadcast bias maps
|
||||
; to the common logical column bias of length n.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d column bias plus matmul epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column bias plus matmul epilogue"
|
||||
)
|
||||
@@ -1,345 +0,0 @@
|
||||
; FP8 support is narrower than "any FP8 x any FP8". cuBLASLt's regular FP8
|
||||
; matmul table supports these A/B descriptor pairs for F32 outputs:
|
||||
; E4M3 x E4M3
|
||||
; E4M3 x E5M2
|
||||
; E5M2 x E4M3
|
||||
; and requires TN format on Ada/Hopper-class GPUs. These rules therefore match
|
||||
; row-major x column-major Luminal matmuls, which the existing COL-order lowering
|
||||
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
|
||||
; transb=N.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 batched row-major x column-major f32 output"
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
; Mixed output dtype rewrites for cuBLASLt.
|
||||
;
|
||||
; The first mixed mode we need for low-precision matmuls is:
|
||||
;
|
||||
; D[f32] = A[fp16/bf16] * B[fp16/bf16]
|
||||
;
|
||||
; Luminal graphs express this today as a Cast(F32) around a low-precision
|
||||
; matmul. cuBLASLt can write the f32 output directly, so expose that candidate
|
||||
; before beta fusion tries to consume an f32 C input.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F16) (F16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt f16 matmul cast f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (Bf16) (Bf16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt bf16 matmul cast f32 output"
|
||||
)
|
||||
@@ -1,452 +0,0 @@
|
||||
; Natural cuBLASLt row-order output rewrites. These keep Luminal's logical
|
||||
; output C[m,n] as a cuBLASLt ROW-ordered D[m,n] instead of using the older
|
||||
; swapped COL-ordered D[n,m] view. A and B orders mirror their matched logical
|
||||
; layouts, so this family is the legal base for future ROW-ordered beta fusions.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x column-major"
|
||||
)
|
||||
@@ -1,316 +0,0 @@
|
||||
; Scalar alpha/beta rewrites for cuBLASLt. These rules target scalar constants
|
||||
; expanded across the matmul/add shape, i.e. zero strides on every logical axis.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; alpha=1.0 hash-conses ?fused == ?matmul; the union merges Mul into ?matmul's eclass and saturate diverges.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; See 2d alpha scale: alpha=1.0 makes (saturate ...) diverge.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c plus matmul beta"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,124 +0,0 @@
|
||||
# FlashInfer Integration
|
||||
|
||||
FlashInfer replaces the multi-op attention pattern (Q×K^T → scale → mask → softmax → ×V) with a single fused GPU kernel via [FlashInfer](https://github.com/flashinfer-ai/flashinfer)'s batch decode and batch prefill APIs.
|
||||
|
||||
## Current State
|
||||
|
||||
**Working:**
|
||||
- Egglog rewrite rule matches any GQA paged attention pattern (model-agnostic shapes)
|
||||
- GA search selects FlashInfer when it wins profiling — verified on Llama 3 8B (32 layers) and Qwen 3 4B (36 layers)
|
||||
- **BatchDecode** (s=1): fp32 natively — FlashInfer's decode kernel uses scalar vectorized dot products, no tensor cores
|
||||
- **BatchPrefill**: template-instantiated for fp16 but **not callable from fp32** — FlashInfer's prefill kernel requires tensor core MMA (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically only operate on 16-bit types; the C API stubs return -1 for fp32; will be enabled when native fp16/bf16 pipeline is added
|
||||
- Decode handles all cases in the current fp32 pipeline (prefill uses cuBLAS attention via dim bucketing)
|
||||
- Indptr-based mask: `qo_indptr` and `kv_indptr` are computed in-graph so the egglog rule can see them in the same chunk as the attention ops
|
||||
|
||||
**Not yet implemented:**
|
||||
- Native fp16 / bf16 pipeline (would eliminate the cast overhead in prefill)
|
||||
- Page sizes > 1
|
||||
|
||||
---
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
src/host/flashinfer/
|
||||
flashinfer_attention.egg — egglog rewrite rule (pattern match → FlashInferAttention)
|
||||
mod.rs — FlashInferAttention op (EgglogOp + HostOp impl)
|
||||
jit.rs — JIT compilation: nvcc wrapper.cu → .so, dlopen, fn pointers
|
||||
find_indptrs.rs — walks the mask e-graph node to locate qo_indptr / kv_indptr inputs
|
||||
wrapper.cu — CUDA: FlashInfer template instantiation + helper kernels
|
||||
wrapper.h — C API header for wrapper.cu
|
||||
README.md — this file
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Egglog Pattern Matching
|
||||
|
||||
The rule in `flashinfer_attention.egg` matches the structural pattern of paged GQA attention:
|
||||
|
||||
```
|
||||
Gather(K_cache, idx) → GQA broadcast (Mul×1.0) → Q×K^T → Sum → scale → mask Add → softmax → attn×V → Sum → output
|
||||
Gather(V_cache, idx) → GQA broadcast (Mul×1.0) ──────────────────────────────────────────→ attn×V → Sum → output
|
||||
```
|
||||
|
||||
Key anchors that prevent false matches on MLP or other ops:
|
||||
- Two Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
- GQA broadcast via `Mul(gathered, Constant(1.0))` with all-zero strides
|
||||
- Mask Add with zero-stride broadcast in the first (nheads) dimension
|
||||
- Two sequential matmul+Sum pairs connected through softmax
|
||||
|
||||
Shape dimensions are egglog variables, not pinned constants — the rule works for any model with GQA (Llama, Qwen, Mistral, etc.). The structural invariants (dimension count, zero-stride positions, Gather from 2D) are enough to avoid combinatorial explosion during saturation.
|
||||
|
||||
When the rule fires, it unions `FlashInferAttention` with the original attention output, making it an equivalent alternative in the e-graph. The GA search then profiles both paths and picks the faster one.
|
||||
|
||||
### 2. Extraction: Finding Indptrs
|
||||
|
||||
During `extract()` (called when egglog selects the FlashInferAttention e-node), `find_indptrs.rs` walks backward from the mask node in the e-graph to locate the `qo_indptr` and `kv_indptr` Input nodes. It validates the mask structure by checking for the `Mul(allowed, Constant(1e10))` pattern that `compute_attn_mask()` produces.
|
||||
|
||||
The indptrs are appended as inputs 5 and 6 to the FlashInferAttention op, so the runtime can build the CSR page table directly without recomputing anything.
|
||||
|
||||
### 3. JIT Compilation
|
||||
|
||||
FlashInfer requires `HEAD_DIM` as a compile-time template parameter. Rather than baking it at `cargo build` time, `jit.rs` JIT-compiles `wrapper.cu` with the model's actual HEAD_DIM:
|
||||
|
||||
1. First call to `ensure_compiled(head_dim)` runs `nvcc` with `-DLUMINAL_HEAD_DIM=<N>`
|
||||
2. The compiled `.so` is cached at `~/.cache/luminal/flashinfer/libflashinfer_hd<N>_<arch>.so`
|
||||
3. Subsequent calls load the cached library via `dlopen`
|
||||
4. Function pointers (plan, run, transpose, etc.) are resolved and stored in a `static OnceLock`
|
||||
|
||||
Supported HEAD_DIM values: 64, 128, 256.
|
||||
|
||||
### 4. Runtime Execution
|
||||
|
||||
`FlashInferAttention::execute()` dispatches to decode or prefill based on `total_q_tokens vs batch_size`:
|
||||
|
||||
**Common steps:**
|
||||
1. **Extract kv_indices** — a helper kernel converts the flat gather index `(c, KV_DIM)` to slot indices `(c,)`
|
||||
2. **Read indptrs to host** — copied to CPU for the plan phase
|
||||
3. **Plan** — queries GPU occupancy and decides split-KV decomposition
|
||||
4. **Run** — the fused kernel writes `(total_q_tokens, num_qo_heads, head_dim)`
|
||||
5. **Transpose** — transposes to `(num_qo_heads, total_q_tokens, head_dim)` to match the Sum reduction layout
|
||||
|
||||
**Decode path** (current, fp32): Always used. Runs FlashInfer's BatchDecode directly on fp32 buffers.
|
||||
|
||||
**Prefill path** (future, fp16/bf16 only): The prefill kernel templates are compiled into the JIT .so for fp16 (CTA_TILE_Q=16/64/128, causal mask). The C API stubs currently return -1 since the pipeline is fp32. When native fp16/bf16 dtype support is added, `execute()` will dispatch to prefill when `total_q_tokens > batch_size`.
|
||||
|
||||
Global workspaces (`static OnceLock`) are shared across all FlashInferAttention instances to avoid ~4ms allocation overhead per GA profiling candidate. Without this, the GA never selects FlashInfer because the first-run allocation cost dwarfs the kernel time.
|
||||
|
||||
## How the Attention Mask Enables FlashInfer
|
||||
|
||||
For the egglog rule to fire, the `qo_indptr` and `kv_indptr` tensors must be visible in the same e-graph chunk as the attention ops. This is why the mask is computed *inside* each layer (via `compute_attn_mask()` in the model) rather than passed as a pre-computed input.
|
||||
|
||||
The mask computation uses a specific structure:
|
||||
```rust
|
||||
let allowed = same_request * causal;
|
||||
allowed * 1e10 - 1e10 // → 0.0 for allowed, -1e10 for blocked
|
||||
```
|
||||
|
||||
The `Mul(allowed, Constant(1e10))` pattern is the anchor that `find_indptrs.rs` uses to walk backward and locate the indptr inputs.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Items listed in priority order. Checked items are done.
|
||||
|
||||
- [x] Model-agnostic egglog rule (shape variables instead of Llama-specific constants)
|
||||
- [x] bs>1 supersequence decode
|
||||
- [x] Indptr-based attention mask (replaces CPU-computed mask)
|
||||
- [x] Multi-model support (verified on Llama 3 8B and Qwen 3 4B)
|
||||
- [x] BatchPrefill kernel compiled for fp16 (causal mask, CTA_TILE_Q=16/64/128)
|
||||
- [ ] Native fp16 / bf16 pipeline (enables prefill, reduces memory, eliminates cuBLAS prefill fallback)
|
||||
- [ ] HEAD_DIM dispatch for 64, 96 (JIT supports 64/128/256; wrapper.cu needs 96 for Phi)
|
||||
- [ ] Page sizes > 1 (currently page_size=1; larger pages reduce CSR overhead)
|
||||
- [ ] Sliding window, ALiBi, logits soft cap (FlashInfer `AttentionVariant` templates)
|
||||
- [ ] MHA / MQA / arbitrary GQA ratios beyond {1, 2, 4, 8}
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- **page_size=1**: Each KV cache slot is one "page". This simplifies the CSR page table (`kv_indices` = physical slot indices directly) and matches the flat `(num_slots, KV_DIM)` cache layout.
|
||||
|
||||
- **Pinned structural anchors**: The egglog rule pins the *structure* (number of dimensions, which dims are zero-stride, presence of Gather from 2D cache) but uses variables for the *values* (head counts, head_dim). This prevents saturation blowup while remaining model-agnostic.
|
||||
|
||||
- **Prefill requires fp16/bf16**: FlashInfer's prefill kernel uses tensor core MMA instructions (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically require 16-bit inputs — there is no fp32 tensor core matmul instruction. The prefill kernel templates are compiled into the .so for fp16 but the C API returns -1 for fp32 callers. When native fp16/bf16 is added, prefill will be enabled automatically.
|
||||
|
||||
- **Global workspaces**: Float workspace (128 MiB), int workspace (8 MiB), and a page-locked host buffer are allocated once via `static OnceLock` and shared across all instances.
|
||||
@@ -1,248 +0,0 @@
|
||||
//! Walk the e-graph from the mask node to find qo_indptr and kv_indptr Input nodes.
|
||||
//!
|
||||
//! The mask is produced by `compute_attn_mask(q_pos, qo_indptr, kv_indptr)` using
|
||||
//! primitive HLIR ops. This module validates the mask's structure and extracts the
|
||||
//! indptr Input node IDs so FlashInfer can use them directly.
|
||||
|
||||
use luminal::egglog_utils::{ClassId, NodeId, SerializedEGraph};
|
||||
use luminal::prelude::FxHashSet;
|
||||
|
||||
/// Result of walking the mask computation chain.
|
||||
#[derive(Debug)]
|
||||
pub struct IndptrNodes<'a> {
|
||||
pub qo_indptr: &'a NodeId,
|
||||
pub kv_indptr: &'a NodeId,
|
||||
}
|
||||
|
||||
/// Find the qo_indptr and kv_indptr Input nodes by walking backwards from the mask.
|
||||
///
|
||||
/// Validates the mask structure: `allowed * 1e10 + (-1e10)`. Then does a BFS from
|
||||
/// the `allowed` subtree to find all reachable Input nodes with names containing
|
||||
/// "qo_indptr" and "kv_indptr".
|
||||
///
|
||||
/// Panics with a diagnostic message if the structure doesn't match or the
|
||||
/// indptr inputs can't be found.
|
||||
pub fn find_indptr_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
assert!(
|
||||
mask_kind_label.contains("Add"),
|
||||
"find_indptr_inputs: mask is not an Add (kind={mask_kind_label})"
|
||||
);
|
||||
|
||||
let mask_inputs = walk_ilist_simple(egraph, &mask_children[1]);
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
"find_indptr_inputs: mask Add should have 2 inputs, got {}",
|
||||
mask_inputs.len()
|
||||
);
|
||||
|
||||
// Step 2: One of the inputs should be Mul(allowed, Constant(1e10))
|
||||
let (scaled_allowed, allowed_node) = find_1e10_mul(egraph, &mask_inputs);
|
||||
|
||||
// Step 3: BFS from `allowed` to find all reachable Input nodes
|
||||
let reachable_inputs = find_reachable_inputs(egraph, allowed_node);
|
||||
|
||||
// Step 4: Match by name
|
||||
let mut qo_indptr: Option<&NodeId> = None;
|
||||
let mut kv_indptr: Option<&NodeId> = None;
|
||||
|
||||
for (node_id, name) in &reachable_inputs {
|
||||
if name.contains("qo_indptr") {
|
||||
qo_indptr = Some(node_id);
|
||||
} else if name.contains("kv_indptr") {
|
||||
kv_indptr = Some(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let qo = qo_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'qo_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
let kv = kv_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'kv_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
IndptrNodes {
|
||||
qo_indptr: qo,
|
||||
kv_indptr: kv,
|
||||
}
|
||||
}
|
||||
|
||||
fn find_1e10_mul<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
if label != "Op" {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("Mul") {
|
||||
continue;
|
||||
}
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
for (i, &inp) in mul_inputs.iter().enumerate() {
|
||||
if is_constant(egraph, inp, 1e10) {
|
||||
let other = mul_inputs[1 - i];
|
||||
return (input_node, other);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut debug_info = String::new();
|
||||
for (i, &input_node) in mask_add_inputs.iter().enumerate() {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
debug_info.push_str(&format!("\n input[{i}]: label={label}"));
|
||||
if label == "Op" && !children.is_empty() {
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
debug_info.push_str(&format!(" kind={kind_label}"));
|
||||
for (j, kc) in egraph.enodes[kind].1.iter().enumerate() {
|
||||
let kc_node = resolve_first_node(egraph, kc);
|
||||
debug_info.push_str(&format!(" child[{j}]={}", egraph.enodes[kc_node].0));
|
||||
}
|
||||
if kind_label.contains("Mul") && children.len() >= 2 {
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for (j, &mi) in mul_inputs.iter().enumerate() {
|
||||
let (ml, mc) = &egraph.enodes[mi];
|
||||
debug_info.push_str(&format!("\n mul_input[{j}]: label={ml}"));
|
||||
if ml == "Op" && !mc.is_empty() {
|
||||
let mk = resolve_first_node(egraph, &mc[0]);
|
||||
debug_info.push_str(&format!(" kind={}", egraph.enodes[mk].0));
|
||||
for (k, mkc) in egraph.enodes[mk].1.iter().enumerate() {
|
||||
let mkc_node = resolve_first_node(egraph, mkc);
|
||||
debug_info.push_str(&format!(" ch[{k}]={}", egraph.enodes[mkc_node].0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find Mul(allowed, Constant(1e10)) in mask Add inputs.{debug_info}"
|
||||
);
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
if !kind_label.contains("Constant") {
|
||||
return false;
|
||||
}
|
||||
let val_children = &egraph.enodes[kind].1;
|
||||
if val_children.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let val_node = resolve_first_node(egraph, &val_children[0]);
|
||||
let val_str = &egraph.enodes[val_node].0;
|
||||
if let Ok(val) = val_str.parse::<f64>() {
|
||||
(val as f32 - expected).abs() < 1.0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn find_reachable_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
start: &'a NodeId,
|
||||
) -> Vec<(&'a NodeId, String)> {
|
||||
let mut found = Vec::new();
|
||||
let mut visited = FxHashSet::default();
|
||||
let mut stack = vec![start];
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
if !visited.insert(node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
|
||||
if label == "Input" {
|
||||
if children.len() >= 2 {
|
||||
let name_node = resolve_first_node(egraph, &children[1]);
|
||||
let name = egraph.enodes[name_node].0.trim_matches('"').to_string();
|
||||
found.push((node, name));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if label == "Op" && children.len() >= 2 {
|
||||
let ir_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for inp in ir_inputs {
|
||||
stack.push(inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found
|
||||
}
|
||||
|
||||
fn walk_ilist_simple<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
ilist_eclass: &'a ClassId,
|
||||
) -> Vec<&'a NodeId> {
|
||||
let mut inputs = Vec::new();
|
||||
let mut current = resolve_first_node(egraph, ilist_eclass);
|
||||
|
||||
loop {
|
||||
let (label, children) = &egraph.enodes[current];
|
||||
if label == "INil" {
|
||||
break;
|
||||
}
|
||||
if label != "ICons" {
|
||||
break;
|
||||
}
|
||||
let ir_node = resolve_first_ir_node(egraph, &children[0]);
|
||||
inputs.push(ir_node);
|
||||
current = resolve_first_node(egraph, &children[1]);
|
||||
}
|
||||
|
||||
inputs
|
||||
}
|
||||
|
||||
fn resolve_first_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
&egraph.eclasses[eclass].1[0]
|
||||
}
|
||||
|
||||
fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
let nodes = &egraph.eclasses[eclass].1;
|
||||
for node in nodes {
|
||||
let label = &egraph.enodes[node].0;
|
||||
if label == "Op" || label == "Input" {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
; FlashInfer batch decode attention rewrite rule.
|
||||
;
|
||||
; Matches the paged attention pattern for ANY model with GQA:
|
||||
; Gather(K_cache) → GQA broadcast → Q*K^T matmul → scale → add mask → softmax → attn*V matmul
|
||||
; Gather(V_cache) → GQA broadcast ──────────────────────────────────────────→ attn*V matmul
|
||||
;
|
||||
; Structural anchors (prevent false matches on MLP/other ops):
|
||||
; - Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
; - GQA broadcast via Mul(gathered, Constant(1.0)) with all-zero strides
|
||||
; - Scale Mul(QK, constant) connecting QK scores to mask Add
|
||||
; - Mask Add with zero-stride broadcast in first dim (nheads broadcast)
|
||||
; - Data flow: two sequential matmul+reduce pairs connected through softmax
|
||||
;
|
||||
; The egglog rule captures the mask as 5th input. During extract(), a Rust
|
||||
; function walks the mask's computation chain in the e-graph to locate the
|
||||
; qo_indptr and kv_indptr Input nodes (validated via the Constant(1e10) anchor
|
||||
; and structural checks). These are appended as inputs 5 and 6 so FlashInfer
|
||||
; can build the CSR page table directly — no runtime derivation needed.
|
||||
;
|
||||
; Shape dimensions are egglog variables, not pinned constants.
|
||||
; Dynamic dims "s" (batch/seq) and "c" (context) stay pinned as MVar.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ── Second matmul: Mul(softmax_out, V_gqa) ──
|
||||
; Shape: (nheads, s, hdim, c) — 4D
|
||||
(= ?mul2 (Op (Mul
|
||||
(ECons ?nheads (ECons (MVar "s") (ECons ?hdim (ECons (MVar "c") (ENil)))))
|
||||
?mul2_a_strides
|
||||
?mul2_b_strides
|
||||
?mul2_out_strides)
|
||||
(ICons ?soft (ICons ?v_gqa (INil)))))
|
||||
|
||||
; ── Second matmul: Sum (reduction over c) → output ──
|
||||
; Shape: (nheads, s, hdim) — reduces c
|
||||
(= ?output (Op (Sum
|
||||
(ECons ?nheads2 (ECons (MVar "s") (ECons ?hdim2 (ENil))))
|
||||
(MVar "c")
|
||||
?out_in_strides
|
||||
(MIter)
|
||||
?out_out_strides)
|
||||
(ICons ?mul2 (INil))))
|
||||
|
||||
; ── V GQA broadcast: Mul(V_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, c, hdim) — 3D
|
||||
(= ?v_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?v_gqa (Op (Mul
|
||||
(ECons ?nheads3 (ECons (MVar "c") (ECons ?hdim3 (ENil))))
|
||||
?v_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?v_gqa_out_strides)
|
||||
(ICons ?v_gathered (ICons ?v_gqa_const (INil)))))
|
||||
|
||||
; ── V Gather: rows from V_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?v_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim (ENil)))
|
||||
?v_gather_strides
|
||||
(ECons ?num_slots_v (ECons ?kvdim2 (ENil)))
|
||||
?v_src_strides)
|
||||
(ICons ?v_idx (ICons ?v_cache (INil)))))
|
||||
|
||||
; ── First matmul: Mul(Q, K_gqa) ──
|
||||
; Shape: (nheads, s, c, hdim) — 4D
|
||||
(= ?mul1 (Op (Mul
|
||||
(ECons ?nheads4 (ECons (MVar "s") (ECons (MVar "c") (ECons ?hdim4 (ENil)))))
|
||||
?mul1_a_strides
|
||||
?mul1_b_strides
|
||||
?mul1_out_strides)
|
||||
(ICons ?q (ICons ?k_gqa (INil)))))
|
||||
|
||||
; ── First matmul: Sum (reduction over hdim) → QK scores ──
|
||||
; Shape: (nheads, s, c) — reduces hdim
|
||||
(= ?qk (Op (Sum
|
||||
(ECons ?nheads5 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?hdim5
|
||||
?qk_in_strides
|
||||
(MIter)
|
||||
?qk_out_strides)
|
||||
(ICons ?mul1 (INil))))
|
||||
|
||||
; ── Mask Add: Add(scaled_QK, mask) ──
|
||||
; Shape: (nheads, s, c) — 3D
|
||||
; Mask is broadcast from (s, c) via zero-stride in first dim (nheads).
|
||||
(= ?masked (Op (Add
|
||||
(ECons ?nheads8 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?mask_add_a_strides
|
||||
(ECons (MNum 0) ?mask_rest_strides)
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?k_gqa (Op (Mul
|
||||
(ECons ?nheads6 (ECons ?hdim6 (ECons (MVar "c") (ENil))))
|
||||
?k_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?k_gqa_out_strides)
|
||||
(ICons ?k_gathered (ICons ?k_gqa_const (INil)))))
|
||||
|
||||
; ── K Gather: rows from K_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?k_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim3 (ENil)))
|
||||
?k_gather_strides
|
||||
(ECons ?num_slots_k (ECons ?kvdim4 (ENil)))
|
||||
?k_src_strides)
|
||||
(ICons ?k_idx (ICons ?k_cache (INil)))))
|
||||
|
||||
; ── Dtype consistency ──
|
||||
(= ?dt (dtype ?q))
|
||||
(= ?dt (dtype ?k_cache))
|
||||
(= ?dt (dtype ?v_cache))
|
||||
)
|
||||
(
|
||||
(let ?fi (Op (FlashInferAttention
|
||||
?nheads (MDiv ?kvdim ?hdim) ?hdim (MNum 1) (MVar "s"))
|
||||
(ICons ?q (ICons ?k_cache (ICons ?v_cache (ICons ?k_idx (ICons ?mask (INil))))))))
|
||||
(union ?output ?fi)
|
||||
(set (dtype ?fi) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "FlashInfer batch decode attention"
|
||||
)
|
||||
@@ -1,504 +0,0 @@
|
||||
//! JIT compilation and dynamic loading of FlashInfer kernels.
|
||||
//!
|
||||
//! Everything runs at compile / profiling time — there is no `build.rs`.
|
||||
//! `wrapper.cu` and `wrapper.h` are embedded via `include_str!()` and
|
||||
//! extracted to the cache directory on first use. The FlashInfer + CUTLASS
|
||||
//! header trees are located by probing `LUMINAL_FLASHINFER_DIR`, a small set
|
||||
//! of default paths, and (as a last resort) by `git clone`-ing FlashInfer at
|
||||
//! a pinned commit into the cache. `nvcc` is then invoked with the model's
|
||||
//! actual `HEAD_DIM` and the resulting `.so` is `dlopen`'d.
|
||||
//!
|
||||
//! `ensure_compiled` is called from `FlashInferAttention::extract()`, i.e.
|
||||
//! during luminal's compile / GA-profiling phase, not from `execute()`. After
|
||||
//! the first call the `OnceLock` makes subsequent lookups free.
|
||||
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
hash::{Hash, Hasher},
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
// ── Function pointer types matching wrapper.h ──
|
||||
|
||||
pub type PlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
indptr_h: *mut i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type RunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
pub type ExtractFn = unsafe extern "C" fn(
|
||||
flat_idx: *const i32,
|
||||
out: *mut i32,
|
||||
c: i32,
|
||||
kv_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type DeriveIndptrFn =
|
||||
unsafe extern "C" fn(mask: *const f32, indptr: *mut i32, s: i32, c: i32, stream: *mut c_void);
|
||||
|
||||
pub type TransposeOutputFn = unsafe extern "C" fn(
|
||||
src: *const f32,
|
||||
dst: *mut f32,
|
||||
batch: i32,
|
||||
heads: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type PrefillPlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
qo_indptr_h: *mut i32,
|
||||
kv_indptr_h: *mut i32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type PrefillRunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
qo_indptr: *mut i32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
// ── Embedded CUDA sources ──
|
||||
|
||||
const WRAPPER_CU: &str = include_str!("wrapper.cu");
|
||||
const WRAPPER_H: &str = include_str!("wrapper.h");
|
||||
|
||||
// ── Loaded library handle ──
|
||||
|
||||
pub struct FlashInferLib {
|
||||
// Keep the handle alive so the dlopen'd .so remains mapped.
|
||||
_lib: libloading::Library,
|
||||
pub plan: PlanFn,
|
||||
pub run: RunFn,
|
||||
pub extract_slot_indices: ExtractFn,
|
||||
pub derive_indptr_from_mask: DeriveIndptrFn,
|
||||
pub transpose_output: TransposeOutputFn,
|
||||
pub prefill_plan: PrefillPlanFn,
|
||||
pub prefill_run: PrefillRunFn,
|
||||
}
|
||||
|
||||
// SAFETY: The library handle and function pointers are valid for the lifetime
|
||||
// of the process. All functions are called with proper CUDA stream serialization.
|
||||
unsafe impl Send for FlashInferLib {}
|
||||
unsafe impl Sync for FlashInferLib {}
|
||||
|
||||
static FLASHINFER_LIB: OnceLock<FlashInferLib> = OnceLock::new();
|
||||
|
||||
/// Ensure the FlashInfer library is compiled and loaded for the given HEAD_DIM.
|
||||
/// Returns a reference to the loaded library. Thread-safe via OnceLock.
|
||||
pub fn ensure_compiled(head_dim: usize) -> &'static FlashInferLib {
|
||||
FLASHINFER_LIB.get_or_init(|| {
|
||||
assert!(
|
||||
matches!(head_dim, 64 | 128 | 256),
|
||||
"FlashInfer: unsupported HEAD_DIM={} (must be 64, 128, or 256 for f32)",
|
||||
head_dim
|
||||
);
|
||||
let so_path = compile_or_cache(head_dim);
|
||||
unsafe {
|
||||
FlashInferLib::load(&so_path)
|
||||
.unwrap_or_else(|e| panic!("Failed to load FlashInfer library: {e}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl FlashInferLib {
|
||||
/// Load a compiled FlashInfer .so and resolve function pointers.
|
||||
///
|
||||
/// # Safety
|
||||
/// The .so must be a valid FlashInfer wrapper compiled from wrapper.cu.
|
||||
unsafe fn load(path: &Path) -> Result<Self, libloading::Error> {
|
||||
let lib = unsafe { libloading::Library::new(path)? };
|
||||
let plan: PlanFn = unsafe { *lib.get::<PlanFn>(b"flashinfer_batch_decode_plan\0")? };
|
||||
let run: RunFn = unsafe { *lib.get::<RunFn>(b"flashinfer_batch_decode_run\0")? };
|
||||
let extract_slot_indices: ExtractFn =
|
||||
unsafe { *lib.get::<ExtractFn>(b"flashinfer_extract_slot_indices\0")? };
|
||||
let derive_indptr_from_mask: DeriveIndptrFn =
|
||||
unsafe { *lib.get::<DeriveIndptrFn>(b"flashinfer_derive_indptr_from_mask\0")? };
|
||||
let transpose_output: TransposeOutputFn =
|
||||
unsafe { *lib.get::<TransposeOutputFn>(b"flashinfer_transpose_output\0")? };
|
||||
let prefill_plan: PrefillPlanFn =
|
||||
unsafe { *lib.get::<PrefillPlanFn>(b"flashinfer_batch_prefill_plan\0")? };
|
||||
let prefill_run: PrefillRunFn =
|
||||
unsafe { *lib.get::<PrefillRunFn>(b"flashinfer_batch_prefill_run\0")? };
|
||||
Ok(Self {
|
||||
_lib: lib,
|
||||
plan,
|
||||
run,
|
||||
extract_slot_indices,
|
||||
derive_indptr_from_mask,
|
||||
transpose_output,
|
||||
prefill_plan,
|
||||
prefill_run,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile wrapper.cu for the given HEAD_DIM, or return cached .so path.
|
||||
fn compile_or_cache(head_dim: usize) -> PathBuf {
|
||||
let cache_dir = cache_directory();
|
||||
std::fs::create_dir_all(&cache_dir).expect("Failed to create FlashInfer cache directory");
|
||||
|
||||
// Extract bundled wrapper sources to the cache so nvcc can compile them.
|
||||
let (wrapper_cu_path, wrapper_h_dir) = extract_wrapper_sources(&cache_dir);
|
||||
|
||||
let arch = detect_cuda_arch();
|
||||
// Bake a hash of the embedded wrapper into the .so name so old caches are
|
||||
// discarded automatically when wrapper.cu or wrapper.h change.
|
||||
let wrapper_hash = wrapper_source_hash();
|
||||
let so_name = format!(
|
||||
"libflashinfer_hd{}_{}_w{:016x}.so",
|
||||
head_dim, arch, wrapper_hash
|
||||
);
|
||||
let so_path = cache_dir.join(&so_name);
|
||||
|
||||
if so_path.exists() {
|
||||
eprintln!(
|
||||
"FlashInfer: using cached library for HEAD_DIM={} ({})",
|
||||
head_dim,
|
||||
so_path.display()
|
||||
);
|
||||
return so_path;
|
||||
}
|
||||
|
||||
let Some((flashinfer_include, cutlass_include)) = locate_flashinfer_includes() else {
|
||||
panic!(
|
||||
"FlashInfer: could not locate header tree. Set LUMINAL_FLASHINFER_DIR to the \
|
||||
FlashInfer source root (the directory containing `include/` and \
|
||||
`3rdparty/cutlass/include/`)."
|
||||
);
|
||||
};
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: JIT compiling for HEAD_DIM={}, arch={} ...",
|
||||
head_dim, arch
|
||||
);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let output = Command::new("nvcc")
|
||||
.args([
|
||||
"-shared",
|
||||
"-o",
|
||||
so_path.to_str().unwrap(),
|
||||
&format!("-DLUMINAL_HEAD_DIM={}", head_dim),
|
||||
wrapper_cu_path.to_str().unwrap(),
|
||||
"-I",
|
||||
flashinfer_include.to_str().unwrap(),
|
||||
"-I",
|
||||
cutlass_include.to_str().unwrap(),
|
||||
"-I",
|
||||
wrapper_h_dir.to_str().unwrap(),
|
||||
"-std=c++17",
|
||||
&format!("-arch={}", arch),
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-w",
|
||||
"-rdc=true",
|
||||
"--compiler-options",
|
||||
"-fPIC",
|
||||
])
|
||||
.output()
|
||||
.expect("Failed to run nvcc. Is the CUDA toolkit installed?");
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let _ = std::fs::remove_file(&so_path);
|
||||
panic!(
|
||||
"FlashInfer JIT compilation failed (HEAD_DIM={}, arch={}):\nstdout: {}\nstderr: {}",
|
||||
head_dim, arch, stdout, stderr
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
"FlashInfer: compiled in {:.1}s → {}",
|
||||
elapsed.as_secs_f64(),
|
||||
so_path.display()
|
||||
);
|
||||
|
||||
so_path
|
||||
}
|
||||
|
||||
/// Returns ~/.cache/luminal/flashinfer/
|
||||
fn cache_directory() -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("luminal")
|
||||
.join("flashinfer")
|
||||
}
|
||||
|
||||
/// Drop the embedded wrapper.cu/wrapper.h into the cache dir so nvcc has files
|
||||
/// on disk to compile. Returns (wrapper.cu path, directory containing wrapper.h).
|
||||
fn extract_wrapper_sources(cache_dir: &Path) -> (PathBuf, PathBuf) {
|
||||
let cu = cache_dir.join("wrapper.cu");
|
||||
let h = cache_dir.join("wrapper.h");
|
||||
write_if_changed(&cu, WRAPPER_CU.as_bytes());
|
||||
write_if_changed(&h, WRAPPER_H.as_bytes());
|
||||
(cu, cache_dir.to_path_buf())
|
||||
}
|
||||
|
||||
fn write_if_changed(path: &Path, contents: &[u8]) {
|
||||
if let Ok(existing) = std::fs::read(path)
|
||||
&& existing == contents
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::fs::write(path, contents).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"FlashInfer: failed to write wrapper source to {}: {e}",
|
||||
path.display()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn wrapper_source_hash() -> u64 {
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
WRAPPER_CU.hash(&mut hasher);
|
||||
WRAPPER_H.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// ── Pinned FlashInfer source ──
|
||||
//
|
||||
// Bumping this constant invalidates the cached source tree AND the cached .so
|
||||
// (the .so cache key incorporates the wrapper hash, which is rebuilt against
|
||||
// these headers, so different headers compile to a different .so file even at
|
||||
// the same head_dim). If you change `FLASHINFER_GIT_REV`, also re-check
|
||||
// `wrapper.cu` against the new FlashInfer API.
|
||||
|
||||
const FLASHINFER_GIT_URL: &str = "https://github.com/flashinfer-ai/flashinfer.git";
|
||||
const CUTLASS_GIT_URL: &str = "https://github.com/NVIDIA/cutlass.git";
|
||||
const FLASHINFER_GIT_REV: &str = "f1e6fdcb8f65104047697f022b5d055ef022d763";
|
||||
const CUTLASS_GIT_REV: &str = "f3fde58372d33e9a5650ba7b80fc48b3b49d40c8";
|
||||
|
||||
fn locate_flashinfer_includes() -> Option<(PathBuf, PathBuf)> {
|
||||
if let Ok(path) = std::env::var("LUMINAL_FLASHINFER_DIR")
|
||||
&& !path.is_empty()
|
||||
{
|
||||
let root = PathBuf::from(path);
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
eprintln!(
|
||||
"FlashInfer: LUMINAL_FLASHINFER_DIR={} did not contain include/ and \
|
||||
3rdparty/cutlass/include/ — falling back to default locations",
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
|
||||
let home = std::env::var("HOME").unwrap_or_default();
|
||||
let candidates = [
|
||||
PathBuf::from(&home).join("luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
PathBuf::from(&home).join("luminal_cuda/flashinfer"),
|
||||
PathBuf::from("/opt/luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
];
|
||||
for root in candidates {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: fetch the pinned commit into the cache directory.
|
||||
fetch_flashinfer_source().ok().map(|root| {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
(inc, cutlass)
|
||||
})
|
||||
}
|
||||
|
||||
/// Clone FlashInfer at `FLASHINFER_GIT_REV` + CUTLASS at `CUTLASS_GIT_REV`
|
||||
/// into `~/.cache/luminal/flashinfer-src/<short_rev>/` if absent, then return
|
||||
/// the FlashInfer root directory. ~50 MB one-time download; subsequent calls
|
||||
/// short-circuit on the directory check.
|
||||
fn fetch_flashinfer_source() -> Result<PathBuf, String> {
|
||||
let short = &FLASHINFER_GIT_REV[..12];
|
||||
let cache_root = cache_directory().join("flashinfer-src").join(short);
|
||||
let inc = cache_root.join("include");
|
||||
let cutlass_inc = cache_root.join("3rdparty/cutlass/include");
|
||||
|
||||
if inc.exists() && cutlass_inc.exists() {
|
||||
return Ok(cache_root);
|
||||
}
|
||||
|
||||
let parent = cache_root.parent().unwrap();
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("failed to create {}: {e}", parent.display()))?;
|
||||
|
||||
// Clone into a staging dir, then atomic rename. Protects against multiple
|
||||
// processes racing to fetch the same source.
|
||||
let staging = parent.join(format!(".staging-{}-{}", short, std::process::id()));
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: cloning {FLASHINFER_GIT_URL} @ {short} into {} (one-time fetch, ~50 MB) …",
|
||||
cache_root.display()
|
||||
);
|
||||
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
FLASHINFER_GIT_URL,
|
||||
staging.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&staging, &["checkout", FLASHINFER_GIT_REV])?;
|
||||
|
||||
// Init only the CUTLASS submodule (skip spdlog — we don't need it for kernels).
|
||||
let cutlass_path = staging.join("3rdparty/cutlass");
|
||||
let _ = std::fs::remove_dir_all(&cutlass_path);
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
CUTLASS_GIT_URL,
|
||||
cutlass_path.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&cutlass_path, &["checkout", CUTLASS_GIT_REV])?;
|
||||
|
||||
if !staging.join("include").exists() {
|
||||
return Err(format!(
|
||||
"FlashInfer clone succeeded but include/ missing at {}",
|
||||
staging.display()
|
||||
));
|
||||
}
|
||||
if !staging.join("3rdparty/cutlass/include").exists() {
|
||||
return Err(format!(
|
||||
"CUTLASS clone succeeded but include/ missing at {}",
|
||||
staging.join("3rdparty/cutlass").display()
|
||||
));
|
||||
}
|
||||
|
||||
// Atomic-ish rename. If another process beat us to it, just keep theirs.
|
||||
match std::fs::rename(&staging, &cache_root) {
|
||||
Ok(()) => {}
|
||||
Err(_) if cache_root.exists() => {
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
}
|
||||
Err(e) => return Err(format!("rename to {} failed: {e}", cache_root.display())),
|
||||
}
|
||||
|
||||
Ok(cache_root)
|
||||
}
|
||||
|
||||
fn run_git(args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}. Is git installed?"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` failed: {}",
|
||||
args.join(" "),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_git_in(cwd: &Path, args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` in {} failed: {}",
|
||||
args.join(" "),
|
||||
cwd.display(),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect CUDA arch via env override → nvidia-smi → default sm_80.
|
||||
fn detect_cuda_arch() -> String {
|
||||
if let Ok(arch) = std::env::var("FLASHINFER_CUDA_ARCH") {
|
||||
return arch;
|
||||
}
|
||||
|
||||
if let Ok(output) = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
|
||||
.output()
|
||||
&& output.status.success()
|
||||
{
|
||||
let cap = String::from_utf8_lossy(&output.stdout);
|
||||
let cap = cap.trim().lines().next().unwrap_or("8.0");
|
||||
let sm = cap.replace('.', "");
|
||||
if !sm.is_empty() {
|
||||
return format!("sm_{}", sm);
|
||||
}
|
||||
}
|
||||
|
||||
"sm_80".to_string()
|
||||
}
|
||||
@@ -1,424 +0,0 @@
|
||||
pub mod find_indptrs;
|
||||
pub mod jit;
|
||||
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaSlice, CudaStream, DevicePtr, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// FlashInfer attention op (batch decode, fp32).
|
||||
///
|
||||
/// Replaces the full paged-GQA attention pattern (gather → broadcast → Q*K^T →
|
||||
/// scale → mask → softmax → *V) with a single FlashInfer fused kernel.
|
||||
///
|
||||
/// Graph inputs (7): Q, K_pool, V_pool, flat_gather_idx, mask, qo_indptr, kv_indptr.
|
||||
/// The egglog rule captures the first 5; `extract()` appends qo/kv indptrs after
|
||||
/// walking the e-graph from the mask. `batch_size` is derived at runtime from the
|
||||
/// indptr length (= num_sequences + 1).
|
||||
#[derive(Debug)]
|
||||
pub struct FlashInferAttention {
|
||||
pub num_qo_heads: usize,
|
||||
pub num_kv_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub page_size: usize,
|
||||
pub batch_dim: Expression,
|
||||
|
||||
pub plan_info: Mutex<Vec<i64>>,
|
||||
}
|
||||
|
||||
// SAFETY: PAGE_LOCKED_WORKSPACE holds a raw pointer to page-locked CUDA memory
|
||||
// allocated once and serialized via the CUDA stream that owns it.
|
||||
unsafe impl Send for FlashInferAttention {}
|
||||
unsafe impl Sync for FlashInferAttention {}
|
||||
|
||||
const FLOAT_WORKSPACE_SIZE: usize = 128 * 1024 * 1024; // 128 MiB
|
||||
const INT_WORKSPACE_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
static PAGE_LOCKED_WORKSPACE: OnceLock<PageLockedPtr> = OnceLock::new();
|
||||
|
||||
struct PageLockedPtr(*mut u8);
|
||||
|
||||
// SAFETY: The pointer is page-locked CUDA memory allocated once via
|
||||
// posix_memalign + cudaHostRegister and only mutated during OnceLock
|
||||
// initialization.
|
||||
unsafe impl Send for PageLockedPtr {}
|
||||
unsafe impl Sync for PageLockedPtr {}
|
||||
|
||||
impl std::fmt::Debug for PageLockedPtr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PageLockedPtr({:p})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FlashInferAttention {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_qo_heads: 0,
|
||||
num_kv_heads: 0,
|
||||
head_dim: 0,
|
||||
page_size: 0,
|
||||
batch_dim: Expression::default(),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for FlashInferAttention {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FlashInferAttention",
|
||||
&[
|
||||
("num_qo_heads", EXPRESSION),
|
||||
("num_kv_heads", EXPRESSION),
|
||||
("head_dim", EXPRESSION),
|
||||
("page_size", EXPRESSION),
|
||||
("batch_dim", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
// Q, K_pool, V_pool, flat_gather_idx, mask (egglog IList).
|
||||
// extract() appends qo_indptr + kv_indptr → 7 actual inputs at runtime.
|
||||
5
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["flashinfer_attention.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let num_qo_heads = extract_expr(egraph, kind_children[0], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let num_kv_heads = extract_expr(egraph, kind_children[1], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let head_dim = extract_expr(egraph, kind_children[2], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let page_size = extract_expr(egraph, kind_children[3], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let batch_dim = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
|
||||
let extracted = Self {
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
batch_dim,
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Trigger JIT compilation (or .so cache hit) at extract time, not at
|
||||
// first execute. Pays the ~30s cold-cache nvcc cost during compile
|
||||
// rather than during the GA profiling loop, where it would dominate
|
||||
// the candidate's measured runtime and make the GA reject FlashInfer.
|
||||
let _ = jit::ensure_compiled(head_dim);
|
||||
|
||||
// Walk the mask e-graph chain to recover qo_indptr / kv_indptr Input nodes.
|
||||
// input_enodes: [Q, K_cache, V_cache, gather_idx, mask]
|
||||
let mask_node = input_enodes[4];
|
||||
let indptrs = find_indptrs::find_indptr_inputs(egraph, mask_node);
|
||||
|
||||
// Build final inputs: [Q, K_cache, V_cache, gather_idx, mask, qo_indptr, kv_indptr]
|
||||
let mut final_inputs = input_enodes;
|
||||
final_inputs.push(indptrs.qo_indptr);
|
||||
final_inputs.push(indptrs.kv_indptr);
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
(op, final_inputs)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for FlashInferAttention {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let lib = jit::ensure_compiled(self.head_dim);
|
||||
|
||||
let total_q_tokens = self
|
||||
.batch_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention batch_dim is unresolved"))?;
|
||||
let c = *dyn_map
|
||||
.get(&'c')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'c'"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'r'"))?;
|
||||
|
||||
if inputs.len() < 7 {
|
||||
anyhow::bail!(
|
||||
"FlashInferAttention expects 7 inputs (Q, K, V, flat_idx, mask, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("FlashInferAttention missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_buf = get_buf("Q", inputs[0])?;
|
||||
let k_buf = get_buf("K_cache", inputs[1])?;
|
||||
let v_buf = get_buf("V_cache", inputs[2])?;
|
||||
let flat_idx_buf = get_buf("flat_gather_idx", inputs[3])?;
|
||||
// inputs[4] = mask (unused by FlashInfer — indptrs replace it)
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[6])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
// Derive batch_size (num sequences) from r = indptr length.
|
||||
let batch_size = r.saturating_sub(1);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"FlashInferAttention",
|
||||
total_q_tokens,
|
||||
batch_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let kv_dim = self.num_kv_heads * self.head_dim;
|
||||
let cu_stream = stream.cu_stream() as *mut std::ffi::c_void;
|
||||
|
||||
// Extract slot indices (one per context page) from the flat gather index.
|
||||
let indices_buf = unsafe { stream.alloc::<u8>(c.max(1) * std::mem::size_of::<i32>())? };
|
||||
let (indices_ptr, _idx_guard) = indices_buf.device_ptr(stream);
|
||||
|
||||
if c > 0 {
|
||||
unsafe {
|
||||
(lib.extract_slot_indices)(
|
||||
flat_idx_buf.ptr() as *const i32,
|
||||
indices_ptr as *mut i32,
|
||||
c as i32,
|
||||
kv_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Read kv_indptr to host for the plan phase.
|
||||
let kv_indptr_bytes = r * 4;
|
||||
let mut kv_indptr_host_bytes = vec![0u8; kv_indptr_bytes];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(
|
||||
&mut kv_indptr_host_bytes,
|
||||
kv_indptr_buf.ptr(),
|
||||
stream.cu_stream(),
|
||||
)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let kv_indptr_host: Vec<i32> = unsafe {
|
||||
let mut v = std::mem::ManuallyDrop::new(kv_indptr_host_bytes);
|
||||
Vec::from_raw_parts(v.as_mut_ptr() as *mut i32, r, r)
|
||||
};
|
||||
|
||||
// kv_last_page_len = [1; batch_size] when page_size=1.
|
||||
let last_page_host: Vec<i32> = vec![1; batch_size];
|
||||
let last_page_dev: CudaSlice<u8> = if batch_size > 0 {
|
||||
stream.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
last_page_host.as_ptr() as *const u8,
|
||||
last_page_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
unsafe { stream.alloc::<u8>(1)? }
|
||||
};
|
||||
let (last_page_ptr, _lp_guard) = last_page_dev.device_ptr(stream);
|
||||
|
||||
// Global shared workspaces (allocated once across all op instances to
|
||||
// amortize the ~4ms first-allocation cost during GA profiling).
|
||||
static FLOAT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
static INT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
let float_ws = FLOAT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(FLOAT_WORKSPACE_SIZE).unwrap() });
|
||||
let int_ws = INT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(INT_WORKSPACE_SIZE).unwrap() });
|
||||
let page_locked_ws = PAGE_LOCKED_WORKSPACE.get_or_init(|| unsafe {
|
||||
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
let status = libc::posix_memalign(&mut ptr, 4096, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(status, 0, "Failed to allocate page-locked workspace");
|
||||
let cuda_status = cuda_pin_memory(ptr, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(cuda_status, 0, "Failed to pin memory");
|
||||
PageLockedPtr(ptr as *mut u8)
|
||||
});
|
||||
|
||||
let (float_ws_ptr, _fws_guard) = float_ws.device_ptr(stream);
|
||||
let (int_ws_ptr, _iws_guard) = int_ws.device_ptr(stream);
|
||||
|
||||
// FlashInfer decode writes (total_q_tokens, heads, dim);
|
||||
// luminal expects (heads, total_q_tokens, dim) — transpose at the end.
|
||||
let output_elems = total_q_tokens * self.num_qo_heads * self.head_dim;
|
||||
let temp_out_buf =
|
||||
unsafe { stream.alloc::<u8>(output_elems * std::mem::size_of::<f32>())? };
|
||||
let (temp_out_ptr, _tmp_guard) = temp_out_buf.device_ptr(stream);
|
||||
|
||||
// PrefillPlanInfo has 15 entries, DecodePlanInfo fewer — 16 is enough.
|
||||
let mut plan_info_buf = [0i64; 16];
|
||||
let mut plan_info_len: i32 = 0;
|
||||
|
||||
// ── BatchDecode path ──
|
||||
// Prefill kernels require fp16/bf16 tensor-core MMA; the C API returns -1
|
||||
// when called from the fp32 pipeline. We only use decode here.
|
||||
let plan_ret = unsafe {
|
||||
(lib.plan)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
INT_WORKSPACE_SIZE,
|
||||
page_locked_ws.0 as *mut std::ffi::c_void,
|
||||
kv_indptr_host.as_ptr() as *mut i32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
plan_info_buf.as_mut_ptr(),
|
||||
&mut plan_info_len,
|
||||
)
|
||||
};
|
||||
if plan_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode plan failed with error code {plan_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
let mut plan_info = self.plan_info.lock().unwrap();
|
||||
plan_info.clear();
|
||||
plan_info.extend_from_slice(&plan_info_buf[..plan_info_len as usize]);
|
||||
|
||||
let run_ret = unsafe {
|
||||
(lib.run)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
plan_info.as_mut_ptr(),
|
||||
plan_info.len() as i32,
|
||||
q_buf.ptr() as *mut f32,
|
||||
k_buf.ptr() as *mut f32,
|
||||
v_buf.ptr() as *mut f32,
|
||||
kv_indptr_buf.ptr() as *mut i32,
|
||||
indices_ptr as *mut i32,
|
||||
last_page_ptr as *mut i32,
|
||||
temp_out_ptr as *mut f32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
)
|
||||
};
|
||||
drop(plan_info);
|
||||
|
||||
if run_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode run failed with error code {run_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
// Transpose (total_q_tokens, heads, dim) → (heads, total_q_tokens, dim)
|
||||
unsafe {
|
||||
(lib.transpose_output)(
|
||||
temp_out_ptr as *const f32,
|
||||
out_buf.ptr() as *mut f32,
|
||||
total_q_tokens as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.batch_dim * self.num_qo_heads * self.head_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("FlashInferAttention")
|
||||
}
|
||||
}
|
||||
|
||||
/// Pin host memory for CUDA async memcpy.
|
||||
///
|
||||
/// `cudaHostRegister` lives in libcudart, which cudarc doesn't link to our
|
||||
/// binary. Resolve it via `dlopen`/`dlsym` so we don't need a build script or
|
||||
/// a `#[link]` directive — keeping the crate buildable without any nvcc-side
|
||||
/// dependencies.
|
||||
unsafe fn cuda_pin_memory(ptr: *mut std::ffi::c_void, size: usize) -> i32 {
|
||||
type HostRegisterFn = unsafe extern "C" fn(*mut std::ffi::c_void, usize, u32) -> i32;
|
||||
static FN: OnceLock<usize> = OnceLock::new();
|
||||
|
||||
let raw = *FN.get_or_init(|| unsafe {
|
||||
let lib = [
|
||||
"libcudart.so",
|
||||
"libcudart.so.13",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
]
|
||||
.iter()
|
||||
.find_map(|n| libloading::Library::new(*n).ok())
|
||||
.expect("FlashInfer: could not dlopen libcudart for cudaHostRegister");
|
||||
let sym: libloading::Symbol<HostRegisterFn> = lib
|
||||
.get(b"cudaHostRegister\0")
|
||||
.expect("FlashInfer: libcudart missing cudaHostRegister symbol");
|
||||
let ptr = *sym as *const () as usize;
|
||||
// Keep libcudart resident for the process lifetime so the function
|
||||
// pointer remains valid.
|
||||
std::mem::forget(lib);
|
||||
ptr
|
||||
});
|
||||
let f: HostRegisterFn = unsafe { std::mem::transmute(raw) };
|
||||
// cudaHostRegisterDefault = 0
|
||||
unsafe { f(ptr, size, 0) }
|
||||
}
|
||||
@@ -1,357 +0,0 @@
|
||||
// FlashInfer batch decode + prefill wrapper for luminal_cuda.
|
||||
// JIT-compiled at runtime with -DLUMINAL_HEAD_DIM=N.
|
||||
//
|
||||
// Decode: instantiated for f32 (scalar vectorized dot products, no tensor cores).
|
||||
// Prefill: instantiated for f16 (requires tensor core MMA + ldmatrix).
|
||||
// The C API accepts fp32 buffers; cast kernels convert fp32↔fp16 at the boundary.
|
||||
//
|
||||
// NHD layout. GQA group_size and page_size are runtime parameters.
|
||||
|
||||
#ifndef LUMINAL_HEAD_DIM
|
||||
#error "LUMINAL_HEAD_DIM must be defined (e.g. -DLUMINAL_HEAD_DIM=128)"
|
||||
#endif
|
||||
|
||||
// Include utils.cuh first to get the original DISPATCH_HEAD_DIM, then override it
|
||||
// to only instantiate our specific HEAD_DIM. This avoids a compile error in
|
||||
// cascade.cuh where HEAD_DIM=512 + f32 triggers vec_size=16, vec_bits=512
|
||||
// which exceeds cp_async's 256-bit limit.
|
||||
#include <flashinfer/utils.cuh>
|
||||
#undef DISPATCH_HEAD_DIM
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
{ \
|
||||
constexpr size_t HEAD_DIM = LUMINAL_HEAD_DIM; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#include <flashinfer/attention/scheduler.cuh>
|
||||
#include <flashinfer/attention/decode.cuh>
|
||||
#include <flashinfer/attention/default_decode_params.cuh>
|
||||
#include <flashinfer/attention/prefill.cuh>
|
||||
#include <flashinfer/attention/default_prefill_params.cuh>
|
||||
#include <flashinfer/attention/mask.cuh>
|
||||
#include <flashinfer/attention/variants.cuh>
|
||||
#include <flashinfer/page.cuh>
|
||||
#include <flashinfer/pos_enc.cuh>
|
||||
|
||||
#include "wrapper.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// ── Decode types (f32) ──
|
||||
using DTypeQ = float;
|
||||
using DTypeKV = float;
|
||||
using DTypeO = float;
|
||||
using IdType = int32_t;
|
||||
|
||||
// ── Prefill types (f16 compute, fp32 external interface) ──
|
||||
using PrefillDTypeQ = half;
|
||||
using PrefillDTypeKV = half;
|
||||
using PrefillDTypeO = half;
|
||||
|
||||
constexpr uint32_t HEAD_DIM = LUMINAL_HEAD_DIM;
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;
|
||||
|
||||
// Attention variants
|
||||
using Variant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
using CausalVariant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
// Decode params (f32)
|
||||
using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
|
||||
|
||||
// Prefill params (f16)
|
||||
using PrefillParams = BatchPrefillPagedParams<PrefillDTypeQ, PrefillDTypeKV, PrefillDTypeO, IdType>;
|
||||
|
||||
// Forward declarations
|
||||
namespace flashinfer {
|
||||
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
|
||||
typename Params>
|
||||
cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
|
||||
PosEncodingMode POS_ENCODING_MODE, bool USE_FP16_QK_REDUCTION,
|
||||
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
|
||||
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
}
|
||||
|
||||
// Explicit instantiation: decode kernel (f32)
|
||||
template cudaError_t flashinfer::BatchDecodeWithPagedKVCacheDispatched<
|
||||
HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
DecodeParams params, DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// Explicit instantiation: prefill kernels (f16, causal mask, CTA_TILE_Q=16/64/128)
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
16, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
64, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
128, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// ── fp32 ↔ fp16 cast kernels ──
|
||||
|
||||
__global__ void cast_f32_to_f16_kernel(const float* src, half* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __float2half(src[i]);
|
||||
}
|
||||
|
||||
__global__ void cast_f16_to_f32_kernel(const half* src, float* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __half2float(src[i]);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
uint32_t group_size = num_qo_heads / num_kv_heads;
|
||||
|
||||
// We need to dispatch on GROUP_SIZE to get the right work estimation function
|
||||
cudaError_t status = cudaSuccess;
|
||||
|
||||
// Use a lambda to dispatch on group size
|
||||
auto do_plan = [&]<uint32_t GROUP_SIZE>() -> cudaError_t {
|
||||
auto work_estimation_func =
|
||||
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
|
||||
GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>;
|
||||
return DecodePlan<HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
float_workspace, float_ws_size,
|
||||
int_workspace, page_locked_int_workspace,
|
||||
int_ws_size, plan_info, indptr_h,
|
||||
(uint32_t)batch_size, (uint32_t)num_qo_heads,
|
||||
(uint32_t)page_size, /*enable_cuda_graph=*/false,
|
||||
stream, work_estimation_func);
|
||||
};
|
||||
|
||||
switch (group_size) {
|
||||
case 1: status = do_plan.operator()<1>(); break;
|
||||
case 2: status = do_plan.operator()<2>(); break;
|
||||
case 4: status = do_plan.operator()<4>(); break;
|
||||
case 8: status = do_plan.operator()<8>(); break;
|
||||
default: return -1; // unsupported group size
|
||||
}
|
||||
|
||||
if (status != cudaSuccess) return (int)status;
|
||||
|
||||
auto vec = plan_info.ToVector();
|
||||
*plan_info_len_out = (int)vec.size();
|
||||
std::memcpy(plan_info_out, vec.data(), vec.size() * sizeof(int64_t));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q,
|
||||
float* k_cache,
|
||||
float* v_cache,
|
||||
int32_t* kv_indptr,
|
||||
int32_t* kv_indices,
|
||||
int32_t* kv_last_page_len,
|
||||
float* output,
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
plan_info.FromVector(std::vector<int64_t>(plan_info_vec, plan_info_vec + plan_info_len));
|
||||
|
||||
// Construct paged_kv_t with NHD layout
|
||||
paged_kv_t<DTypeKV, IdType> paged_kv(
|
||||
(uint32_t)num_kv_heads,
|
||||
(uint32_t)page_size,
|
||||
HEAD_DIM,
|
||||
(uint32_t)batch_size,
|
||||
QKVLayout::kNHD,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
kv_last_page_len);
|
||||
|
||||
DecodeParams params;
|
||||
params.q = q;
|
||||
params.q_rope_offset = nullptr;
|
||||
params.paged_kv = paged_kv;
|
||||
params.o = output;
|
||||
params.lse = nullptr;
|
||||
params.maybe_alibi_slopes = nullptr;
|
||||
params.padded_batch_size = plan_info.padded_batch_size;
|
||||
params.num_qo_heads = (uint32_t)num_qo_heads;
|
||||
// Q buffer is (batch, num_qo_heads * head_dim) flat — the graph's split_dims + transpose
|
||||
// are stride tricks, no data movement. So the actual memory layout is (batch, heads, dim).
|
||||
params.q_stride_n = num_qo_heads * HEAD_DIM;
|
||||
params.q_stride_h = HEAD_DIM;
|
||||
params.window_left = -1; // no sliding window
|
||||
params.logits_soft_cap = 0.0f;
|
||||
params.sm_scale = 1.0f / sqrtf((float)HEAD_DIM);
|
||||
params.rope_rcp_scale = 1.0f;
|
||||
params.rope_rcp_theta = 1.0f;
|
||||
|
||||
// Set plan info pointers
|
||||
params.request_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.request_indices_offset);
|
||||
params.kv_tile_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_tile_indices_offset);
|
||||
params.o_indptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.o_indptr_offset);
|
||||
params.kv_chunk_size_ptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_chunk_size_ptr_offset);
|
||||
params.block_valid_mask = nullptr;
|
||||
params.partition_kv = false;
|
||||
|
||||
DTypeO* tmp_v = nullptr;
|
||||
float* tmp_s = nullptr;
|
||||
|
||||
if (plan_info.split_kv) {
|
||||
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_workspace, plan_info.v_offset);
|
||||
tmp_s = GetPtrFromBaseOffset<float>(float_workspace, plan_info.s_offset);
|
||||
if (plan_info.enable_cuda_graph) {
|
||||
params.block_valid_mask =
|
||||
GetPtrFromBaseOffset<bool>(int_workspace, plan_info.block_valid_mask_offset);
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t status =
|
||||
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, Variant>(
|
||||
params, tmp_v, tmp_s, /*enable_pdl=*/false, stream);
|
||||
|
||||
return (int)status;
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
// BatchPrefill (fp16/bf16 only — tensor core MMA requires 16-bit inputs)
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
//
|
||||
// The prefill kernel templates are instantiated above for fp16. These C API
|
||||
// functions accept fp32 pointers (matching the current luminal pipeline) but
|
||||
// return -1 to indicate that fp32 prefill is not supported. When native fp16
|
||||
// support is added, these will accept fp16 pointers and call through to the
|
||||
// instantiated templates.
|
||||
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void*, size_t, void*, size_t, void*,
|
||||
int32_t*, int32_t*, int, int,
|
||||
int, int, int, int, cudaStream_t,
|
||||
int64_t*, int*)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
int flashinfer_batch_prefill_run(
|
||||
void*, size_t, void*,
|
||||
int64_t*, int,
|
||||
float*, float*, float*,
|
||||
int32_t*, int32_t*, int32_t*, int32_t*,
|
||||
float*, int, int, int, int, int, int, cudaStream_t)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
// ── Slot index extraction kernel (outside extern "C" for __global__) ──
|
||||
|
||||
__global__ void extract_slot_indices_kernel(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < c) out[i] = flat_idx[i * kv_dim] / kv_dim;
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream) {
|
||||
if (c == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (c + threads - 1) / threads;
|
||||
extract_slot_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
flat_idx, out, c, kv_dim);
|
||||
}
|
||||
|
||||
// ── Derive CSR indptr from attention mask ──
|
||||
// Mask is (s, c) f32. Entries > -1e9 are "valid" (0.0), rest are -inf.
|
||||
// Per-row count of valid entries = context length for that sequence.
|
||||
// Output: indptr[0..=s] with indptr[0]=0 and indptr[i+1] = indptr[i] + ctx_len[i].
|
||||
// Single thread is fine since s is tiny (batch_size during decode, typically 1-8).
|
||||
|
||||
__global__ void derive_indptr_kernel(
|
||||
const float* mask, int32_t* indptr, int s, int c) {
|
||||
if (threadIdx.x != 0 || blockIdx.x != 0) return;
|
||||
indptr[0] = 0;
|
||||
for (int i = 0; i < s; i++) {
|
||||
int count = 0;
|
||||
for (int j = 0; j < c; j++) {
|
||||
if (mask[i * c + j] > -1e9f) count++;
|
||||
}
|
||||
indptr[i + 1] = indptr[i] + count;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream) {
|
||||
if (s == 0) return;
|
||||
derive_indptr_kernel<<<1, 1, 0, stream>>>(mask, indptr, s, c);
|
||||
}
|
||||
|
||||
// ── Output transpose: (batch, heads, dim) → (heads, batch, dim) ──
|
||||
// FlashInfer writes output as (batch, heads, dim) but Luminal expects (heads, batch, dim).
|
||||
// For batch=1 these are identical; for batch>1 we need an explicit transpose.
|
||||
|
||||
__global__ void transpose_bhd_to_hbd_kernel(
|
||||
const float* src, float* dst, int batch, int heads, int dim) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = batch * heads * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Decompose linear index into (b, h, d) for src layout
|
||||
int d = idx % dim;
|
||||
int h = (idx / dim) % heads;
|
||||
int b = idx / (heads * dim);
|
||||
|
||||
// Write to (h, b, d) layout in dst
|
||||
dst[h * batch * dim + b * dim + d] = src[idx];
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream) {
|
||||
int total = batch * heads * dim;
|
||||
if (total == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (total + threads - 1) / threads;
|
||||
transpose_bhd_to_hbd_kernel<<<blocks, threads, 0, stream>>>(
|
||||
src, dst, batch, heads, dim);
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Plan phase: CPU-side scheduling. Must call before each new batch config.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase: GPU kernel launch.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [batch_size, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* kv_indptr, // [batch_size + 1]
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [batch_size, num_qo_heads, head_dim]
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Extract slot indices from a flat gather index tensor.
|
||||
// flat_idx shape: (c, kv_dim) i32, out shape: (c,) i32.
|
||||
// out[i] = flat_idx[i * kv_dim] / kv_dim
|
||||
void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Derive CSR indptr from attention mask.
|
||||
// mask shape: (s, c) f32. Entries > -1e9 are valid.
|
||||
// indptr shape: (s + 1,) i32. indptr[0] = 0, indptr[i+1] = cumsum of valid counts.
|
||||
void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Transpose output from (batch, heads, dim) to (heads, batch, dim).
|
||||
void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// ── BatchPrefill with Paged KV Cache ──
|
||||
|
||||
// Plan phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* qo_indptr_h, int32_t* kv_indptr_h,
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [total_num_rows, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* qo_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [total_num_rows, num_qo_heads, head_dim]
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,122 +1,17 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub use compute_attn_mask::ComputeAttnMask;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
moe::GLUMoE,
|
||||
compute_attn_mask::ComputeAttnMask,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTypeTuple = (
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
&'static str,
|
||||
luminal::dtype::DType,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_type_tuple(op: &dyn HostOp) -> Option<CublasLtTypeTuple> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::type_tuple)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtScaleValues = (f64, f64);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_scale_values(op: &dyn HostOp) -> Option<CublasLtScaleValues> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::scale_values)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_epilogue(op: &dyn HostOp) -> Option<&'static str> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::epilogue)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtMatrixOrders = (&'static str, &'static str, &'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_matrix_orders(op: &dyn HostOp) -> Option<CublasLtMatrixOrders> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::matrix_orders)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTransposeOps = (&'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_transpose_ops(op: &dyn HostOp) -> Option<CublasLtTransposeOps> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::transpose_ops)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
/// the reusable arena, or an external pointer. Host ops only need the pointer
|
||||
/// and the logical byte length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct DeviceBuffer {
|
||||
ptr: u64,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl DeviceBuffer {
|
||||
pub fn new(ptr: u64, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
pub fn ptr(self) -> u64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn len(self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
|
||||
let mut host = vec![0u8; self.len];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(host)
|
||||
}
|
||||
}
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
@@ -134,7 +29,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
@@ -153,15 +48,6 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns relative lifetimes for extra buffer nodes within this host op.
|
||||
///
|
||||
/// The tuple is `(node, first_step, last_step)`, where steps are local to
|
||||
/// this host op's execution. Returning `None` tells the runtime to treat
|
||||
/// every extra buffer as live for the whole host op.
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -5,19 +5,12 @@
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared expert index/gather markers
|
||||
; 2) Shared gate-up matmul marker
|
||||
; 3) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 5) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEExpertIndexState
|
||||
(MkGLUMoEExpertIndexState Expression Expression IR)
|
||||
)
|
||||
(GLUMoEExpertGatherState
|
||||
(MkGLUMoEExpertGatherState Expression Expression IR IR)
|
||||
)
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
@@ -35,8 +28,6 @@
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_expert_index (IR) GLUMoEExpertIndexState :merge new)
|
||||
(function glumoe_expert_gather (IR) GLUMoEExpertGatherState :merge new)
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
@@ -45,38 +36,17 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?iota_base (Op (Iota ?io ?iota_base_range) (INil)))
|
||||
(= ?mul_base (Op (Mul ?mul_base_shape ?mul_base_a_stride ?mul_base_b_stride ?mul_base_out_stride) (ICons ?topk_idx (ICons ?iota_base (INil)))))
|
||||
(= ?iota_within (Op (Iota (MIter) ?iota_within_range) (INil)))
|
||||
(= ?add_idx (Op (Add ?add_shape ?add_a_stride ?add_b_stride ?add_out_stride) (ICons ?mul_base (ICons ?iota_within (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_index ?add_idx)
|
||||
(MkGLUMoEExpertIndexState ?io ?iota_within_range ?topk_idx))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert index marker"
|
||||
)
|
||||
; ===== Gate-up expert gather =====
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?index_state (glumoe_expert_index ?idx))
|
||||
(= ?index_state (MkGLUMoEExpertIndexState ?io ?within_range ?topk_idx))
|
||||
(= ?gathered (Op (Gather ?gather_idx_shape ?gather_idx_stride ?gather_data_shape ?gather_data_stride) (ICons ?idx (ICons ?weights (INil)))))
|
||||
(= ?f32 (Op (Cast ?f32_size (F32)) (ICons ?gathered (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_gather ?f32)
|
||||
(MkGLUMoEExpertGatherState ?io ?within_range ?topk_idx ?weights))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert gather marker"
|
||||
)
|
||||
; ===== Cast BF16→F32 =====
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?gather_state (glumoe_expert_gather ?gu_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?gu_io ?gu_iota_within_range ?topk_idx ?gate_up_w))
|
||||
; ===== Gate-up batched matmul =====
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
@@ -84,7 +54,6 @@
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
@@ -111,7 +80,6 @@
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
@@ -145,7 +113,6 @@
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
@@ -155,8 +122,12 @@
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -164,7 +135,6 @@
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
@@ -174,8 +144,12 @@
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -183,7 +157,6 @@
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
@@ -204,10 +177,7 @@
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
@@ -238,9 +208,6 @@
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
@@ -224,9 +224,8 @@ impl EgglogOp for GLUMoE {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
@@ -235,15 +234,17 @@ impl EgglogOp for GLUMoE {
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
),
|
||||
Rule::raw(include_str!["glumoe_rewrite.egg"]),
|
||||
]
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
@@ -294,140 +295,27 @@ impl HostOp for GLUMoE {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 6 {
|
||||
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
|
||||
}
|
||||
|
||||
// Resolve dimensions
|
||||
let hidden = self
|
||||
.gu_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
|
||||
let intermediate = self
|
||||
.dn_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
|
||||
let top_k = self
|
||||
.output_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
|
||||
let gu_io = self
|
||||
.gu_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
|
||||
let dn_io = self
|
||||
.dn_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
if hidden == 0 || intermediate == 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
|
||||
);
|
||||
}
|
||||
if top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if gu_io % hidden != 0 {
|
||||
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
|
||||
}
|
||||
if dn_io % intermediate != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
|
||||
);
|
||||
}
|
||||
|
||||
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
|
||||
let down_hidden = dn_io / intermediate;
|
||||
if gate_up_dim != intermediate * 2 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
|
||||
gate_up_dim,
|
||||
intermediate * 2
|
||||
);
|
||||
}
|
||||
if down_hidden != hidden {
|
||||
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
|
||||
}
|
||||
|
||||
let output_bytes = self
|
||||
.output_bytes()
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
|
||||
if output_bytes % (hidden * 4) != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
|
||||
hidden * 4
|
||||
);
|
||||
}
|
||||
let seq = output_bytes / (hidden * 4);
|
||||
if seq == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
|
||||
})
|
||||
};
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
|
||||
// Get input/output buffers
|
||||
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
|
||||
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
|
||||
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
|
||||
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
if output_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
|
||||
output_buf.len()
|
||||
);
|
||||
}
|
||||
|
||||
let gu_stride_bytes = gate_up_dim * hidden * 2;
|
||||
let down_stride_bytes = hidden * intermediate * 2;
|
||||
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
|
||||
gate_up_buf.len()
|
||||
);
|
||||
}
|
||||
let num_experts = gate_up_buf.len() / gu_stride_bytes;
|
||||
if num_experts == 0 {
|
||||
anyhow::bail!("GLUMoE has no expert weights");
|
||||
}
|
||||
if down_buf.len() < num_experts * down_stride_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down weight buffer too small: have {} bytes, need {}",
|
||||
down_buf.len(),
|
||||
num_experts * down_stride_bytes
|
||||
);
|
||||
}
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
@@ -439,17 +327,21 @@ impl HostOp for GLUMoE {
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
|
||||
|
||||
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
@@ -459,16 +351,9 @@ impl HostOp for GLUMoE {
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
if per_expert_scale_host.len() < per_expert_scale_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
|
||||
per_expert_scale_host.len()
|
||||
);
|
||||
}
|
||||
let per_expert_scale_f32: &[f32] =
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
@@ -498,10 +383,10 @@ impl HostOp for GLUMoE {
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = slice_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = slice_ptr(&workspace, stream);
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
@@ -520,8 +405,8 @@ impl HostOp for GLUMoE {
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
@@ -623,11 +508,7 @@ impl HostOp for GLUMoE {
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
|
||||
buf.ptr()
|
||||
}
|
||||
|
||||
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
|
||||
@@ -6,8 +6,12 @@
|
||||
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
|
||||
// pattern. Cascade prevention by typing.
|
||||
//
|
||||
// Each FusedX must be absorbed into a FusionEnd-rooted region and compiled by
|
||||
// `region_codegen`; standalone compilation is intentionally unsupported.
|
||||
// `compile()` is a *fallback* path. The fast path collapses each FE-rooted
|
||||
// region into one CUDA kernel inside `region_codegen` and FusedX/FS/FE
|
||||
// never reach kernel_to_host's compile loop. But extraction can produce
|
||||
// LLIR shapes the detector doesn't sweep into a region, so each FusedX's
|
||||
// standalone `compile()` falls back to emitting the same kernel its
|
||||
// un-fused KernelX sibling would — correct, just one launch per op.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -23,7 +27,11 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
pub type Ops = (
|
||||
FusedSin,
|
||||
@@ -47,6 +55,135 @@ type CompileOut = (
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// Fallback kernel templates — used when a FusedX op reaches
|
||||
// `kernel_to_host` standalone (region detection missed it). Same CUDA as
|
||||
// the matching un-fused KernelX would emit, parameterised by the per-op
|
||||
// body expression. The fast path goes through `region_codegen`.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compile_unary_fallback(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
body_expr: &str, // CUDA expression on `in[{in_idx}]`, e.g. "sinf(in[{in_idx}])"
|
||||
shape: &[Expression],
|
||||
in_strides: &[Expression],
|
||||
out_strides: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
|
||||
let out_idx = flatten_strides(shape, out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(shape, in_strides).to_kernel();
|
||||
let body = body_expr.replace("{in_idx}", &in_idx);
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 out[{out_idx}] = {body};\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compile_binary_fallback(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
op_str: &str, // CUDA infix operator, e.g. "+", "*"
|
||||
out_shape: &[Expression],
|
||||
a_stride: &[Expression],
|
||||
b_stride: &[Expression],
|
||||
out_stride: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype, dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(out_shape, out_stride).to_kernel();
|
||||
let a_idx = flatten_strides(out_shape, a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(out_shape, b_stride).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *C, const {cuda_ty} *A, const {cuda_ty} *B{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 C[{out_idx}] = A[{a_idx}] {op_str} B[{b_idx}];\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
|
||||
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
|
||||
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
|
||||
@@ -118,13 +255,19 @@ macro_rules! impl_fused_unary {
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
compile_unary_fallback(
|
||||
stream,
|
||||
compile_cache,
|
||||
$kernel_name,
|
||||
$body,
|
||||
&self.shape,
|
||||
&self.in_strides,
|
||||
&self.out_strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
@@ -236,13 +379,20 @@ macro_rules! impl_fused_binary {
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
compile_binary_fallback(
|
||||
stream,
|
||||
compile_cache,
|
||||
$kernel_name,
|
||||
$op_str,
|
||||
&self.out_shape,
|
||||
&self.a_stride,
|
||||
&self.b_stride,
|
||||
&self.out_stride,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
|
||||
@@ -27,7 +27,70 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
/// Identity-memcpy kernel used as a *fallback* when a FusionStart or
|
||||
/// FusionEnd reaches `kernel_to_host`'s compile loop standalone (i.e.,
|
||||
/// region detection didn't sweep it into a `CompileUnit::Region`). The
|
||||
/// fast path is region collapse, but model-fuzz extraction sometimes
|
||||
/// produces LLIR shapes the detector doesn't catch; this keeps
|
||||
/// execution correct in those cases.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn compile_identity_kernel(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
shape: &[Expression],
|
||||
strides: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
|
||||
let idx = flatten_strides(shape, strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 out[{idx}] = in[{idx}];\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
pub type Ops = (FusionStart, FusionEnd);
|
||||
|
||||
@@ -96,10 +159,17 @@ impl EgglogOp for FusionStart {
|
||||
impl KernelOp for FusionStart {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionStart must be compiled through fusion region codegen")
|
||||
compile_identity_kernel(
|
||||
stream,
|
||||
compile_cache,
|
||||
"fusion_start_k",
|
||||
&self.shape,
|
||||
&self.strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
@@ -113,9 +183,6 @@ impl KernelOp for FusionStart {
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
@@ -142,6 +209,14 @@ impl EgglogOp for FusionEnd {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Ablation switch: with `LUMINAL_DISABLE_BINARY_FUSION=1` set, do
|
||||
// not register any fusion rules. The e-graph never sees the FS/FE
|
||||
// bracketed alternative, extraction always picks the un-fused
|
||||
// form, and the runtime path matches main with no fusion at all.
|
||||
// Used to A/B fusion's runtime impact on a single binary.
|
||||
if std::env::var("LUMINAL_DISABLE_BINARY_FUSION").is_ok() {
|
||||
return Vec::new();
|
||||
}
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
@@ -183,7 +258,7 @@ impl EgglogOp for FusionEnd {
|
||||
(let ?fu2 (Op ({fo2} ?shape ?s ?s ?dt) (ICons ?fu1 (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu2 (INil))))
|
||||
(union ?u2 ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
) :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@@ -204,7 +279,7 @@ impl EgglogOp for FusionEnd {
|
||||
(let ?fu (Op ({fu} ?shape ?o_s ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
) :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@@ -227,7 +302,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?fu (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
) :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
@@ -242,7 +317,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?fs_a (ICons ?fu (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
) :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@@ -266,7 +341,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?fbi (ICons ?fs_c (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
) :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
@@ -284,7 +359,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?fs_c (ICons ?fbi (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
) :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
@@ -299,7 +374,7 @@ impl EgglogOp for FusionEnd {
|
||||
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{ku}\")"
|
||||
) :name \"grow-FE-U-{ku}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -316,7 +391,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{lb}\")"
|
||||
) :name \"grow-FE-B-lhs-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
@@ -329,16 +404,13 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{lb}\")"
|
||||
) :name \"grow-FE-B-rhs-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 7. Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(FB(ia, ib)).
|
||||
//
|
||||
// This is destructive: after creating the larger region, subsume the
|
||||
// two smaller FusionEnd rows. Without that, independently-grown left
|
||||
// and right regions form a Cartesian product, then those alternatives
|
||||
// can merge again higher in the graph.
|
||||
// Both inners reused, no new FS — shared external tensors with
|
||||
// upstream FSes stay at one FS.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
@@ -351,9 +423,7 @@ impl EgglogOp for FusionEnd {
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(subsume (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(subsume (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{lb}\")"
|
||||
) :name \"merge-FE-FE-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -393,10 +463,17 @@ impl EgglogOp for FusionEnd {
|
||||
impl KernelOp for FusionEnd {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionEnd must be compiled through fusion region codegen")
|
||||
compile_identity_kernel(
|
||||
stream,
|
||||
compile_cache,
|
||||
"fusion_end_k",
|
||||
&self.shape,
|
||||
&self.strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
|
||||
@@ -93,8 +93,9 @@ pub(crate) enum CompileUnit {
|
||||
/// (one whose e-graph congruence-deduplicated it across multiple
|
||||
/// regions) into a different subgraph than the FE that absorbs it.
|
||||
/// Without this global view, `build_compile_units` running on the FS's
|
||||
/// subgraph would not see any FE walking back to the FS and would emit the
|
||||
/// FS as `CompileUnit::Single`; marker standalone compilation is not supported.
|
||||
/// subgraph would not see any FE walking back to the FS, would emit the
|
||||
/// FS as `CompileUnit::Single`, and the markers' identity-memcpy
|
||||
/// fallback would compile and launch — pure overhead at runtime.
|
||||
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
@@ -195,10 +196,11 @@ pub(crate) fn build_compile_units(
|
||||
// Non-marker, non-FusedX predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
// not absorb — let the kernel_to_host single
|
||||
// path handle it. This means the region is
|
||||
// malformed and we likely should not have a
|
||||
// region at all; caller will see incomplete
|
||||
// interior.
|
||||
// region at all. Caller will see incomplete
|
||||
// interior; the safer thing is to fall back.
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,10 +253,11 @@ pub(crate) fn build_compile_units(
|
||||
// FE nodes with their RegionUnit and skipping anything absorbed —
|
||||
// either by a region in *this* subgraph (`absorbed`) or by any
|
||||
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
|
||||
// latter prevents shared FS markers whose consumers live in other
|
||||
// convex subgraphs from being emitted as standalone compile units:
|
||||
// latter prevents the identity-memcpy fallback from firing on
|
||||
// shared FS markers whose consumers live in other convex subgraphs:
|
||||
// those FSes are absorbed by some other region, and the consuming
|
||||
// region reads from FS's external producer.
|
||||
// region reads from FS's external producer, so the FS never needs
|
||||
// its own kernel.
|
||||
let mut units: Vec<CompileUnit> = Vec::new();
|
||||
for &node in topo_order {
|
||||
if let Some(region) = regions.remove(&node) {
|
||||
|
||||
@@ -8,7 +8,7 @@ use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, Term, app, eq, rule, set, sort, union, v},
|
||||
api::{Rule, SortDef, app, eq, rule, set, sort, union, v},
|
||||
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
@@ -79,48 +79,7 @@ pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
args.add("dtype", dt.clone());
|
||||
let llir_kind_term = llir.call(&args);
|
||||
let llir_op = op_term(llir_kind_term, inputs);
|
||||
rule(union(hlir_op.clone(), llir_op))
|
||||
.fact(eq(dt, dtype(hlir_op)))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
/// Build a kernel rewrite for ops whose kernel dtype must match the first input.
|
||||
///
|
||||
/// This avoids extracting stale/conflicting dtype facts from the output e-class
|
||||
/// after backend alternatives have been unioned into it.
|
||||
fn kernel_rewrite_from_first_input<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
let hlir = H::default().sort();
|
||||
let llir = L::default().sort();
|
||||
let (mut args, hlir_kind_term) = hlir.new_call();
|
||||
let first_inp = v("?__first_inp");
|
||||
let tail = v("?__tail");
|
||||
let inputs = Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![first_inp.clone(), tail],
|
||||
};
|
||||
let hlir_op = op_term(hlir_kind_term, inputs.clone());
|
||||
let dt = v("?__dt");
|
||||
args.add("dtype", dt.clone());
|
||||
let llir_kind_term = llir.call(&args);
|
||||
let llir_op = op_term(llir_kind_term, inputs);
|
||||
rule(union(hlir_op, llir_op))
|
||||
.fact(eq(dt, dtype(first_inp)))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
fn dtype_for_ir_enode(egraph: &SerializedEGraph, ir_node: &ENodeId) -> Option<DType> {
|
||||
let ir_class = egraph.node_to_class.get(ir_node)?;
|
||||
let dtype_node = egraph.enodes.iter().find_map(|(node, (label, children))| {
|
||||
(label == "dtype" && children.first() == Some(ir_class)).then_some(node)
|
||||
})?;
|
||||
let dtype_class = egraph.node_to_class.get(dtype_node)?;
|
||||
egraph.eclasses.get(dtype_class)?.1.iter().find_map(|node| {
|
||||
match egraph.enodes.get(node)?.0.as_str() {
|
||||
"F32" | "F16" | "Bf16" | "Int" | "Bool" | "F4E2M1" | "F8E4M3" | "F8UE8M0" | "I4"
|
||||
| "TF32" => Some(extract_dtype(egraph, node)),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
rule(union(hlir_op.clone(), llir_op)).fact(eq(dt, dtype(hlir_op)))
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -741,7 +700,7 @@ impl EgglogOp for KernelMul {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite_from_first_input::<Mul, Self>()]
|
||||
vec![kernel_rewrite::<Mul, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -756,45 +715,17 @@ impl EgglogOp for KernelMul {
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
// Some e-graph paths (length-changing rewrites such as `merge_dims`
|
||||
// or `RemoveNthFromEnd`) leave a Mul kind enode whose shape and
|
||||
// strides children are extracted to different lengths under the
|
||||
// first-enode walk. The `enforce_consistent_first_kind_enodes`
|
||||
// pass in `src/egglog_utils/mod.rs` repairs this where it can,
|
||||
// but a handful of eclasses have *no* consistent variant in any
|
||||
// of their stride sub-eclasses. For those we truncate to the
|
||||
// SHORTEST length here so `flatten_strides` is structurally
|
||||
// satisfied — the resulting kernel is numerically wrong for that
|
||||
// candidate but harmless for the search, which profiles many
|
||||
// candidates and steers toward the consistent ones.
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
let dtype = input_enodes
|
||||
.first()
|
||||
.and_then(|node| dtype_for_ir_enode(egraph, node))
|
||||
.unwrap_or_else(|| extract_dtype(egraph, kind_children[4]));
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype,
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
@@ -934,29 +865,13 @@ impl EgglogOp for KernelGather {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match HLIR Gather (now in Op format) and rewrite to KernelGather.
|
||||
// Mirror the IList pattern used by `Gather`'s own dtype propagation
|
||||
// rule (`src/hlir.rs`): use a `?__tail` variable instead of a
|
||||
// strict `(INil)` so we don't accidentally fail to match against a
|
||||
// Gather Op whose IList tail eclass has been merged with another
|
||||
// chain by some unrelated egglog union. Without this the kernel
|
||||
// rewrite is silently skipped for some Gathers in deep models
|
||||
// (e.g. YOLO's stacked make_contiguous chains).
|
||||
// Match HLIR Gather (now in Op format) and rewrite to KernelGather
|
||||
let hlir_gather = luminal::hlir::Gather::default().sort();
|
||||
let (gather_args, gather_kind_term) = hlir_gather.new_call();
|
||||
// HLIR Gather inputs: [indexes, data] (n_inputs=2)
|
||||
let indexes = v("?__indexes");
|
||||
let data = v("?__data");
|
||||
let tail = v("?__tail");
|
||||
let gather_inputs = Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![
|
||||
indexes.clone(),
|
||||
Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![data.clone(), tail],
|
||||
},
|
||||
],
|
||||
};
|
||||
let gather_inputs = ilist(vec![indexes.clone(), data.clone()]);
|
||||
let gather_op = op_term(gather_kind_term, gather_inputs);
|
||||
|
||||
let out_strides = SORTS
|
||||
@@ -979,11 +894,7 @@ impl EgglogOp for KernelGather {
|
||||
];
|
||||
let kernel_kind_term = self.sort().call(kernel_kind_args);
|
||||
let kernel_op = op_term(kernel_kind_term, ilist(vec![indexes, data.clone()]));
|
||||
vec![
|
||||
rule(union(gather_op, kernel_op))
|
||||
.fact(eq(dt, dtype(data)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(gather_op, kernel_op)).fact(eq(dt, dtype(data)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1218,11 +1129,7 @@ impl EgglogOp for KernelScatter {
|
||||
];
|
||||
let kernel_kind_term = self.sort().call(kernel_kind_args);
|
||||
let kernel_op = op_term(kernel_kind_term, ilist(vec![dest, indexes, src.clone()]));
|
||||
vec![
|
||||
rule(union(scatter_op, kernel_op))
|
||||
.fact(eq(dt, dtype(src)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(scatter_op, kernel_op)).fact(eq(dt, dtype(src)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1499,8 +1406,7 @@ impl EgglogOp for KernelIota {
|
||||
let kernel_op = op_term(kernel_kind, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op.clone()))
|
||||
.set(dtype(kernel_op), app(&SORTS.int_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
.set(dtype(kernel_op), app(&SORTS.int_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1540,22 +1446,19 @@ impl KernelOp for KernelIota {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
vars.extend(self.range.dyn_vars());
|
||||
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let range = self.range.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void iota_k(int *C{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {range}) return;
|
||||
C[const_z] = {};
|
||||
}}
|
||||
}}",
|
||||
@@ -1574,8 +1477,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.range.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(self.range, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2177,7 +2080,7 @@ extern \"C\" {{
|
||||
__global__ void recip_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / in[{in_idx}];
|
||||
out[{out_idx}] = ({dtype})1.0f / in[{in_idx}];
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
@@ -2588,11 +2491,7 @@ impl EgglogOp for KernelLessThan {
|
||||
args.add("dtype", dt.clone());
|
||||
let kernel_kind_term = self.sort().call(&args);
|
||||
let kernel_op = op_term(kernel_kind_term, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op))
|
||||
.fact(eq(dt, dtype(inp_a)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(hlir_op, kernel_op)).fact(eq(dt, dtype(inp_a)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2749,8 +2648,7 @@ impl EgglogOp for KernelConstant {
|
||||
let kernel_op = op_term(kernel_kind, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op.clone()))
|
||||
.set(dtype(kernel_op), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
.set(dtype(kernel_op), app(&SORTS.f32_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2892,11 +2790,7 @@ impl EgglogOp for KernelCast {
|
||||
cast_args.add("src_dtype", out_dty);
|
||||
let kernel_kind_term = self.sort().call(&cast_args);
|
||||
let kernel_op = op_term(kernel_kind_term, cast_inputs);
|
||||
vec![
|
||||
rule(union(cast_op, kernel_op))
|
||||
.fact(eq(in_dty, dtype(inp)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(cast_op, kernel_op)).fact(eq(in_dty, dtype(inp)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2938,14 +2832,6 @@ impl KernelOp for KernelCast {
|
||||
) {
|
||||
let out_dtype = cuda_dtype(self.out_dtype);
|
||||
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
|
||||
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let size = self.size.to_kernel();
|
||||
|
||||
let kernel = if self.in_dtype.bits() < 8 {
|
||||
// Sub-byte packed types: multiple values packed per byte.
|
||||
@@ -2955,11 +2841,9 @@ impl KernelOp for KernelCast {
|
||||
let mask = (1u32 << bits) - 1;
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {size}) return;
|
||||
long long bit_offset = idx * {bits};
|
||||
long long byte_idx = bit_offset >> 3;
|
||||
int bit_pos = (int)(bit_offset & 7);
|
||||
@@ -2975,11 +2859,9 @@ extern \"C\" {{
|
||||
let in_dtype = cuda_dtype(self.in_dtype);
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {size}) return;
|
||||
out[const_z] = ({out_dtype})in[const_z];
|
||||
}}
|
||||
}}"
|
||||
@@ -2998,8 +2880,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.size.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(self.size, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -3162,7 +3044,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with cast mul\"
|
||||
)"),
|
||||
// Match Gather with Add(Iota, Mul(Cast(token_ids), const)) indices (reversed order)
|
||||
@@ -3182,7 +3063,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with cast mul reversed\"
|
||||
)"),
|
||||
// Match Gather with Add(Mul(token_ids, const), Iota) indices (no Cast)
|
||||
@@ -3201,7 +3081,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with mul\"
|
||||
)"),
|
||||
// Match Gather with Add(Iota, Mul(token_ids, const)) indices (reversed order, no Cast)
|
||||
@@ -3220,7 +3099,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with mul reversed\"
|
||||
)"),
|
||||
]
|
||||
@@ -3281,24 +3159,15 @@ impl KernelOp for KernelEmbed {
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.embed_dim.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
let n_elements = total_threads.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
{}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {n_elements}) return;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
long long embed_idx = idx % embed_dim;
|
||||
@@ -3308,7 +3177,10 @@ extern \"C\" {{
|
||||
int token_id = token_ids[token_offset];
|
||||
out[out_offset + embed_idx] = embed_table[(long long)token_id * embed_dim + embed_idx];
|
||||
}}
|
||||
}}"
|
||||
}}",
|
||||
vars.iter()
|
||||
.map(|i| format!("__constant__ int const_{i}[1];"))
|
||||
.join("\n"),
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
@@ -3319,14 +3191,17 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let constants = vars
|
||||
.into_iter()
|
||||
.map(|d| (d, module.get_global(&format!("const_{d}"), stream).unwrap()))
|
||||
.collect();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(total_threads.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(total_threads, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
constants,
|
||||
)
|
||||
|
||||
@@ -128,8 +128,7 @@ impl KernelOp for KernelMeanReduce {
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block: usize = 256; // 8 warps per block
|
||||
let n_warps = threads_per_block / 32;
|
||||
let threads_per_block = 256; // 8 warps per block
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
@@ -150,24 +149,12 @@ extern \"C\" {{
|
||||
long long iters = {iters};
|
||||
long long iter_stride = {iter_stride};
|
||||
|
||||
float thread_sum = 0.0f;
|
||||
for (long long i = threadIdx.x; i < iters; i += {threads_per_block})
|
||||
thread_sum += (float)in[in_start + i * iter_stride];
|
||||
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
|
||||
|
||||
__shared__ float warp_sums[{n_warps}];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp = threadIdx.x >> 5;
|
||||
if (lane == 0) warp_sums[warp] = thread_sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {{
|
||||
float sum = 0.0f;
|
||||
for (int w = 0; w < {n_warps}; w++) sum += warp_sums[w];
|
||||
out[{out_index}] = ({dtype})(sum / (float)iters);
|
||||
{dtype} sum = 0;
|
||||
for (long long i = 0; i < iters; i++) {{
|
||||
sum += in[in_start + i * iter_stride];
|
||||
}}
|
||||
|
||||
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
@@ -180,8 +167,6 @@ extern \"C\" {{
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel(),
|
||||
threads_per_block = threads_per_block,
|
||||
n_warps = n_warps,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
@@ -198,9 +183,9 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(threads_per_block.into(), 1.into(), 1.into()), // block
|
||||
0.into(), // shmem size
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
@@ -294,9 +279,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match KernelScatter and rewrite to KernelScatterNoCopy with ConsumedBuffer on dest.
|
||||
// ConsumedBuffer wraps dest to signal in-place modification.
|
||||
// This is only valid when the destination buffer can also represent
|
||||
// the scatter output layout. If dest is a strided/broadcast view,
|
||||
// regular Scatter must first materialize a contiguous output copy.
|
||||
//
|
||||
// Two-phase resolution:
|
||||
// 1. During (run): cleanup rules delete ConsumedBuffer if dest is shared (another op uses it)
|
||||
@@ -307,31 +289,12 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
|
||||
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
|
||||
let mut rules = vec![
|
||||
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail)))
|
||||
((consumed_buffer_ilist_contains ?list ?head))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-head\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail))
|
||||
(consumed_buffer_ilist_contains ?tail ?item))
|
||||
((consumed_buffer_ilist_contains ?list ?item))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-tail\"
|
||||
)",
|
||||
),
|
||||
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?dst ?os)
|
||||
(= ?dty (dtype ?src))
|
||||
)
|
||||
(
|
||||
@@ -341,7 +304,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
(union ?scatter ?nocopy)
|
||||
(set (dtype ?nocopy) ?dty)
|
||||
)
|
||||
:ruleset buffer_reuse
|
||||
:name \"scatter to scatter-no-copy\"
|
||||
)",
|
||||
),
|
||||
@@ -351,7 +313,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?dt (dtype ?a)))
|
||||
((set (dtype ?cb) ?dt))
|
||||
:ruleset dtype_prop
|
||||
:name \"consumed-buffer-dtype\"
|
||||
)",
|
||||
),
|
||||
@@ -361,28 +322,13 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?op1 (Op ?k1 ?ilist1))
|
||||
(consumed_buffer_ilist_contains ?ilist1 ?cb)
|
||||
(= ?ilist1 (ICons ?cb ?rest1))
|
||||
(= ?op2 (Op ?k2 ?ilist2))
|
||||
(!= ?op1 ?op2)
|
||||
(consumed_buffer_ilist_contains ?ilist2 ?a))
|
||||
(= ?ilist2 (ICons ?a ?t2)))
|
||||
((delete (ConsumedBuffer ?a)))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-cleanup-shared-op-use\"
|
||||
)",
|
||||
));
|
||||
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
|
||||
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?dest))
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
|
||||
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
|
||||
:ruleset post_cleanup
|
||||
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
|
||||
:name \"consumed-buffer-cleanup-pos\"
|
||||
)",
|
||||
));
|
||||
// Surviving ConsumedBuffers are valid — union with source and delete.
|
||||
@@ -509,8 +455,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
scatter_kernel,
|
||||
(n_src.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(n_src, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -713,7 +659,6 @@ impl EgglogOp for KernelBatchMatVec {
|
||||
(union ?sum ?bmv)
|
||||
(set (dtype ?bmv) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch mat-vec\"
|
||||
)"
|
||||
)]
|
||||
@@ -994,7 +939,6 @@ impl EgglogOp for KernelBatchMatMul {
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
@@ -1234,7 +1178,6 @@ impl EgglogOp for KernelSoftmax {
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
@@ -1507,7 +1450,6 @@ impl EgglogOp for KernelExp {
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
@@ -1669,17 +1611,9 @@ impl EgglogOp for KernelSigmoid {
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Stage the HLIR sigmoid pattern through a small marker so repeated
|
||||
// default passes do not re-run one large join over every Mul/Add/Recip.
|
||||
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
|
||||
Rule::raw(
|
||||
"(datatype*
|
||||
(KernelSigmoidScaledState
|
||||
(MkKernelSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function kernel_sigmoid_scaled (IR) KernelSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
"(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
@@ -1689,33 +1623,19 @@ impl EgglogOp for KernelSigmoid {
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (kernel_sigmoid_scaled ?scaled)
|
||||
(MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (kernel_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
|
||||
@@ -7,13 +7,13 @@ use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, sys::CUgraphNode,
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr,
|
||||
sys::{CUgraphNode, CUresult, cuLaunchKernel},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
@@ -23,7 +23,7 @@ use luminal::{
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
@@ -48,12 +48,8 @@ struct CompiledKernel {
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Human-readable labels for input nodes, for launch diagnostics.
|
||||
input_labels: Vec<String>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
|
||||
has_dyn_dims_param: bool,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
@@ -73,9 +69,7 @@ impl CompiledKernel {
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
input_labels: Vec<String>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
has_dyn_dims_param: bool,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
@@ -86,9 +80,7 @@ impl CompiledKernel {
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
@@ -235,7 +227,7 @@ impl HostOp for CudaGraphOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
@@ -267,40 +259,6 @@ impl HostOp for CudaGraphOp {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
let state = self.state.borrow();
|
||||
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
|
||||
let max_step = state.kernels.len().saturating_sub(1);
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
lifetimes
|
||||
.entry(node)
|
||||
.and_modify(|(first, last)| {
|
||||
*first = (*first).min(step);
|
||||
*last = (*last).max(step);
|
||||
})
|
||||
.or_insert((step, step));
|
||||
};
|
||||
|
||||
for (step, kernel) in state.kernels.iter().enumerate() {
|
||||
for &input in &kernel.inputs {
|
||||
touch(input, step);
|
||||
}
|
||||
touch(kernel.node, step);
|
||||
}
|
||||
|
||||
for node in self.extra_buffer_nodes() {
|
||||
lifetimes.entry(node).or_insert((0, max_step));
|
||||
}
|
||||
|
||||
Some(
|
||||
lifetimes
|
||||
.into_iter()
|
||||
.map(|(node, (start, end))| (node, start, end))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
@@ -311,66 +269,21 @@ impl HostOp for CudaGraphOp {
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
|
||||
match kernel_name {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_requires_output_buffer(
|
||||
kernel: &CompiledKernel,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> bool {
|
||||
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
|
||||
&& kernel.kernel_op.output_aliases_input().is_none()
|
||||
}
|
||||
|
||||
fn validate_kernel_pointers(
|
||||
kernel: &CompiledKernel,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
|
||||
if *input_ptr == 0 {
|
||||
let input_label = kernel
|
||||
.input_labels
|
||||
.get(idx)
|
||||
.map(String::as_str)
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!(
|
||||
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
input_node,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Debug path: launch each kernel sequentially with sync between, so the
|
||||
// failing kernel surfaces instead of the generic "CudaGraph" panic.
|
||||
// Enable via `LUMINAL_DEBUG_SEQ=1`. Slow — only for diagnosing
|
||||
// CUDA_ERROR_ILLEGAL_ADDRESS / NaN / wrong-output bugs in graph batching.
|
||||
if std::env::var("LUMINAL_DEBUG_SEQ").is_ok() {
|
||||
return self.execute_sequential_for_debug(stream, buffers, dyn_map);
|
||||
}
|
||||
|
||||
let mut state = self.state.borrow_mut();
|
||||
let _span = span!(Level::TRACE, "cuda_graph", kernels = state.kernels.len()).entered();
|
||||
|
||||
@@ -439,7 +352,7 @@ impl CudaGraphOp {
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.ptr());
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -487,26 +400,13 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
@@ -533,19 +433,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
@@ -569,12 +456,158 @@ impl CudaGraphOp {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Diagnostic path for kernel-level errors that surface as a generic
|
||||
/// `CUDA_ERROR_ILLEGAL_ADDRESS` panic from the batched cuda_graph_exec
|
||||
/// launch. Bypasses CUDA-graph batching entirely: builds params per
|
||||
/// kernel and launches each via `cuLaunchKernel`, syncing afterwards so
|
||||
/// the offending kernel reports itself instead of being hidden inside
|
||||
/// the graph's atomic launch.
|
||||
///
|
||||
/// Enabled via `LUMINAL_DEBUG_SEQ=1`. ~10–100× slower than the graph
|
||||
/// path; not for production.
|
||||
fn execute_sequential_for_debug(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
let num_kernels = state.kernels.len();
|
||||
|
||||
// Allocate dyn_dims_buffer if needed and copy current values.
|
||||
if !self.dyn_dims_order.is_empty() && state.dyn_dims_buffer.is_none() {
|
||||
state.dyn_dims_buffer = Some(stream.alloc_zeros::<i32>(self.dyn_dims_order.len())?);
|
||||
}
|
||||
if !self.dyn_dims_order.is_empty() {
|
||||
let values: Vec<i32> = self
|
||||
.dyn_dims_order
|
||||
.iter()
|
||||
.map(|d| dyn_map.get(d).copied().unwrap_or(0) as i32)
|
||||
.collect();
|
||||
if let Some(buf) = state.dyn_dims_buffer.as_mut() {
|
||||
stream.memcpy_htod(&values, buf)?;
|
||||
}
|
||||
}
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Collect buffer pointers (mirrors the graph path).
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
for kernel in state.kernels.iter() {
|
||||
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
|
||||
&& let Some(&input_ptr) = buffer_ptrs.get(&kernel.inputs[input_idx])
|
||||
{
|
||||
buffer_ptrs.insert(kernel.node, input_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate internal buffers + run pre_execute for every kernel up front.
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &mut state.kernels[idx];
|
||||
if kernel.internal_bufs.is_empty() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
kernel.kernel_op.pre_execute(
|
||||
stream,
|
||||
&mut kernel.internal_bufs,
|
||||
&mut kernel.constants,
|
||||
&buffer_ptrs,
|
||||
dyn_map,
|
||||
);
|
||||
}
|
||||
|
||||
let cu_stream = stream.cu_stream();
|
||||
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &state.kernels[idx];
|
||||
let kernel_name = kernel.kernel_op.kernel_name();
|
||||
let node = kernel.node;
|
||||
|
||||
let grid = (
|
||||
kernel.grid.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let block = (
|
||||
kernel.block.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&node).copied().unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
let result = unsafe {
|
||||
cuLaunchKernel(
|
||||
cu_func,
|
||||
grid.0,
|
||||
grid.1,
|
||||
grid.2,
|
||||
block.0,
|
||||
block.1,
|
||||
block.2,
|
||||
shared_mem,
|
||||
cu_stream,
|
||||
params.as_cuda_params(),
|
||||
std::ptr::null_mut(),
|
||||
)
|
||||
};
|
||||
if result != CUresult::CUDA_SUCCESS {
|
||||
eprintln!(
|
||||
"[seq-debug] kernel #{idx}/{num_kernels} '{kernel_name}' \
|
||||
node={node:?} grid={grid:?} block={block:?} \
|
||||
output_ptr={output_ptr:#x} inputs={input_ptrs:#x?} \
|
||||
LAUNCH FAILED: {result:?}"
|
||||
);
|
||||
anyhow::bail!(
|
||||
"kernel #{idx} '{kernel_name}' (node {node:?}) launch failed: {result:?}"
|
||||
);
|
||||
}
|
||||
if let Err(e) = stream.synchronize() {
|
||||
eprintln!(
|
||||
"[seq-debug] kernel #{idx}/{num_kernels} '{kernel_name}' \
|
||||
node={node:?} grid={grid:?} block={block:?} \
|
||||
output_ptr={output_ptr:#x} inputs={input_ptrs:#x?} \
|
||||
SYNC FAILED: {e}"
|
||||
);
|
||||
anyhow::bail!(
|
||||
"kernel #{idx} '{kernel_name}' (node {node:?}) sync failed: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the CUDA graph from compiled kernels.
|
||||
fn build_graph(
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
@@ -596,7 +629,7 @@ impl CudaGraphOp {
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.ptr());
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -643,19 +676,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
@@ -664,41 +684,18 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
|
||||
eprintln!(
|
||||
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
kernel.inputs.len(),
|
||||
kernel.has_dyn_dims_param,
|
||||
params.values.len(),
|
||||
);
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
@@ -815,41 +812,11 @@ pub fn kernel_to_host(
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress the
|
||||
// identity-memcpy fallback for shared FS leaves whose consumers live
|
||||
// in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
|
||||
graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
|
||||
name_of(graph, node) == Some("FusionStart")
|
||||
|| graph[node].to_op::<LoopStart>().is_some()
|
||||
|| graph[node].to_op::<LoopEnd>().is_some()
|
||||
|| graph[node].to_op::<LoopInput>().is_some()
|
||||
|| graph[node].to_op::<LoopInputStatic>().is_some()
|
||||
|| graph[node].to_op::<LoopOutput>().is_some()
|
||||
|| graph[node].to_op::<LoopOutputSelect>().is_some()
|
||||
};
|
||||
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
|
||||
let mut visited = FxHashSet::default();
|
||||
while visited.insert(node) && is_transparent_input(graph, node) {
|
||||
let Some(pred) = graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.next()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
node = pred;
|
||||
}
|
||||
node
|
||||
};
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
// Track all CudaGraphOp nodes and their subgraphs for edge creation
|
||||
@@ -866,7 +833,6 @@ pub fn kernel_to_host(
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
let mut external_inputs = FxHashSet::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
@@ -880,7 +846,9 @@ pub fn kernel_to_host(
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
@@ -898,35 +866,14 @@ pub fn kernel_to_host(
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect_vec();
|
||||
if let Some(expected_inputs) =
|
||||
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
|
||||
{
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
expected_inputs,
|
||||
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel_op_ref.kernel_name(),
|
||||
kernel_node_idx,
|
||||
);
|
||||
}
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
@@ -937,12 +884,6 @@ pub fn kernel_to_host(
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
@@ -953,9 +894,7 @@ pub fn kernel_to_host(
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op.clone(),
|
||||
has_dyn_dims_param,
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
@@ -968,7 +907,6 @@ pub fn kernel_to_host(
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
@@ -980,20 +918,7 @@ pub fn kernel_to_host(
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region
|
||||
.external_inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect();
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
let inputs: Vec<NodeIndex> = region.external_inputs.clone();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
@@ -1001,12 +926,6 @@ pub fn kernel_to_host(
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
@@ -1017,9 +936,7 @@ pub fn kernel_to_host(
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
@@ -1064,17 +981,16 @@ pub fn kernel_to_host(
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into
|
||||
// subgraph. Also include normalized FusionStart predecessors, because
|
||||
// the compiled kernels read from the concrete producer buffer rather
|
||||
// than the marker node.
|
||||
external_inputs.extend(subgraph.iter().flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
}));
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -41,8 +41,9 @@ fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
all_names
|
||||
}
|
||||
|
||||
/// When dest is NOT shared with any other compute op, KernelScatterNoCopy should
|
||||
/// be the only scatter variant left after post-cleanup.
|
||||
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
|
||||
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
|
||||
/// the ConsumedBuffer (not in any other ICons).
|
||||
#[test]
|
||||
fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -61,17 +62,12 @@ fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should be the only scatter variant (dest is not shared)
|
||||
// KernelScatterNoCopy should be available (dest is not shared)
|
||||
assert!(
|
||||
names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy to be available but got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "Scatter"),
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid, got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
|
||||
@@ -113,74 +109,8 @@ fn test_scatter_nocopy_not_selected_when_dest_shared() {
|
||||
);
|
||||
}
|
||||
|
||||
/// Shared-use detection must catch the destination in non-first input
|
||||
/// positions too. Gather takes indexes first and data second, so this would
|
||||
/// miss the unsafe read if cleanup only inspected the head of the input list.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_when_dest_shared_as_later_input() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let scatter_indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
let read_indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
let scatter_result = src.scatter(scatter_indexes, dest);
|
||||
let _dest_also_read = dest.gather(read_indexes).output();
|
||||
let _result = scatter_result.output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest is read by another op, got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected regular Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// ScatterNoCopy aliases the destination buffer as the output, so it is only
|
||||
/// valid when the destination layout already matches the contiguous scatter
|
||||
/// output layout. Broadcast/expanded destinations need regular Scatter's
|
||||
/// copy-then-scatter materialization.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_for_expanded_dest_layout() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dest = cx.tensor(128).expand_dim(0, 4).persist();
|
||||
let src = cx.tensor((4, 128)).persist();
|
||||
let indexes = cx.tensor((4, 128)).as_dtype(DType::Int).persist();
|
||||
|
||||
let _result = src.scatter(indexes, dest).output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest layout differs from output, got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected regular Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// Actually execute the scatter and verify correctness.
|
||||
/// Post-cleanup should force the valid no-copy extraction.
|
||||
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
|
||||
#[test]
|
||||
fn test_scatter_execution_correctness() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -205,8 +135,9 @@ fn test_scatter_execution_correctness() {
|
||||
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
|
||||
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
|
||||
|
||||
// Try many random extractions; each valid choice should now use ScatterNoCopy.
|
||||
// Try many random extractions to cover both Scatter and ScatterNoCopy
|
||||
let mut rng = rand::rng();
|
||||
let mut tested_scatter = false;
|
||||
let mut tested_nocopy = false;
|
||||
|
||||
for _ in 0..50 {
|
||||
@@ -249,24 +180,27 @@ fn test_scatter_execution_correctness() {
|
||||
|
||||
let actual = rt.get_f32(result);
|
||||
|
||||
assert!(
|
||||
has_nocopy,
|
||||
"Expected ScatterNoCopy after post-cleanup, got no no-copy scatter"
|
||||
);
|
||||
assert!(
|
||||
!has_scatter,
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid"
|
||||
);
|
||||
tested_nocopy = true;
|
||||
let variant = if has_nocopy {
|
||||
tested_nocopy = true;
|
||||
"ScatterNoCopy"
|
||||
} else if has_scatter {
|
||||
tested_scatter = true;
|
||||
"Scatter"
|
||||
} else {
|
||||
"Unknown"
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Scatter result mismatch with ScatterNoCopy: got {:?}, expected {:?}",
|
||||
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
|
||||
actual, expected
|
||||
);
|
||||
}
|
||||
|
||||
println!("Tested ScatterNoCopy: {}", tested_nocopy);
|
||||
println!(
|
||||
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
|
||||
tested_scatter, tested_nocopy
|
||||
);
|
||||
assert!(
|
||||
tested_nocopy,
|
||||
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
|
||||
@@ -308,28 +242,14 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Selected: {name}");
|
||||
// Print which scatter variant was selected
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
scatter_names.contains(&"ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy in KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
assert!(
|
||||
!scatter_names.contains(&"Scatter"),
|
||||
"Regular Scatter should be pruned from KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
@@ -424,31 +344,19 @@ fn test_scatter_dual_cache() {
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
// Use seeded search for deterministic variant selection.
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Dual test selected: {name}");
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Dual test selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
!scatter_names.is_empty(),
|
||||
"Expected scatter kernels in dual-cache search result"
|
||||
);
|
||||
assert!(
|
||||
scatter_names.iter().all(|name| *name == "ScatterNoCopy"),
|
||||
"Expected only ScatterNoCopy in dual-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,941 +0,0 @@
|
||||
//! Unit + integration tests for the FlashInfer port.
|
||||
//!
|
||||
//! Four layers:
|
||||
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
|
||||
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
|
||||
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
|
||||
//! 3. Mask op correctness (GPU): `ComputeAttnMask` produces the right (s, c) mask.
|
||||
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
|
||||
//! compared against a luminal-compiled reference attention graph.
|
||||
//!
|
||||
//! GPU-dependent tests short-circuit when no CUDA device is available.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use cudarc::driver::{CudaStream, DevicePtr};
|
||||
use luminal::egglog_utils::{hlir_to_egglog, run_egglog};
|
||||
use luminal::op::{EgglogOp, IntoEgglogOp};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::host::flashinfer::FlashInferAttention;
|
||||
use crate::host::{ComputeAttnMask, DeviceBuffer, HostOp};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::get_cuda_stream;
|
||||
|
||||
/// Look up an op in `CudaRuntime::Ops::into_vec()` by its egglog sort name.
|
||||
fn ops_contains_sort(name: &str) -> bool {
|
||||
let ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.iter().any(|op| {
|
||||
// `SortDef` is opaque; its Debug repr starts with the sort name.
|
||||
let sort_dbg = format!("{:?}", op.sort());
|
||||
sort_dbg.contains(name)
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Test-wide model dimensions ───────────────────────────────────────────
|
||||
//
|
||||
// Small Llama-shaped GQA model: nheads=8, kv_heads=2, group=4, head_dim=64.
|
||||
// Chosen so HEAD_DIM ∈ {64, 128, 256} (FlashInfer constraint) and the test
|
||||
// suite fits in O(1ms) of GPU time per case.
|
||||
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_KV_HEADS: usize = 2;
|
||||
const KV_GROUPS: usize = 4;
|
||||
const N_HEADS: usize = N_KV_HEADS * KV_GROUPS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const HIDDEN: usize = N_HEADS * HEAD_DIM;
|
||||
|
||||
// ─── Reference attention graph (Q*K^T → softmax → *V via the compiler) ───
|
||||
|
||||
fn build_attention_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', HIDDEN));
|
||||
let k_ctx = cx.named_tensor("k_ctx", ('c', KV_DIM));
|
||||
let v_ctx_input = cx.named_tensor("v_ctx", ('c', KV_DIM));
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k_ctx.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx_input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// GQA broadcast: zero-stride Mul by 1.0
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let weights = scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
(cx, q_rope, k_ctx, v_ctx_input, attn_out)
|
||||
}
|
||||
|
||||
fn run_reference_attention(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch_size: usize,
|
||||
context_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
|
||||
cx.set_dim('s', batch_size);
|
||||
cx.set_dim('c', context_len);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, 3);
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt.execute(&cx.dyn_map);
|
||||
rt.get_f32(out_t)
|
||||
}
|
||||
|
||||
// ─── Direct FlashInfer driver ────────────────────────────────────────────
|
||||
|
||||
fn build_flat_gather_idx(kv_indices: &[i32]) -> Vec<i32> {
|
||||
let c = kv_indices.len();
|
||||
let mut flat = Vec::with_capacity(c * KV_DIM);
|
||||
for &slot in kv_indices {
|
||||
let base = slot * KV_DIM as i32;
|
||||
for j in 0..KV_DIM as i32 {
|
||||
flat.push(base + j);
|
||||
}
|
||||
}
|
||||
flat
|
||||
}
|
||||
|
||||
fn transpose_hbd_to_bhd(data: &[f32], heads: usize, batch: usize, dim: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; data.len()];
|
||||
for h in 0..heads {
|
||||
for b in 0..batch {
|
||||
for d in 0..dim {
|
||||
out[b * heads * dim + h * dim + d] = data[h * batch * dim + b * dim + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn alloc_dev(stream: &Arc<CudaStream>, bytes: usize) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = bytes.max(1);
|
||||
unsafe { stream.alloc::<u8>(bytes).unwrap() }
|
||||
}
|
||||
|
||||
fn copy_to_dev<T: Copy>(stream: &Arc<CudaStream>, data: &[T]) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
|
||||
};
|
||||
stream.clone_htod(bytes).unwrap()
|
||||
}
|
||||
|
||||
/// Run FlashInferAttention.execute() directly and reshape the output to the
|
||||
/// reference (batch, heads, dim) layout used by `run_reference_attention`.
|
||||
fn run_flashinfer(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k_cache: &[f32],
|
||||
v_cache: &[f32],
|
||||
kv_indptr: &[i32],
|
||||
kv_indices: &[i32],
|
||||
batch_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let q_buf = copy_to_dev(stream, q);
|
||||
let k_buf = copy_to_dev(stream, k_cache);
|
||||
let v_buf = copy_to_dev(stream, v_cache);
|
||||
let flat_idx = build_flat_gather_idx(kv_indices);
|
||||
let flat_idx_buf = copy_to_dev(stream, &flat_idx);
|
||||
let mask_buf = alloc_dev(stream, 4); // unused but reserved
|
||||
let qo_indptr: Vec<i32> = (0..=batch_size as i32).collect();
|
||||
let qo_indptr_buf = copy_to_dev(stream, &qo_indptr);
|
||||
let kv_indptr_buf = copy_to_dev(stream, kv_indptr);
|
||||
let out_buf = alloc_dev(stream, batch_size * HIDDEN * 4);
|
||||
|
||||
let fi = FlashInferAttention {
|
||||
num_qo_heads: N_HEADS,
|
||||
num_kv_heads: N_KV_HEADS,
|
||||
head_dim: HEAD_DIM,
|
||||
page_size: 1,
|
||||
batch_dim: Expression::from('s'),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Reserve dedicated NodeIndex values for the test ports.
|
||||
let nodes: Vec<NodeIndex> = (0..8).map(NodeIndex::new).collect();
|
||||
let (q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n, out_n) = (
|
||||
nodes[0], nodes[1], nodes[2], nodes[3], nodes[4], nodes[5], nodes[6], nodes[7],
|
||||
);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
let q_ptr = q_buf.device_ptr(stream).0;
|
||||
let k_ptr = k_buf.device_ptr(stream).0;
|
||||
let v_ptr = v_buf.device_ptr(stream).0;
|
||||
let idx_ptr = flat_idx_buf.device_ptr(stream).0;
|
||||
let mask_ptr = mask_buf.device_ptr(stream).0;
|
||||
let qo_ptr = qo_indptr_buf.device_ptr(stream).0;
|
||||
let kv_ptr = kv_indptr_buf.device_ptr(stream).0;
|
||||
let out_ptr = out_buf.device_ptr(stream).0;
|
||||
buffers.insert(q_n, DeviceBuffer::new(q_ptr, q.len() * 4));
|
||||
buffers.insert(k_n, DeviceBuffer::new(k_ptr, k_cache.len() * 4));
|
||||
buffers.insert(v_n, DeviceBuffer::new(v_ptr, v_cache.len() * 4));
|
||||
buffers.insert(idx_n, DeviceBuffer::new(idx_ptr, flat_idx.len() * 4));
|
||||
buffers.insert(mask_n, DeviceBuffer::new(mask_ptr, 4));
|
||||
buffers.insert(qo_n, DeviceBuffer::new(qo_ptr, qo_indptr.len() * 4));
|
||||
buffers.insert(kv_n, DeviceBuffer::new(kv_ptr, kv_indptr.len() * 4));
|
||||
buffers.insert(out_n, DeviceBuffer::new(out_ptr, batch_size * HIDDEN * 4));
|
||||
|
||||
let inputs = [q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n];
|
||||
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', batch_size);
|
||||
dyn_map.insert('c', kv_indices.len());
|
||||
dyn_map.insert('r', kv_indptr.len());
|
||||
|
||||
fi.execute(stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.expect("FlashInferAttention execute failed");
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Output is (heads, batch, dim); reshape to (batch, heads, dim).
|
||||
let mut out_bytes = vec![0u8; batch_size * HIDDEN * 4];
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtoh_async(&mut out_bytes, out_ptr, stream.cu_stream())
|
||||
.unwrap();
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let raw: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(out_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
transpose_hbd_to_bhd(&raw, N_HEADS, batch_size, HEAD_DIM)
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn deterministic_f32(n: usize, seed: f32, scale: f32) -> Vec<f32> {
|
||||
(0..n).map(|i| (i as f32 * seed).sin() * scale).collect()
|
||||
}
|
||||
|
||||
fn assert_close(a: &[f32], b: &[f32], rtol: f32, atol: f32) {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"length mismatch: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
);
|
||||
let mut worst = (0usize, 0.0f32);
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff = (x - y).abs();
|
||||
if diff > worst.1 {
|
||||
worst = (i, diff);
|
||||
}
|
||||
let tol = atol + rtol * y.abs();
|
||||
assert!(
|
||||
diff <= tol,
|
||||
"mismatch at idx {i}: {x} vs {y} (|diff|={diff}, tol={tol})"
|
||||
);
|
||||
}
|
||||
eprintln!("max |diff| = {:.2e} @ idx {}", worst.1, worst.0);
|
||||
}
|
||||
|
||||
// ─── Layer 1: egglog metadata sanity (no GPU) ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_registers_via_into_egglog() {
|
||||
// Confirm the op is reachable through the Runtime::Ops tuple. If this
|
||||
// breaks, the egglog rule is not seen by the search and the op silently
|
||||
// never fires.
|
||||
assert!(
|
||||
ops_contains_sort("FlashInferAttention"),
|
||||
"FlashInferAttention is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_egg_rule_parses() {
|
||||
// Rule::raw() returns the rule with no validation; egglog parses it at
|
||||
// graph build. Smoke-test by running it through the egglog frontend via
|
||||
// a tiny program string.
|
||||
let op = FlashInferAttention::default();
|
||||
let rewrites = op.rewrites();
|
||||
assert_eq!(rewrites.len(), 1);
|
||||
// The rule must mention FlashInferAttention to be the right one.
|
||||
let s = format!("{:?}", rewrites[0]);
|
||||
assert!(
|
||||
s.contains("FlashInferAttention"),
|
||||
"rewrite is not the FlashInfer rule: {s}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_sort_shape() {
|
||||
let op = FlashInferAttention::default();
|
||||
let s = op.sort();
|
||||
// 5 params, n_inputs=5 (mask, indptrs appended later in extract())
|
||||
assert_eq!(op.n_inputs(), 5);
|
||||
let dbg = format!("{:?}", s);
|
||||
assert!(dbg.contains("FlashInferAttention"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_registers() {
|
||||
assert!(
|
||||
ops_contains_sort("ComputeAttnMask"),
|
||||
"ComputeAttnMask is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 2: ComputeAttnMask correctness ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_matches_cpu_reference() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// 2 sequences, seq0 length=3, seq1 length=2 → s=2 queries (one per seq, decode),
|
||||
// c=5 total context tokens (3+2).
|
||||
let s_dim = 2usize;
|
||||
let c_dim = 5usize;
|
||||
let q_pos: Vec<i32> = vec![2, 1]; // last position in each seq
|
||||
let qo_indptr: Vec<i32> = vec![0, 1, 2];
|
||||
let kv_indptr: Vec<i32> = vec![0, 3, 5];
|
||||
let r = kv_indptr.len();
|
||||
|
||||
let q_pos_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(q_pos.as_ptr() as *const u8, q_pos.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let qo_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(qo_indptr.as_ptr() as *const u8, qo_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let kv_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(kv_indptr.as_ptr() as *const u8, kv_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let out_bytes = s_dim * c_dim * 4;
|
||||
let out_buf = unsafe { stream.alloc::<u8>(out_bytes).unwrap() };
|
||||
|
||||
let op = ComputeAttnMask {
|
||||
s_dim: Expression::from(s_dim),
|
||||
c_dim: Expression::from(c_dim),
|
||||
};
|
||||
|
||||
let q_pos_n = NodeIndex::new(0);
|
||||
let qo_n = NodeIndex::new(1);
|
||||
let kv_n = NodeIndex::new(2);
|
||||
let out_n = NodeIndex::new(3);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
buffers.insert(
|
||||
q_pos_n,
|
||||
DeviceBuffer::new(q_pos_buf.device_ptr(&stream).0, q_pos.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
qo_n,
|
||||
DeviceBuffer::new(qo_buf.device_ptr(&stream).0, qo_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
kv_n,
|
||||
DeviceBuffer::new(kv_buf.device_ptr(&stream).0, kv_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
out_n,
|
||||
DeviceBuffer::new(out_buf.device_ptr(&stream).0, out_bytes),
|
||||
);
|
||||
|
||||
let inputs = [q_pos_n, qo_n, kv_n];
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('r', r);
|
||||
|
||||
op.execute(&stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.unwrap();
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let host_bytes = stream.clone_dtoh(&out_buf).unwrap();
|
||||
let mask: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
|
||||
// Expected: query 0 (q_pos=2, seq 0) attends to ctx [0, 3) i.e. mask[0, 0..3]=0;
|
||||
// query 1 (q_pos=1, seq 1) attends to ctx [3, 5) i.e. mask[1, 3..5]=0.
|
||||
// Everywhere else is -1e10.
|
||||
let mut expected = vec![-1e10f32; s_dim * c_dim];
|
||||
for j in 0..3 {
|
||||
expected[0 * c_dim + j] = 0.0;
|
||||
}
|
||||
for j in 3..5 {
|
||||
expected[1 * c_dim + j] = 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(mask, expected);
|
||||
}
|
||||
|
||||
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs1_ctx4() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k = deterministic_f32(context_len * KV_DIM, 0.021, 0.1);
|
||||
let v = deterministic_f32(context_len * KV_DIM, 0.031, 0.1);
|
||||
let expected = run_reference_attention(&stream, &q, &k, &v, batch_size, context_len);
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = (0..context_len as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs2_supersequence() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 2;
|
||||
let ctx0 = 8;
|
||||
let ctx1 = 3;
|
||||
let total_ctx = ctx0 + ctx1;
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.014, 0.1);
|
||||
let k = deterministic_f32(total_ctx * KV_DIM, 0.022, 0.1);
|
||||
let v = deterministic_f32(total_ctx * KV_DIM, 0.032, 0.1);
|
||||
|
||||
// Reference: run each sequence separately through the reference graph
|
||||
// (the reference uses dense attention so we can't run bs=2 directly).
|
||||
let expected0 = run_reference_attention(
|
||||
&stream,
|
||||
&q[..HIDDEN],
|
||||
&k[..ctx0 * KV_DIM],
|
||||
&v[..ctx0 * KV_DIM],
|
||||
1,
|
||||
ctx0,
|
||||
);
|
||||
let expected1 = run_reference_attention(
|
||||
&stream,
|
||||
&q[HIDDEN..],
|
||||
&k[ctx0 * KV_DIM..],
|
||||
&v[ctx0 * KV_DIM..],
|
||||
1,
|
||||
ctx1,
|
||||
);
|
||||
let expected: Vec<f32> = expected0.into_iter().chain(expected1).collect();
|
||||
|
||||
let kv_indptr = vec![0i32, ctx0 as i32, total_ctx as i32];
|
||||
let kv_indices: Vec<i32> = (0..total_ctx as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_noncontiguous_page_table() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let num_slots = 8;
|
||||
let slot_indices = [3usize, 0, 7, 1];
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k_full = deterministic_f32(num_slots * KV_DIM, 0.022, 0.1);
|
||||
let v_full = deterministic_f32(num_slots * KV_DIM, 0.033, 0.1);
|
||||
|
||||
// Reference operates on the contiguous gathered cache.
|
||||
let mut k_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
let mut v_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
for (i, &slot) in slot_indices.iter().enumerate() {
|
||||
k_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&k_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
v_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&v_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
}
|
||||
let expected = run_reference_attention(
|
||||
&stream,
|
||||
&q,
|
||||
&k_gathered,
|
||||
&v_gathered,
|
||||
batch_size,
|
||||
context_len,
|
||||
);
|
||||
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = slot_indices.iter().map(|&s| s as i32).collect();
|
||||
let result = run_flashinfer(
|
||||
&stream,
|
||||
&q,
|
||||
&k_full,
|
||||
&v_full,
|
||||
&kv_indptr,
|
||||
&kv_indices,
|
||||
batch_size,
|
||||
);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
// ─── Layer 3b: HEAD_DIM 128 path (validates the head-dim JIT dispatch) ────
|
||||
//
|
||||
// Each FlashInfer .so is compiled for one HEAD_DIM. JIT caches by head dim;
|
||||
// the OnceLock means only one is loaded per process. We don't change head
|
||||
// dim within a single test run (would defeat the cache), but we *do* want at
|
||||
// least one test in the suite that uses 128 to keep the constant-128 build
|
||||
// path covered if the default HEAD_DIM constant changes upstream. We assert
|
||||
// the constraint here rather than firing a second JIT.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_jit_head_dim_assertion() {
|
||||
// 64 / 128 / 256 must be the only allowed values.
|
||||
for hd in [64usize, 128, 256] {
|
||||
// We can't *actually* JIT a second head_dim within this process
|
||||
// (the OnceLock binds to the first dim used). Just check the dim
|
||||
// is in the supported set.
|
||||
assert!(matches!(hd, 64 | 128 | 256));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Layer 4: egglog rule firing (no GPU) ────────────────────────────────
|
||||
//
|
||||
// These tests build HLIR graphs and run egglog saturation. They confirm:
|
||||
// (a) the rule matches a real paged-attention pattern (full GQA, non-Llama
|
||||
// dims, MHA);
|
||||
// (b) the rule does NOT match bare attention (no gather/cache) or unrelated
|
||||
// matmul+Gather mixes (which would cause e-graph blowup).
|
||||
//
|
||||
// Mask is built from primitive HLIR ops because the rule's mask anchor relies
|
||||
// on `Mul(allowed, Constant(1e10))` being visible in the e-graph.
|
||||
|
||||
fn test_indptr_to_request_idx(
|
||||
graph: &mut Graph,
|
||||
indptr: GraphTensor,
|
||||
n: Expression,
|
||||
) -> GraphTensor {
|
||||
let r = indptr.dims1();
|
||||
let indices = graph.arange(n.clone()).expand_dim(1, r.clone());
|
||||
let indptr_2d = indptr.expand_dim(0, n);
|
||||
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
|
||||
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
|
||||
}
|
||||
|
||||
fn test_compute_attn_mask(
|
||||
graph: &mut Graph,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
c: Expression,
|
||||
) -> GraphTensor {
|
||||
let s = q_pos.dims1();
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s.clone());
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c.clone());
|
||||
let c_arange = graph.arange(c.clone());
|
||||
let c_kv_start = kv_indptr.gather(c_request);
|
||||
let c_local_pos = c_arange - c_kv_start;
|
||||
let q_req_2d = q_request.expand_dim(1, c.clone());
|
||||
let c_req_2d = c_request.expand_dim(0, s.clone());
|
||||
let same = q_req_2d.eq(c_req_2d);
|
||||
let c_pos_2d = c_local_pos.expand_dim(0, s);
|
||||
let qp_2d = q_pos.expand_dim(1, c);
|
||||
let causal = c_pos_2d.le(qp_2d);
|
||||
let allowed = same.cast(luminal::dtype::DType::F32) * causal.cast(luminal::dtype::DType::F32);
|
||||
allowed * 1e10 - 1e10
|
||||
}
|
||||
|
||||
fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = data.graph().arange(d as i32).expand_dim(0, n);
|
||||
data.gather(base + col)
|
||||
}
|
||||
|
||||
fn scatter_rows(
|
||||
src: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
dest: GraphTensor,
|
||||
d: usize,
|
||||
) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = src.graph().arange(d as i32).expand_dim(0, n);
|
||||
src.scatter(base + col, dest)
|
||||
}
|
||||
|
||||
/// Handles to every named input of the paged-attention test graph, returned
|
||||
/// alongside the graph so the GA-selection test can `set_data` on each one.
|
||||
struct PagedAttnHandles {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v_new: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
}
|
||||
|
||||
/// Build a full paged-attention HLIR graph with the structural anchors the
|
||||
/// FlashInfer egglog rule looks for: scatter into a 2D cache, gather rows out
|
||||
/// by index, GQA broadcast via `Mul(..., 1.0)` with zero strides, Q*K^T → Sum
|
||||
/// → scale → mask Add → softmax → *V → Sum.
|
||||
fn build_paged_attention_graph(
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> (Graph, PagedAttnHandles) {
|
||||
let kv_groups = n_heads / n_kv_heads;
|
||||
let kv_dim = n_kv_heads * head_dim;
|
||||
let hidden = n_heads * head_dim;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', hidden));
|
||||
let k_rope = cx.named_tensor("k_rope", ('s', kv_dim));
|
||||
let v_new = cx.named_tensor("v_new", ('s', kv_dim));
|
||||
let k_cache = cx.named_tensor("k_cache", (2048, kv_dim)).persist();
|
||||
let v_cache = cx.named_tensor("v_cache", (2048, kv_dim)).persist();
|
||||
let scatter_idx = cx
|
||||
.named_tensor("scatter_idx", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let q_pos = cx
|
||||
.named_tensor("q_pos", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let qo_indptr = cx
|
||||
.named_tensor("qo_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let kv_indptr = cx
|
||||
.named_tensor("kv_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, kv_dim);
|
||||
let v_cache_out = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, kv_dim);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, kv_dim);
|
||||
|
||||
let c: Expression = 'c'.into();
|
||||
let attn_mask = test_compute_attn_mask(&mut cx, q_pos, qo_indptr, kv_indptr, c);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.split_dims(1, head_dim).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (head_dim as f32).sqrt();
|
||||
let mask = attn_mask.expand_dim(0, n_heads);
|
||||
let masked_scores = scores + mask;
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
attn_out.output();
|
||||
k_cache_out.output();
|
||||
v_cache_out.output();
|
||||
|
||||
(
|
||||
cx,
|
||||
PagedAttnHandles {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
q_pos,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Saturate egglog on the graph and report whether a FlashInferAttention
|
||||
/// e-node was produced. Helper used by the rule-firing tests.
|
||||
fn saturate_and_has_flashinfer(cx: &Graph) -> (bool, Vec<String>) {
|
||||
let (program, root) = hlir_to_egglog(cx);
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
// cleanup=false: keep every saturation-introduced e-node so we can inspect
|
||||
// whether the FlashInferAttention rule produced a node, regardless of
|
||||
// whether downstream extraction would have pruned it.
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
let has_flashinfer = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
|
||||
// Collect distinct OpKind labels so a failure can print what *did* match.
|
||||
let mut op_kinds: Vec<String> = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.filter(|(l, _)| {
|
||||
!l.starts_with('(')
|
||||
&& ![
|
||||
"Op",
|
||||
"Input",
|
||||
"Output",
|
||||
"OutputJoin",
|
||||
"ICons",
|
||||
"INil",
|
||||
"ECons",
|
||||
"ENil",
|
||||
"MNum",
|
||||
"MVar",
|
||||
"MMul",
|
||||
"MDiv",
|
||||
"MIter",
|
||||
]
|
||||
.contains(&l.as_str())
|
||||
})
|
||||
.map(|(l, _)| l.clone())
|
||||
.collect();
|
||||
op_kinds.sort();
|
||||
op_kinds.dedup();
|
||||
|
||||
(has_flashinfer, op_kinds)
|
||||
}
|
||||
|
||||
/// Debug aid: dump the egglog program and key e-graph metrics for the lite
|
||||
/// paged-attention test so we can see why the FlashInfer rule isn't matching.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn flashinfer_dump_paged_attn_egglog() {
|
||||
// First sanity-check that each Ops member returns its rewrites and that
|
||||
// FlashInferAttention's rule appears in the combined corpus.
|
||||
let ops_vec = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
eprintln!("==== Ops rewrites count ====");
|
||||
let mut fi_rewrites = 0usize;
|
||||
let mut total_rewrites = 0usize;
|
||||
for op in &ops_vec {
|
||||
let rws = op.rewrites();
|
||||
total_rewrites += rws.len();
|
||||
for r in &rws {
|
||||
let s = format!("{r:?}");
|
||||
if s.contains("FlashInferAttention") {
|
||||
fi_rewrites += 1;
|
||||
eprintln!("FOUND FlashInfer rewrite ({} chars)", s.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"==== ops_vec.len()={} total_rewrites={total_rewrites} fi_rewrites={fi_rewrites} ====",
|
||||
ops_vec.len()
|
||||
);
|
||||
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (program, root) = hlir_to_egglog(&cx);
|
||||
eprintln!("==== EGGLOG PROGRAM (root={root}) ====");
|
||||
for (i, line) in program.lines().enumerate() {
|
||||
eprintln!("{:5}: {line}", i + 1);
|
||||
}
|
||||
eprintln!(
|
||||
"==== END EGGLOG PROGRAM ({} lines) ====",
|
||||
program.lines().count()
|
||||
);
|
||||
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
// Bucket enode labels by frequency.
|
||||
let mut counts: std::collections::HashMap<String, usize> = Default::default();
|
||||
for (label, _) in egraph.enodes.values() {
|
||||
*counts.entry(label.clone()).or_default() += 1;
|
||||
}
|
||||
let mut sorted: Vec<_> = counts.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
eprintln!("==== E-GRAPH LABEL HISTOGRAM (top 60) ====");
|
||||
for (label, n) in sorted.iter().take(60) {
|
||||
eprintln!(" {n:6} {label}");
|
||||
}
|
||||
let has_fi = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
eprintln!("==== has FlashInferAttention enode: {has_fi} ====");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_bare_attention() {
|
||||
// Dense attention without paged gather + cache should NOT match.
|
||||
let (cx, _, _, _, _) = build_attention_graph();
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on bare attention (no gather/cache)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_unrelated_matmuls() {
|
||||
// A Gather + plain matmul (MLP-shaped projection) plus two chained matmuls
|
||||
// through softmax — close to attention structurally but missing the GQA
|
||||
// broadcast / mask Add anchors. The rule must reject this.
|
||||
let mut cx = Graph::default();
|
||||
let cache = cx.named_tensor("cache", (4096, KV_DIM)).persist();
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let weight = cx.named_tensor("weight", (HIDDEN, KV_DIM)).persist();
|
||||
|
||||
let n = gather_idx.dims1();
|
||||
let base = (gather_idx * KV_DIM).expand_dim(1, KV_DIM);
|
||||
let col = cx.arange(KV_DIM as i32).expand_dim(0, n);
|
||||
let gathered = cache.gather(base + col);
|
||||
let proj = gathered.matmul(weight.t());
|
||||
proj.output();
|
||||
|
||||
let a = cx.named_tensor("a", ('s', HIDDEN));
|
||||
let b = cx.named_tensor("b", (HIDDEN, HIDDEN)).persist();
|
||||
let c_tensor = cx.named_tensor("c_tensor", (HIDDEN, HIDDEN)).persist();
|
||||
let ab = a.matmul(b.t());
|
||||
let abc = ab.softmax(1).matmul(c_tensor.t());
|
||||
abc.output();
|
||||
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on unrelated matmuls + Gather"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_full_paged_attention() {
|
||||
// Default Llama-shaped test dims (HEAD_DIM=64, N_HEADS=8, N_KV_HEADS=2).
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found in the e-graph (Llama-shaped paged attention). \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_non_llama_dims() {
|
||||
// Different head counts: HEAD_DIM=64, N_HEADS=16, N_KV_HEADS=4 (group=4).
|
||||
// Exercises the model-agnostic structural variables in the rule.
|
||||
let (cx, _) = build_paged_attention_graph(16, 4, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for non-Llama dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_mha() {
|
||||
// MHA: KV_GROUPS=1 (n_heads == n_kv_heads). The GQA broadcast still
|
||||
// structurally appears (expand_dim(1, 1) + merge), so the rule should
|
||||
// still match.
|
||||
let (cx, _) = build_paged_attention_graph(12, 12, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for MHA dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 5: extraction reachability (no GPU) ───────────────────────────
|
||||
//
|
||||
// After `build_search_space` saturates egglog, the GA picks an extraction by
|
||||
// cost. In a tiny test graph the cuBLAS+kernel path is often faster than the
|
||||
// FlashInfer host op (which pays a `plan()` setup cost per call), so asserting
|
||||
// "GA picked FlashInfer" is flaky. Instead, sample many random valid genomes
|
||||
// from the search space and assert that the FlashInfer extraction is reachable
|
||||
// — meaning the rule fired AND `find_indptrs` extraction succeeded for at
|
||||
// least one offspring. That is the end-to-end check we actually want.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_extraction_reachable_from_search_space() {
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
let (mut cx, _h) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
cx.set_dim('s', 1usize);
|
||||
cx.set_dim('c', 16usize);
|
||||
cx.set_dim('r', 2usize);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let egraph = cx
|
||||
.egraph()
|
||||
.expect("egraph missing after build_search_space");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("egglog_ops missing after build_search_space");
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0xf1a541);
|
||||
let mut prev: FxHashSet<u64> = FxHashSet::default();
|
||||
let initial = luminal::egglog_utils::random_initial_choice(egraph, &mut rng);
|
||||
prev.insert(luminal::egglog_utils::hash_choice_set(&initial));
|
||||
let mut base = initial;
|
||||
|
||||
let mut found = false;
|
||||
'outer: for _ in 0..50 {
|
||||
let offspring =
|
||||
luminal::egglog_utils::extract_generation(egraph, &base, 10, 2, &mut prev, &mut rng);
|
||||
if offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
for genome in offspring {
|
||||
if luminal::egglog_utils::validate_choice_set(egraph, &genome, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
// Catch a possible panic from find_indptrs walking the mask — we
|
||||
// want the test to fail with a clean message, not abort.
|
||||
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
luminal::egglog_utils::egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
)
|
||||
}));
|
||||
let Ok(llir_graph) = panicked else { continue };
|
||||
|
||||
let has_fi = llir_graph.node_indices().any(|n| {
|
||||
llir_graph[n]
|
||||
.to_dialect::<dyn HostOp>()
|
||||
.and_then(|op| op.stats_name())
|
||||
== Some("FlashInferAttention")
|
||||
});
|
||||
if has_fi {
|
||||
found = true;
|
||||
break 'outer;
|
||||
}
|
||||
base = genome;
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found,
|
||||
"FlashInferAttention extraction not reachable from search space after 50 generations"
|
||||
);
|
||||
}
|
||||
@@ -5,10 +5,6 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod cublaslt_rewrite_tests;
|
||||
#[cfg(test)]
|
||||
mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
|
||||
//!
|
||||
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
|
||||
//! reference to catch incorrect HLIR kernel rewrites.
|
||||
//!
|
||||
//! These are marked ignored by default because each test builds a model-shaped
|
||||
//! graph and checks many extraction genomes. Run them explicitly with
|
||||
//! `cargo test -p luminal_cuda_lite -- --ignored` when touching extraction,
|
||||
//! scheduling, or model-pattern rewrites.
|
||||
//! reference to catch incorrect HLIR kernel fallback rewrites.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
@@ -382,38 +377,32 @@ mod llama {
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
|
||||
}
|
||||
|
||||
/// Force HLIR-only (no block ops) to specifically test that extraction path.
|
||||
/// Force HLIR-only (no block ops) to specifically test the fallback path.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
|
||||
}
|
||||
@@ -435,26 +424,22 @@ mod gemma {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
|
||||
}
|
||||
|
||||
/// Gemma has extra post-attention and post-feedforward norms.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer_full_norms() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -579,14 +564,12 @@ mod gemma {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Gemma dimensions.
|
||||
/// Force HLIR-only to test fallback path with Gemma dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
|
||||
}
|
||||
@@ -608,26 +591,22 @@ mod qwen {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
|
||||
}
|
||||
|
||||
/// Qwen uses tied embeddings: lm_head = embedding^T
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_lm_head() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -689,20 +668,17 @@ mod qwen {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Qwen dimensions.
|
||||
/// Force HLIR-only to test fallback path with Qwen dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
|
||||
}
|
||||
|
||||
@@ -16,16 +16,9 @@ use super::utilities::{
|
||||
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
|
||||
};
|
||||
|
||||
// The property-based op tests each build/search CUDA graphs for multiple random
|
||||
// shapes. They are ignored by default to keep the main CUDA unit suite short;
|
||||
// run `cargo test -p luminal_cuda_lite -- --ignored` for the broader sweeps.
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -35,9 +28,6 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -47,27 +37,18 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_matmul(
|
||||
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
|
||||
@@ -138,8 +119,6 @@ proptest! {
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
@@ -148,9 +127,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
@@ -159,9 +135,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -169,9 +142,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
@@ -179,9 +149,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
@@ -190,17 +157,12 @@ proptest! {
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
test_mod(x, x, |a, b| a % b, seed);
|
||||
test_mod((y, x), (y, x), |a, b| a % b, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
@@ -373,8 +335,6 @@ proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
|
||||
use luminal::dtype::DType;
|
||||
@@ -567,9 +527,6 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(3))]
|
||||
|
||||
// This walks random extraction genomes and is intentionally opt-in so the
|
||||
// default CUDA unit suite keeps a tight feedback loop.
|
||||
#[ignore = "expensive CUDA genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
#[test]
|
||||
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
|
||||
fuzz_test_cuda_genomes_impl(seed);
|
||||
@@ -637,9 +594,6 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_embed_proptest(
|
||||
vocab_size in 10usize..200,
|
||||
|
||||
@@ -3,7 +3,10 @@ use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
@@ -71,9 +74,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +133,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
@@ -173,9 +176,10 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
|
||||
@@ -136,15 +136,14 @@ pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
|
||||
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
|
||||
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
|
||||
let Some((major, minor)) = gpu_compute_cap() else {
|
||||
let Some((major, _)) = gpu_compute_cap() else {
|
||||
return false;
|
||||
};
|
||||
match dtype {
|
||||
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
|
||||
luminal::dtype::DType::F8E4M3 | luminal::dtype::DType::F8E5M2 => {
|
||||
major > 8 || (major == 8 && minor >= 9)
|
||||
} // Ada/Hopper (sm_89+)
|
||||
luminal::dtype::DType::F4E2M1 | luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
luminal::dtype::DType::F4E2M1
|
||||
| luminal::dtype::DType::F8E4M3
|
||||
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,21 +102,6 @@ fn metal_copy_value(dtype: DType, buffer: &str, index: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn metal_binary_op_values(
|
||||
output_dtype: DType,
|
||||
a_dtype: DType,
|
||||
b_dtype: DType,
|
||||
a_idx: &str,
|
||||
b_idx: &str,
|
||||
) -> (String, String) {
|
||||
let read: fn(DType, &str, &str) -> String = if output_dtype == DType::Int {
|
||||
metal_copy_value
|
||||
} else {
|
||||
metal_numeric_read
|
||||
};
|
||||
(read(a_dtype, "a", a_idx), read(b_dtype, "b", b_idx))
|
||||
}
|
||||
|
||||
fn call_sort_from_args(sort: &SortDef, args: &Args) -> EggTerm {
|
||||
let mut filtered_args = Args::new();
|
||||
for field in &sort.fields {
|
||||
@@ -132,11 +117,9 @@ fn unary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
args["__inputs"].clone(),
|
||||
);
|
||||
let dt = v("?__dt");
|
||||
rule(union(hlir_match.clone(), metal_op.clone()))
|
||||
.subsume(hlir_match)
|
||||
rule(union(hlir_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(args["inp"].clone())))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
fn binary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
@@ -146,11 +129,9 @@ fn binary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
args["__inputs"].clone(),
|
||||
);
|
||||
let dt = v("?__dt");
|
||||
rule(union(hlir_match.clone(), metal_op.clone()))
|
||||
.subsume(hlir_match)
|
||||
rule(union(hlir_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(args["inp_a"].clone())))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -304,7 +285,7 @@ macro_rules! metal_unary_op {
|
||||
device {input_ty} *inp [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -388,10 +369,8 @@ impl EgglogOp for MetalAdd {
|
||||
|
||||
vec![
|
||||
binary_dtype_rewrite(&Add::default().sort(), &self.sort()),
|
||||
rule(union(hlir_match2.clone(), metal_op2.clone()))
|
||||
.subsume(hlir_match2)
|
||||
.set(dtype(metal_op2), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
rule(union(hlir_match2, metal_op2.clone()))
|
||||
.set(dtype(metal_op2), app(&SORTS.f32_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -444,7 +423,8 @@ impl MetalKernelOp for MetalAdd {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("({a_val}) + ({b_val})"));
|
||||
|
||||
let source = format!(
|
||||
@@ -457,7 +437,7 @@ impl MetalKernelOp for MetalAdd {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -576,7 +556,8 @@ impl MetalKernelOp for MetalMul {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("({a_val}) * ({b_val})"));
|
||||
|
||||
let source = format!(
|
||||
@@ -589,7 +570,7 @@ impl MetalKernelOp for MetalMul {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -718,13 +699,9 @@ impl MetalKernelOp for MetalMod {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let out_expr = if output_dtype == DType::Int {
|
||||
format!("({a_val}) % ({b_val})")
|
||||
} else {
|
||||
format!("fmod({a_val}, {b_val})")
|
||||
};
|
||||
let out_val = metal_numeric_write(output_dtype, &out_expr);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("fmod({a_val}, {b_val})"));
|
||||
|
||||
let source = format!(
|
||||
r#"
|
||||
@@ -736,7 +713,7 @@ impl MetalKernelOp for MetalMod {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -876,7 +853,7 @@ impl MetalKernelOp for MetalLessThan {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -1023,7 +1000,7 @@ impl MetalKernelOp for MetalSumReduce {
|
||||
const device {input_ty} *in [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
device uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
@@ -1204,7 +1181,7 @@ impl MetalKernelOp for MetalMaxReduce {
|
||||
const device {input_ty} *in [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
device uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
@@ -1742,10 +1719,8 @@ impl EgglogOp for MetalConstant {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, const_match) = new_op_call(&Constant::default().sort(), &[]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(const_match.clone(), metal_op.clone()))
|
||||
.subsume(const_match)
|
||||
.set(dtype(metal_op), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(const_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), app(&SORTS.f32_dt, vec![]))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1852,10 +1827,8 @@ impl EgglogOp for MetalIota {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, iota_match) = new_op_call(&Iota::default().sort(), &[]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(iota_match.clone(), metal_op.clone()))
|
||||
.subsume(iota_match)
|
||||
.set(dtype(metal_op), app(&SORTS.int_dt, vec![]))
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(iota_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), app(&SORTS.int_dt, vec![]))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1899,7 +1872,7 @@ impl MetalKernelOp for MetalIota {
|
||||
kernel void mkernel(
|
||||
device int *out [[buffer(0)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -1951,7 +1924,6 @@ impl MetalKernelOp for MetalIota {
|
||||
pub struct MetalGather {
|
||||
out_shape: Vec<Expression>,
|
||||
index_stride: Vec<Expression>,
|
||||
data_shape: Vec<Expression>,
|
||||
data_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
}
|
||||
@@ -1966,7 +1938,6 @@ impl EgglogOp for MetalGather {
|
||||
("indexes", IR),
|
||||
("index_strides", ELIST),
|
||||
("data", IR),
|
||||
("data_shape", ELIST),
|
||||
("data_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
],
|
||||
@@ -1988,7 +1959,6 @@ impl EgglogOp for MetalGather {
|
||||
gather_args["index_strides"].clone(),
|
||||
),
|
||||
("data".to_string(), gather_args["data"].clone()),
|
||||
("data_shape".to_string(), gather_args["data_shape"].clone()),
|
||||
(
|
||||
"data_strides".to_string(),
|
||||
gather_args["data_strides"].clone(),
|
||||
@@ -1996,11 +1966,9 @@ impl EgglogOp for MetalGather {
|
||||
("out_strides".to_string(), out_strides),
|
||||
];
|
||||
let metal_op = self.sort().call(metal_args);
|
||||
vec![rule(union(gather_match.clone(), metal_op.clone()))
|
||||
.subsume(gather_match)
|
||||
vec![rule(union(gather_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(gather_args["data"].clone())))
|
||||
.ruleset("kernel_lower")]
|
||||
.fact(eq(dt, dtype(gather_args["data"].clone())))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2021,10 +1989,9 @@ impl EgglogOp for MetalGather {
|
||||
out_shape: extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap(),
|
||||
index_stride: extract_expr_list(egraph, children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
data_shape: extract_expr_list(egraph, children[4], list_cache, expr_cache).unwrap(),
|
||||
data_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache)
|
||||
data_stride: extract_expr_list(egraph, children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[6], list_cache, expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap(),
|
||||
})),
|
||||
vec![children[1], children[3]],
|
||||
)
|
||||
@@ -2048,7 +2015,7 @@ impl MetalKernelOp for MetalGather {
|
||||
"idx",
|
||||
);
|
||||
let data_idx = lower_expression_for_metal(
|
||||
&flatten_strides(&self.data_shape, &self.data_stride),
|
||||
&flatten_strides(&self.out_shape, &self.data_stride),
|
||||
"gathered_index",
|
||||
);
|
||||
let gathered_val = metal_copy_value(data_dtype, "data", &data_idx);
|
||||
@@ -2063,7 +2030,7 @@ impl MetalKernelOp for MetalGather {
|
||||
const device {data_ty} *data [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -2089,10 +2056,6 @@ impl MetalKernelOp for MetalGather {
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.get(1).copied().unwrap_or(DType::F32)
|
||||
}
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
@@ -2214,11 +2177,9 @@ impl EgglogOp for MetalScatter {
|
||||
("out_strides".to_string(), out_strides),
|
||||
];
|
||||
let metal_op = self.sort().call(metal_args);
|
||||
vec![rule(union(scatter_match.clone(), metal_op.clone()))
|
||||
.subsume(scatter_match)
|
||||
vec![rule(union(scatter_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))
|
||||
.ruleset("kernel_lower")]
|
||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2282,7 +2243,7 @@ impl MetalKernelOp for MetalScatter {
|
||||
kernel void copy_kernel(
|
||||
device {out_ty} *out [[buffer(0)]],
|
||||
const device {dest_ty} *dest [[buffer(1)]],
|
||||
constant uint &n_elements [[buffer(2)]],
|
||||
device uint &n_elements [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
@@ -2316,7 +2277,7 @@ impl MetalKernelOp for MetalScatter {
|
||||
device {out_ty} *out [[buffer(0)]],
|
||||
const device int *indexes [[buffer(1)]],
|
||||
const device {src_ty} *src [[buffer(2)]],
|
||||
constant uint &n_elements [[buffer(3)]],
|
||||
device uint &n_elements [[buffer(3)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
@@ -2447,10 +2408,7 @@ impl EgglogOp for MetalCast {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, cast_match) = new_op_call(&Cast::default().sort(), &["inp"]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(cast_match.clone(), metal_op.clone()))
|
||||
.subsume(cast_match)
|
||||
.set(dtype(metal_op), args["dtype"].clone())
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(cast_match, metal_op.clone())).set(dtype(metal_op), args["dtype"].clone())]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2509,7 +2467,7 @@ impl MetalKernelOp for MetalCast {
|
||||
device {input_ty} *inp [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
|
||||
@@ -282,8 +282,6 @@ impl Runtime for MetalRuntime {
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -294,7 +292,6 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
|
||||
@@ -250,23 +250,6 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
assert_close(&out, &[9.0, 12.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_int_arithmetic_preserves_large_values() {
|
||||
let mut cx = Graph::default();
|
||||
let token = cx.tensor(1).as_dtype(DType::Int);
|
||||
let large_index = (token * 1024) + 123;
|
||||
let mod_output = (large_index % 65_537).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(mod_output), vec![891.0]);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
@@ -988,28 +971,6 @@ fn test_scatter_basic() {
|
||||
assert_close(&out, &[0.0, 10.0, 0.0, 20.0, 30.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((4, 3));
|
||||
let data = input.transpose(0, 1);
|
||||
let indexes = cx.tensor((2, 2)).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(
|
||||
input,
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[0.0, 9.0, 1.0, 10.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_into_nonzero_dest() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1051,21 +1012,3 @@ fn test_scatter_all_positions() {
|
||||
let out = rt.get_f32(result);
|
||||
assert_close(&out, &[10.0, 20.0, 30.0, 40.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_preserves_data_dtype() {
|
||||
let mut cx = Graph::default();
|
||||
let data = cx.tensor(2);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[2.5], 0.001);
|
||||
}
|
||||
|
||||
@@ -61,8 +61,7 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -479,8 +478,7 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -749,6 +749,92 @@ candidates rejected" during search, check whether the rejection is from actual f
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
|
||||
## 2026-04-22 — Benchmark python_luminal Path: NativeRuntime Panic on CUDA Weights
|
||||
|
||||
### What the symptom was
|
||||
|
||||
Running `benchmarks/ttft/run.py` with the `python_luminal` path panicked deep in Rust:
|
||||
|
||||
```
|
||||
thread panicked at src/hlir.rs:2239:40: no entry found for key
|
||||
```
|
||||
|
||||
The panic occurred in `NativeRuntime::execute` when the `Output` node tried to read its
|
||||
predecessor's buffer from `self.buffers` — and the buffer wasn't there.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The luminal Python wheel was built without `--features cuda` (plain `maturin build --release`).
|
||||
This means `_cuda_lite_factory_capsule` is not compiled into the `.so` file. In `main.py`,
|
||||
`_detect_factory_capsule` catches the resulting `ImportError` and **silently** falls back to
|
||||
`_native_factory_capsule` (NativeRuntime / CPU runtime).
|
||||
|
||||
The benchmark model (`LlamaForCausalLM.from_pretrained(...).to("cuda")`) has all weights as
|
||||
CUDA device pointers. `BackendCompileArgs.device_ptrs` is populated with these GPU pointers.
|
||||
NativeRuntime has no mechanism to handle GPU-resident weight data — the `device_ptrs` map is
|
||||
simply ignored. After search completes (it can search because it uses dummy CPU data during
|
||||
profiling), the first real `execute()` call processes the graph:
|
||||
|
||||
1. `Input` nodes are skipped (their buffers should be pre-populated by `set_input_from_ptr`)
|
||||
2. Weight `Input` nodes were set via `set_input_device_ptr` — but NativeRuntime's
|
||||
`set_input_device_ptr` likely no-ops or stores garbage, leaving those buffers empty
|
||||
3. The `Output` node looks up its predecessor's buffer → key not found → panic
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Silent fallback**: `_detect_factory_capsule` catches `ImportError` without logging a
|
||||
warning. Nothing in stdout indicates you're running on CPU when the model is on GPU.
|
||||
2. **Search succeeds**: The e-graph search runs to completion (searches 1 group, 1 chunk in
|
||||
~15s) because it uses 1.0f32 dummy data that doesn't need GPU. The failure only occurs at
|
||||
first real execution.
|
||||
3. **Misleading error site**: `hlir.rs:2239` is in NativeRuntime's buffer-copy loop for Output
|
||||
nodes — it gives no indication that the root cause is a missing CUDA feature flag at build time.
|
||||
4. **Backtrace required**: Without `RUST_BACKTRACE=1`, only the panic message is visible;
|
||||
the `NativeRuntime` frame that reveals the CPU fallback is hidden.
|
||||
|
||||
### The fix
|
||||
|
||||
Rebuild the wheel with CUDA support:
|
||||
```bash
|
||||
maturin build --release --features cuda
|
||||
pip install target/wheels/luminal_python-*.whl --force-reinstall
|
||||
```
|
||||
|
||||
Or via the test runner: `./run_tests_cuda.sh` uses `maturin develop --features cuda -r`.
|
||||
|
||||
Consider adding an explicit warning or error in `_detect_factory_capsule` when CUDA inputs are
|
||||
detected but no CUDA factory is available:
|
||||
|
||||
```python
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
return _cuda_lite_factory_capsule()
|
||||
except ImportError:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"CUDA inputs detected but luminal was built without --features cuda. "
|
||||
"Falling back to NativeRuntime (CPU) — this will likely panic at runtime.",
|
||||
RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
```
|
||||
|
||||
### The regression test
|
||||
|
||||
`test_hf_llama3_8b_instruct_1layer` in `tests/test_llama3.py` — tests the exact architecture
|
||||
from the benchmark (Meta-Llama-3-8B-Instruct, 4096 hidden, 32 attn heads, 8 KV heads) with
|
||||
1 layer and random weights. This test passes with `--features cuda` and panics without it.
|
||||
|
||||
### General principle
|
||||
|
||||
**When a feature gate silently changes the runtime backend, assert that the selected backend
|
||||
is compatible with the input device.** A CUDA tensor flowing into a CPU-only runtime is always
|
||||
a programming error, not a graceful degradation. The failure should surface at factory
|
||||
selection time (with a clear error message), not deep in a Rust buffer-copy loop.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
|
||||
|
||||
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
|
||||
@@ -757,6 +843,44 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — NativeRuntime Multi-Call Panic: Input Buffers Cleared After Each Run
|
||||
|
||||
1. **Symptom**: The compiled model panicked with `hlir.rs:XXXX: no entry found for key` on the second call. First call succeeded; subsequent calls failed.
|
||||
2. **Root cause**: `NativeRuntime::execute` in `src/hlir.rs` called `self.buffers.retain(|k, _| output_nodes.contains(k))` after each run to free intermediate buffers. This correctly pruned temporary buffers but also pruned the Input-node buffers that hold model weights — so on the second call, the weight tensors were gone.
|
||||
3. **Why hard**: The bug never manifested in the test suite because every test called the compiled model exactly once per compile. The issue only appeared when running a bench loop that called the model multiple times. The panic location (deep in buffer lookup) gave no indication that the root cause was in the buffer retention policy.
|
||||
4. **Fix**: Changed the retain predicate to keep both `Output` and `Input` nodes:
|
||||
```rust
|
||||
let keep_nodes = graph.node_indices()
|
||||
.filter(|n| is::<Output> || is::<Input>)
|
||||
.collect();
|
||||
self.buffers.retain(|k, _| keep_nodes.contains(k));
|
||||
```
|
||||
5. **Principle**: When buffer lifetime policies are changed to free memory after a run, always verify that *persistent* state (model weights stored in Input nodes) is excluded from the cleanup sweep. A test that compiles + calls once per test function will never catch a multi-call regression — add a dedicated multi-call test for any compiled runtime.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — PT2 USER_INPUT_MUTATION Outputs Confuse Dynamo Caller
|
||||
|
||||
1. **Symptom**: With `StaticCache`, the compiled model returned `[1]` (cumulative_length update) instead of `[1, vocab_size]` logits. The wrong tensor was silently mapped to the output variable.
|
||||
2. **Root cause**: When `torch.export` encounters in-place mutations to input tensors (KV cache updates via `index_copy_`), it lifts them as `USER_INPUT_MUTATION` output specs, placed *before* the actual `USER_OUTPUT` logits in `ep.graph_signature.output_specs`. The compiled model returned all outputs; dynamo mapped index 0 (the mutation) to the first return value.
|
||||
3. **Why hard**: The output shape `[1]` from `cumulative_length` looked like a valid (though wrong) output. No error was raised — just wrong logits. Required inspecting `ep.graph_signature.output_specs` and understanding the ordering convention for different `OutputKind` values.
|
||||
4. **Fix**: In `pt2_backend`, parse `output_specs` to build a `mutation_mappings` list and `user_output_indices`. Wrap the compiled model to: (a) copy mutation outputs back into the corresponding input tensors, and (b) return only the `USER_OUTPUT` tensors.
|
||||
5. **Principle**: After `torch.export(...).run_decompositions()`, always inspect `ep.graph_signature.output_specs` when the model has in-place operations (KV cache, BN running stats). The output ordering is: mutations first, then actual outputs — and the caller only expects actual outputs.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — CUDA Version Mismatch: torch+cuXXX Must Match System Driver
|
||||
|
||||
1. **Symptom**: `torch.cuda.is_available()` returned `False` despite `nvidia-smi` showing a GPU. Warning: "CUDA initialization: The NVIDIA driver on your system is too old (found version 12080)."
|
||||
2. **Root cause**: `torch==2.11.0+cu130` requires CUDA 13.0 which needs driver >= 575. The system has driver 570 (CUDA 12.8 max). The mismatch caused silent CPU fallback — no error, just False from `is_available()`.
|
||||
3. **Why hard**: The bench appeared to start successfully (model loaded, compilation ran) but produced no results because it was running an 8B model on CPU. Zero output with exit code 0 looked like a hang or silent crash.
|
||||
4. **Fix**: Installed `torch==2.11.0+cu128` from `https://download.pytorch.org/whl/cu128`. CUDA 12.8 matches driver 570. Also needed matching `torchvision==0.26.0+cu128` and the `nvidia-cusparselt-cu12` runtime library.
|
||||
5. **Principle**: Before running any CUDA-dependent bench or test, verify `torch.cuda.is_available()` returns `True`. Check `nvidia-smi` CUDA Version field against the `+cuXXX` suffix in `torch.__version__` — they must match (CUDA runtime ≤ driver's max supported version). Never assume CPU fallback "works" for large model benchmarks.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — Loop unroll-union rules silently disabled in full egglog stage
|
||||
|
||||
1. **Symptom**: Python `test_llama_transformer_block` (CUDA backend) produced output ~1e-2 off from PyTorch (atol=1e-4) on the `loop_rolling` branch. All component tests (RMSNorm, attention, SwiGLU, RoPE) passed. The diff pattern was suspicious: row 0 of the (1,4,32) output matched exactly, rows 1–3 differed slightly. Disabling rolling fixed it.
|
||||
@@ -767,6 +891,8 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
4. **Fix**: Register `binary_op_unroll_rules` in BOTH `early_rewrites()` (so fusion patterns like GLUMoE can match before the early-stage extract, which is what fixed `test_glumoe_gemma_gelu_matches_unfused_output` earlier in the session) AND `rewrites()` (so kernel-level rewrites like `direct-exp-fusion` can match in the full stage on the unrolled chain). One block per binary op (`Add`, `Mul`, `Mod`, `LessThan`).
|
||||
5. **Principle**: When egglog has multiple stages (early/full) with disjoint rule sets, any rewrite that materialises new HLIR/IR enodes (rather than just lowering to LLIR) needs to fire in BOTH stages if downstream rewrites in BOTH stages might want to see the new structure. Putting "preparatory" rewrites only in `early_rewrites` means their effect is lost across the early→full handoff. The narrow rule of thumb: if your rule's outputs are intended to enable matches by other rules, audit which stages those other rules run in and register accordingly.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — `unroll_loops_in_llir` panicked on iteration-invariant body producers
|
||||
|
||||
1. **Symptom**: Modal CI/CD job for the gemma example panicked at `src/graph.rs:1867` with `no entry found for key`. The line is `clone_map[i - 1][&body_producer]` inside `unroll_loops_in_llir`'s `resolve_src` closure — `body_producer` (the LoopEnd's incoming source for that slot) wasn't a key in the per-iteration clone map. cuda_lite/python tests didn't repro: only triggered by the specific genome and graph shapes that gemma's longer search settles on.
|
||||
@@ -775,6 +901,8 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
4. **Fix**: in `unroll_loops_in_llir::resolve_src`, when the LoopStart-resolved `body_producer` isn't in `body_nodes`, return `body_producer` itself for iter > 0 instead of indexing `clone_map[i - 1]`. The body op didn't depend on the loop variable, so every iter > 0 carries the same value forward — using `body_producer` directly is semantically correct. Mirrored the same `unwrap_or(body_producer)` fallback in the post-loop substitution map (`marker_post_sub` for LoopEnd / LoopOutputSelect). Added a backward-walk-from-end-markers backfill in `collapse_loops_to_first_iter` so its body-node iteration also covers these nodes (it doesn't have a clone_map, but does need to rewire body ops' incoming edges before deleting markers).
|
||||
5. **Principle**: When a graph-walk-derived set is used as a hashmap key requirement, every code path that *could* produce a key outside that set needs a graceful fallback — not just a defensive `expect`. For loop unrolling specifically, the rule is: `body_nodes` is the set of "ops that participate in per-iter computation"; ops on the LoopEnd's path that *don't* participate (iteration-invariant) are still legitimate, and need a "no clone, share across iters" path through `resolve_src` and `marker_post_sub`. Forward-walk-only `body_nodes` is correct only when extraction never produces iteration-invariant body producers — and in an egglog-driven search, that's not a guarantee you can make.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — Iteration-invariant state slots are a first-class concept, not a defensive fallback
|
||||
|
||||
1. **Symptom + fix recap**: gemma Modal CI panicked at `clone_map[i-1][&body_producer]` because some state slots' `body_producer` (LoopEnd's incoming) isn't in `body_nodes` (forward walk from input markers). The first commit pair (16de9638 / 93fb02c4) caught this with `.unwrap_or(body_producer)` — which works but reads as "defensive, unclear *why* this case exists."
|
||||
@@ -855,13 +983,70 @@ Two important details:
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
---
|
||||
|
||||
## 2026-05-01 — `KernelScatter` float4 vectorization wrote 2× past end of buffer for bf16/f16 KV cache
|
||||
|
||||
### What the symptom was
|
||||
|
||||
After the `translate_grouped_mm` gather rewrite (above) cleared the OOM, the qwen3-moe bench progressed past search but panicked during execution roughly 40% of the time:
|
||||
```
|
||||
crates/luminal_cuda_lite/src/runtime.rs:1204:
|
||||
CUDA execute error in "CudaGraph":
|
||||
DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, "an illegal memory access was encountered")
|
||||
```
|
||||
qwen3-4b (dense) was unaffected; the bf16 KV cache in HF `StaticCache` was the only path triggering it. The rust `examples/qwen3_moe` ran fine because it uses an F32 KV cache.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`KernelScatter::compile` in `crates/luminal_cuda_lite/src/kernel/hlir.rs` emitted a hand-written CUDA copy phase that vectorised through `float4` (16-byte) reads/writes:
|
||||
|
||||
```cuda
|
||||
long long n_vec = n_dest / 4; // ← assumes 4-byte dtype
|
||||
float4 *out4 = (float4 *)out;
|
||||
const float4 *dest4 = (const float4 *)dest;
|
||||
for (long long i = tid; i < n_vec; i += blockDim.x) {
|
||||
out4[i] = dest4[i]; // ← writes 16 B per iteration
|
||||
}
|
||||
long long remainder_start = n_vec * 4; // ← also assumes 4 elem/vec
|
||||
```
|
||||
|
||||
For `dtype=F32` (4 bytes), `n_vec * 16 = n_dest * 4` bytes — exactly fills the buffer. For `dtype=Bf16` (2 bytes), `n_vec * 16 = (n_dest/4) * 16 = n_dest * 4` bytes, which is **2× the actual buffer size of `n_dest * 2` bytes**. The write walks half the buffer past the end of `out` (and reads past `dest`).
|
||||
|
||||
Whether that produced an `ILLEGAL_ADDRESS` depended on whether the OOB region happened to land on an unmapped page. For different search outcomes, the surrounding allocator state differed → ~60% it was silent corruption, ~40% it crashed the CUDA context. That probabilistic mix is why the bug had been hidden — no test exercised a bf16 scatter (every existing scatter test uses F32 by default), and the rust example uses F32 KV cache so it was never seen there either.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Probabilistic, but search-determinate**: the rewrite from HLIR `Scatter` → `KernelScatter` always fires (it's the only non-NoCopy path), so the kernel is always present. The crash depends on memory layout, which depends on which other kernels the search picked. Made it look like an egglog-mutation issue rather than a kernel-correctness issue.
|
||||
2. **Existing test coverage was F32-only**: `test_scatter_execution_correctness` (in `tests/consumed_buffer_tests.rs`) explicitly tries 50 random extractions to cover both `Scatter` and `ScatterNoCopy`, but always with `cx.tensor(5)` which defaults to F32. The bug would never surface there.
|
||||
3. **The panic message hid the kernel name**: it surfaced as a generic `"CudaGraph"` host-op panic — the cuda_graph_exec batches all kernels into one atomic launch, so the failing kernel disappears into the batch. To localize it I had to add a `LUMINAL_DEBUG_SEQ` env var to `CudaGraphOp::execute_internal` that bypasses graph batching and launches each kernel via `cuLaunchKernel` with a sync afterwards, surfacing kernel name + node + grid/block/pointers when one fails.
|
||||
|
||||
### The fix
|
||||
|
||||
Parameterise `n_vec` and the remainder-loop start by the number of dtype elements that fit in 16 bytes:
|
||||
|
||||
```rust
|
||||
let elements_per_vec: usize = match self.dtype {
|
||||
DType::F64 => 2,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 8,
|
||||
DType::Bool | DType::I8 | DType::U8
|
||||
| DType::F8UE8M0 | DType::F8E4M3 | DType::F8E5M2 => 16,
|
||||
other => panic!("Unsupported dtype for scatter vectorization: {other:?}"),
|
||||
};
|
||||
```
|
||||
and substitute `{elements_per_vec}` into the kernel template (both the `n_vec` calc and `remainder_start`). For F32 / Int the generated code is byte-for-byte identical to before, so existing F32 tests are unaffected; for any other dtype the byte coverage now exactly equals `n_dest * sizeof(dtype)` as intended.
|
||||
|
||||
### Result
|
||||
|
||||
Before fix: 3/5 success at iters=10 (probabilistic).
|
||||
After fix: 5/5 at iters=10, 3/3 at iters=50. All 206 HLIR tests still pass. TTFT/TPOT identical (~9.35s / ~1.17s).
|
||||
|
||||
### General principle
|
||||
|
||||
**Hand-rolled CUDA vectorisation with a fixed-width type (`float4`, `float2`, `int4`, …) is almost always specialised to one element size.** When the same kernel template is parameterised by `dtype`, every byte-count expression has to be too. The cheapest correct form is "elements per vector load" computed from the dtype's byte size — never hardcode `/4`.
|
||||
|
||||
Also: **F32 is not a representative test dtype for kernels with vector loads.** When a kernel is written generic-over-dtype, the test matrix needs to actually exercise the dtypes (bf16, f16, bool) where the vector-element-count differs. A `test_scatter_bf16` would have caught this years before the qwen3-moe bench did. Same trap likely exists wherever else `float4` is cast over a `{dtype} *` template.
|
||||
|
||||
Diagnostic also added: `LUMINAL_DEBUG_SEQ=1` on the python_luminal path will now bypass `CudaGraphOp` batching at execute time, launching each kernel sequentially with a sync afterwards. If a future ILLEGAL_ADDRESS hides inside a batched graph again, this surfaces the kernel name and node index immediately.
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
2. **Root cause #1**: the dispatch table in `crates/luminal_python/rust/src/translator/dispatch.rs` mapped `sigmoid`, `tanh`, `relu` etc. but not `gelu` or `silu`. Whisper's encoder uses `F.gelu`, so the activation hit a hole.
|
||||
3. **Root cause #2**: PyTorch serializes `float("-inf")` in PT2 as the string `"-Infinity"` (and `"NaN"`/`"Infinity"` analogously). `translate_full`'s `get_float_arg` only accepts numeric float/int payloads, so any `torch.full((..), -inf)` (the obvious way to write a causal mask) blows up. Decoder mask code is the most common spot.
|
||||
4. **Why it was tricky**: both errors arrive from inside `pt2_backend` with a stack trace that ends in `process_pt2`, hiding the actual ATen target inside the message. You only see the offending op name in the error string itself, so you have to read `RuntimeError: Failed to translate node N: …` carefully and grep `dispatch.rs` for it.
|
||||
5. **Fix in this session**:
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# luminal_python
|
||||
|
||||
PyTorch `torch.compile` integration for Luminal.
|
||||
|
||||
## CUDA Tests
|
||||
|
||||
The Python CUDA CI job builds the Rust extension with the CUDA feature and runs
|
||||
the non-slow pytest suite:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
|
||||
```
|
||||
|
||||
The slow tests are explicit opt-in. They include large/pretrained model tests,
|
||||
full-width architecture compiles, Whisper end-to-end cases, and other cases that
|
||||
can take a long time or need a large GPU / Hugging Face cache.
|
||||
|
||||
Run the full Python CUDA suite, including slow tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s
|
||||
```
|
||||
|
||||
Run only the slow Python CUDA tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m slow
|
||||
```
|
||||
|
||||
The helper script follows the same convention:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
./run_tests_cuda.sh # non-slow CUDA suite
|
||||
./run_tests_cuda.sh --slow-only # only slow CUDA tests
|
||||
./run_tests_cuda.sh --include-slow
|
||||
```
|
||||
|
||||
The GitHub/Modal entrypoint uses the same marker split:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s -m "not slow"
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s
|
||||
```
|
||||
|
||||
@@ -1,497 +0,0 @@
|
||||
"""Whisper transcription demo using the luminal torch.compile backend.
|
||||
|
||||
Implements a small PyTorch port of ``openai/whisper-tiny.en`` that mirrors the
|
||||
luminal Rust example (``examples/whisper`` in the workspace), loads the official
|
||||
HuggingFace weights, and runs greedy decoding through the luminal backend via
|
||||
``torch.compile``.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python examples/whisper.py [path/to/audio.wav]
|
||||
|
||||
If no path is provided, falls back to the JFK sample bundled with the Rust
|
||||
``examples/whisper`` crate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
WhisperFeatureExtractor,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperTokenizer,
|
||||
)
|
||||
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
REPO_ID = "openai/whisper-tiny.en"
|
||||
|
||||
# whisper-tiny.en hyperparameters
|
||||
N_MELS = 80
|
||||
N_AUDIO_CTX = 1500
|
||||
D_MODEL = 384
|
||||
N_HEADS = 6
|
||||
HEAD_DIM = D_MODEL // N_HEADS
|
||||
N_AUDIO_LAYER = 4
|
||||
N_TEXT_LAYER = 4
|
||||
N_TEXT_CTX = 448
|
||||
FF_DIM = 4 * D_MODEL
|
||||
N_VOCAB = 51864
|
||||
LAYER_NORM_EPS = 1e-5
|
||||
|
||||
# Decoder special tokens
|
||||
TOKEN_SOT = 50257
|
||||
TOKEN_NO_TIMESTAMPS = 50362
|
||||
TOKEN_EOT = 50256
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model — mirrors the HLIR encoder/decoder in examples/whisper/src/model.rs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhisperAttention(torch.nn.Module):
|
||||
"""Multi-head attention with separate q/k/v projections (no bias on k_proj)."""
|
||||
|
||||
def __init__(self, d_model: int = D_MODEL, n_heads: int = N_HEADS):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.q_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.out_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
kv_input: Optional[torch.Tensor] = None,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# x: (seq, d_model). kv_input is None → self-attn; otherwise cross-attn.
|
||||
kv = x if kv_input is None else kv_input
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(kv)
|
||||
v = self.v_proj(kv)
|
||||
|
||||
seq_q = q.shape[0]
|
||||
seq_kv = k.shape[0]
|
||||
|
||||
# (seq, d_model) -> (n_heads, seq, head_dim)
|
||||
q = q.reshape(seq_q, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
scale = 1.0 / (self.head_dim**0.5)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (h, sq, sk)
|
||||
if causal:
|
||||
# Use a large finite negative instead of -inf so the export pipeline
|
||||
# serializes a float instead of the unsupported "-Infinity" sentinel.
|
||||
mask = torch.triu(
|
||||
torch.full((seq_q, seq_kv), -1e10, device=x.device),
|
||||
diagonal=1,
|
||||
)
|
||||
scores = scores + mask
|
||||
weights = torch.softmax(scores, dim=-1)
|
||||
attn = torch.matmul(weights, v) # (h, sq, hd)
|
||||
merged = attn.transpose(0, 1).reshape(seq_q, -1)
|
||||
return self.out_proj(merged)
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x))
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(
|
||||
N_MELS, D_MODEL, kernel_size=3, padding=1, bias=True
|
||||
)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
D_MODEL, D_MODEL, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
# Position embedding stored as a regular parameter (matches HF layout).
|
||||
self.embed_positions = torch.nn.Embedding(N_AUDIO_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[EncoderLayer() for _ in range(N_AUDIO_LAYER)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
# mel: (n_mels, 3000) -> add batch dim for conv1d
|
||||
x = mel.unsqueeze(0)
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
# (1, d_model, 1500) -> (1500, d_model)
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
x = x + self.embed_positions.weight
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.layer_norm(x)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.encoder_attn = WhisperAttention()
|
||||
self.encoder_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x), causal=True)
|
||||
x = x + self.encoder_attn(self.encoder_attn_layer_norm(x), kv_input=xa)
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(N_VOCAB, D_MODEL)
|
||||
self.embed_positions = torch.nn.Embedding(N_TEXT_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList([DecoderLayer() for _ in range(N_TEXT_LAYER)])
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
# tokens: (seq,) of int64 — absolute positions are 0..seq-1
|
||||
seq = tokens.shape[0]
|
||||
pos = torch.arange(seq, dtype=torch.long, device=tokens.device)
|
||||
x = self.embed_tokens(tokens) + self.embed_positions(pos)
|
||||
for layer in self.layers:
|
||||
x = layer(x, xa)
|
||||
x = self.layer_norm(x)
|
||||
# Tied projection
|
||||
return torch.matmul(x, self.embed_tokens.weight.transpose(0, 1))
|
||||
|
||||
|
||||
class Whisper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = WhisperEncoder()
|
||||
self.decoder = WhisperDecoder()
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
xa = self.encoder(mel)
|
||||
return self.decoder(tokens, xa)
|
||||
|
||||
|
||||
class DecoderWithFixedXa(torch.nn.Module):
|
||||
"""Wraps the decoder with the encoder output stored as a buffer.
|
||||
|
||||
The audio is fixed for the whole utterance, so ``xa`` is a constant relative
|
||||
to the per-token decode loop. Storing it as a buffer lets us compile the
|
||||
decoder once with a single dynamic-length ``tokens`` input, avoiding a full
|
||||
recompilation at every step as the sequence grows.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder: WhisperDecoder, xa: torch.Tensor):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.register_buffer("xa", xa)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(tokens, self.xa)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight loading: HF state_dict -> our model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_hf_weights_into(model: Whisper) -> None:
|
||||
"""Copy HF whisper-tiny.en weights into our matching modules."""
|
||||
hf = WhisperForConditionalGeneration.from_pretrained(REPO_ID).eval()
|
||||
sd = hf.state_dict()
|
||||
|
||||
def get(name: str) -> torch.Tensor:
|
||||
return sd[f"model.{name}"].clone()
|
||||
|
||||
enc = model.encoder
|
||||
enc.conv1.weight.data.copy_(get("encoder.conv1.weight"))
|
||||
enc.conv1.bias.data.copy_(get("encoder.conv1.bias"))
|
||||
enc.conv2.weight.data.copy_(get("encoder.conv2.weight"))
|
||||
enc.conv2.bias.data.copy_(get("encoder.conv2.bias"))
|
||||
enc.embed_positions.weight.data.copy_(get("encoder.embed_positions.weight"))
|
||||
enc.layer_norm.weight.data.copy_(get("encoder.layer_norm.weight"))
|
||||
enc.layer_norm.bias.data.copy_(get("encoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(enc.layers):
|
||||
prefix = f"encoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
dec = model.decoder
|
||||
dec.embed_tokens.weight.data.copy_(get("decoder.embed_tokens.weight"))
|
||||
dec.embed_positions.weight.data.copy_(get("decoder.embed_positions.weight"))
|
||||
dec.layer_norm.weight.data.copy_(get("decoder.layer_norm.weight"))
|
||||
dec.layer_norm.bias.data.copy_(get("decoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(dec.layers):
|
||||
prefix = f"decoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.encoder_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.q_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.k_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.bias")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio loading + decoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_wav_16k_mono(path: Path) -> np.ndarray:
|
||||
with wave.open(str(path), "rb") as w:
|
||||
sr = w.getframerate()
|
||||
n = w.getnframes()
|
||||
ch = w.getnchannels()
|
||||
sw = w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
|
||||
if sw == 2:
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sw == 4:
|
||||
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
elif sw == 1:
|
||||
samples = (
|
||||
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||
) / 128.0
|
||||
else:
|
||||
raise ValueError(f"unsupported sample width {sw}")
|
||||
|
||||
if ch > 1:
|
||||
samples = samples.reshape(-1, ch).mean(axis=1)
|
||||
|
||||
if sr != 16000:
|
||||
ratio = sr / 16000
|
||||
out_len = int(len(samples) / ratio)
|
||||
idx = np.arange(out_len, dtype=np.float64) * ratio
|
||||
lo = idx.astype(np.int64)
|
||||
frac = (idx - lo).astype(np.float32)
|
||||
hi = np.clip(lo + 1, 0, len(samples) - 1)
|
||||
samples = samples[lo] * (1.0 - frac) + samples[hi] * frac
|
||||
|
||||
return samples.astype(np.float32)
|
||||
|
||||
|
||||
def greedy_decode(logits_row: torch.Tensor, suppress_first_eot: bool) -> int:
|
||||
masked = logits_row.clone()
|
||||
masked[TOKEN_SOT:] = float("-inf")
|
||||
if suppress_first_eot:
|
||||
masked[TOKEN_EOT] = float("-inf")
|
||||
return int(torch.argmax(masked).item())
|
||||
|
||||
|
||||
def find_default_audio() -> Optional[Path]:
|
||||
here = Path(__file__).resolve()
|
||||
workspace_root = here.parents[3]
|
||||
candidate = workspace_root / "examples" / "whisper" / "assets" / "jfk.wav"
|
||||
return candidate if candidate.exists() else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
audio_arg = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
if audio_arg:
|
||||
audio_path = Path(audio_arg)
|
||||
else:
|
||||
audio_path = find_default_audio()
|
||||
if audio_path is None:
|
||||
print(
|
||||
"error: no audio file given and bundled jfk.wav not found",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Loading audio:", audio_path)
|
||||
audio = load_wav_16k_mono(audio_path)
|
||||
|
||||
print("Computing log-mel features...")
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(REPO_ID)
|
||||
features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
|
||||
mel: torch.Tensor = features.input_features[0].to(device) # (80, 3000)
|
||||
assert mel.shape == (N_MELS, 3000), mel.shape
|
||||
|
||||
print("Building model and loading weights...")
|
||||
model = Whisper().eval().to(device)
|
||||
load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
# 1. Run the encoder once eagerly. The audio doesn't change during decode,
|
||||
# so xa is a constant input to the decoder.
|
||||
with torch.no_grad():
|
||||
xa = model.encoder(mel)
|
||||
|
||||
# 2. Wrap the decoder so its only varying input is `tokens`, then compile
|
||||
# once with a dynamic length dim. Subsequent calls reuse the same
|
||||
# compiled graph — no recompile per token.
|
||||
decoder_only = DecoderWithFixedXa(model.decoder, xa).eval().to(device)
|
||||
example_tokens = torch.tensor(
|
||||
[TOKEN_SOT, TOKEN_NO_TIMESTAMPS], dtype=torch.long, device=device
|
||||
)
|
||||
print(
|
||||
f"Compiling decoder with dynamic seq dim (search_iters={search_iters})..."
|
||||
)
|
||||
compile_start = time.time()
|
||||
compiled_decoder = luminal_compile(
|
||||
decoder_only,
|
||||
example_tokens,
|
||||
search_iterations=search_iters,
|
||||
dynamic_dim=0,
|
||||
)
|
||||
print(f"Compiled in {time.time() - compile_start:.1f}s")
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
out = compiled_decoder(decoder_input_ids)
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
else:
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return model(mel, decoder_input_ids)
|
||||
|
||||
tokens = [TOKEN_SOT, TOKEN_NO_TIMESTAMPS]
|
||||
|
||||
print("Transcribing", end="", flush=True)
|
||||
decode_start = time.time()
|
||||
for step in range(max_new_tokens):
|
||||
decoder_input_ids = torch.tensor(tokens, dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
logits = step_logits(decoder_input_ids)
|
||||
|
||||
next_token = greedy_decode(logits[-1], suppress_first_eot=(step == 0))
|
||||
if next_token == TOKEN_EOT:
|
||||
break
|
||||
tokens.append(next_token)
|
||||
piece = tokenizer.decode([next_token], skip_special_tokens=False)
|
||||
print(piece, end="", flush=True)
|
||||
elapsed = time.time() - decode_start
|
||||
print()
|
||||
|
||||
transcription = tokenizer.decode(tokens[2:], skip_special_tokens=True)
|
||||
print(f"\nFinal transcription: {transcription}")
|
||||
print(
|
||||
f"Generated {len(tokens) - 2} tokens in {elapsed:.2f}s "
|
||||
f"({(len(tokens) - 2) / max(elapsed, 1e-6):.1f} tok/s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -22,7 +22,7 @@ from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 2 * 60 * 60
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
@@ -168,37 +168,6 @@ def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _build_cuda_extension(env: dict[str, str]) -> None:
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"maturin",
|
||||
"develop",
|
||||
"--manifest-path",
|
||||
f"{PROJECT_DIR}/rust/Cargo.toml",
|
||||
"--features",
|
||||
"cuda",
|
||||
"--profile",
|
||||
"release",
|
||||
]
|
||||
subprocess.run(cmd, env=env, cwd=PROJECT_DIR, check=True)
|
||||
|
||||
|
||||
def _effective_timeout(timeout: int) -> int:
|
||||
if os.environ.get("GITHUB_ACTIONS") == "true" and timeout < DEFAULT_TIMEOUT:
|
||||
print(
|
||||
f"Using Modal timeout {DEFAULT_TIMEOUT}s instead of requested "
|
||||
f"{timeout}s in GitHub Actions.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return DEFAULT_TIMEOUT
|
||||
return timeout
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
@@ -225,8 +194,6 @@ class TestRunner:
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
_build_cuda_extension(env)
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
@@ -251,6 +218,8 @@ class TestRunner:
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
@@ -316,7 +285,7 @@ class TestRunner:
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int, bool, str | None, list[str]]:
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
@@ -331,8 +300,7 @@ def _parse_cli_args(
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=DEFAULT_TIMEOUT,
|
||||
help="Modal execution timeout in seconds. Defaults to %(default)s seconds.",
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
@@ -366,11 +334,11 @@ def main(*cli_args: str):
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
timeout = _effective_timeout(timeout)
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
runner_options["timeout"] = timeout
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
|
||||
@@ -32,7 +32,7 @@ module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: tests that download large models, compile full-width model graphs, fuzz many CUDA search choices, or otherwise require explicit opt-in",
|
||||
"slow: tests that download large models or require pre-generated artifacts",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -46,4 +46,5 @@ dev = [
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"modal>=1.3.5",
|
||||
"matplotlib>=3.8",
|
||||
]
|
||||
|
||||
@@ -1,43 +1,34 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
|
||||
CUDA_TESTS="tests/"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 1: Building native backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
uv run --group dev pytest $NATIVE_TESTS -v
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Building CUDA backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "Slow CUDA tests are opt-in. To include them, run:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -v -s"
|
||||
echo "Or, for only slow tests:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -m slow -v -s"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -4,34 +4,17 @@ set -e
|
||||
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
|
||||
echo ""
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
PYTEST_MARK='not slow'
|
||||
if [[ "${1:-}" == "--include-slow" ]]; then
|
||||
PYTEST_MARK=''
|
||||
elif [[ "${1:-}" == "--slow-only" ]]; then
|
||||
PYTEST_MARK='slow'
|
||||
elif [[ "${1:-}" != "" ]]; then
|
||||
echo "Usage: ./run_tests_cuda.sh [--include-slow|--slow-only]"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
if [[ -n "$PYTEST_MARK" ]]; then
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -m "$PYTEST_MARK" -v -s
|
||||
else
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -v -s
|
||||
fi
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -12,67 +12,6 @@ use crate::typed_data::TypedData;
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Recover a single-variable dim's variable value from an observed runtime size.
|
||||
///
|
||||
/// Returns `Some((var, value))` when the expression contains exactly one
|
||||
/// variable, is affine in that variable, and `value` round-trips through
|
||||
/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise
|
||||
/// — multi-variable expressions, non-affine forms, slope==0, and inversions
|
||||
/// that don't divide cleanly are all rejected so we never write a wrong
|
||||
/// guess into `dyn_map`.
|
||||
fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> {
|
||||
use luminal::shape::Term;
|
||||
let terms = expr.terms.read();
|
||||
|
||||
// Identify the unique variable, if any.
|
||||
let mut var: Option<char> = None;
|
||||
for t in terms.iter() {
|
||||
if let Term::Var(c) = t {
|
||||
match var {
|
||||
None => var = Some(*c),
|
||||
Some(existing) if existing == *c => {}
|
||||
Some(_) => return None, // multi-var — bail out
|
||||
}
|
||||
}
|
||||
}
|
||||
let var = var?;
|
||||
|
||||
// Bare-var fast path — terms is exactly `[Var]`.
|
||||
if terms.len() == 1 {
|
||||
return Some((var, dim_val));
|
||||
}
|
||||
|
||||
// Probe two points to recover slope/intercept of an assumed affine form
|
||||
// `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default
|
||||
// dynamic-dim min is 2, and 3 keeps the inputs small in case the
|
||||
// expression includes a multiplication that could overflow at scale).
|
||||
drop(terms);
|
||||
let f2 = expr.exec_single_var_checked(2)? as i64;
|
||||
let f3 = expr.exec_single_var_checked(3)? as i64;
|
||||
let slope = f3 - f2;
|
||||
if slope == 0 {
|
||||
return None;
|
||||
}
|
||||
let intercept = f2 - 2 * slope;
|
||||
let target = dim_val as i64 - intercept;
|
||||
if slope == 0 || target % slope != 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = target / slope;
|
||||
if candidate < 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = candidate as usize;
|
||||
|
||||
// Verify by re-evaluating with the candidate value. Catches non-affine
|
||||
// forms whose probe points happen to be collinear (e.g. `min(s, 100)`
|
||||
// would look affine for s ∈ {2, 3} but flatten beyond 100).
|
||||
if expr.exec_single_var_checked(candidate)? != dim_val {
|
||||
return None;
|
||||
}
|
||||
Some((var, candidate))
|
||||
}
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
@@ -98,12 +37,7 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -129,9 +63,7 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -287,27 +219,17 @@ impl CompiledGraph {
|
||||
}
|
||||
|
||||
/// Auto-detect and set dynamic dimensions from input tensor shapes.
|
||||
///
|
||||
/// For each user input we walk the symbolic shape expressions side-by-side
|
||||
/// with the concrete sizes Dynamo handed us at runtime and try to recover
|
||||
/// each unbound variable's value. Two cases are handled:
|
||||
///
|
||||
/// * Bare-variable dim (`s`): set directly from the size.
|
||||
/// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a`
|
||||
/// by sampling the expression at two probe points to extract the
|
||||
/// slope, recovering the intercept, and verifying that plugging the
|
||||
/// recovered value back through `exec_single_var_checked` reproduces
|
||||
/// the observed size. The verification step rejects everything
|
||||
/// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong
|
||||
/// guess to `dyn_map`.
|
||||
///
|
||||
/// Multi-variable dims are skipped here; another input's shape — or an
|
||||
/// explicit `set_dim` call — is expected to bind those.
|
||||
/// For each user input, matches the concrete shape against its symbolic
|
||||
/// shape expressions and sets the corresponding dyn_map entries.
|
||||
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
|
||||
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
|
||||
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
|
||||
if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) {
|
||||
self.graph.set_dim(var, value);
|
||||
// Check if this expression is a bare symbolic variable
|
||||
let terms = dim_expr.terms.read();
|
||||
if terms.len() == 1
|
||||
&& let luminal::shape::Term::Var(c) = terms[0]
|
||||
{
|
||||
self.graph.set_dim(c, dim_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -483,7 +405,10 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes.clone()
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -23,169 +23,20 @@ fn resolve_dim_sizes(
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
let s = e.as_expr.expr_str.trim();
|
||||
// Try the full sympy-style parse first so compound forms like
|
||||
// `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and
|
||||
// similar dim-altering ops) propagate as a real Expression
|
||||
// rather than collapsing to the size-1 fallback. Fall back to
|
||||
// the bare-Symbol fast path when that fails — the parser
|
||||
// bails on unrecognised heads (Pow, Min, etc.) and we'd
|
||||
// rather lose the symbolic info than misinterpret it.
|
||||
parse_sympy_expr(s, sym_to_char)
|
||||
.or_else(|| {
|
||||
pt2_parser::extract_symbol_name_pub(s)
|
||||
.and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c)))
|
||||
})
|
||||
.or_else(|| {
|
||||
// As a last resort, if the EP gave us a concrete `hint`
|
||||
// (the value used to seed shape tracing), use it. The
|
||||
// dim is technically dynamic but at least output-shape
|
||||
// resolution won't return 1 for unset dims.
|
||||
e.as_expr
|
||||
.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(|h| Expression::from(h as usize))
|
||||
})
|
||||
.unwrap_or_else(|| Expression::from(1usize))
|
||||
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
|
||||
if let Some(c) = sym_to_char.get(&sym) {
|
||||
Expression::from(*c)
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
|
||||
///
|
||||
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
|
||||
///
|
||||
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
|
||||
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
|
||||
/// * `Integer(N)` / `Number(N)` — concrete int.
|
||||
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
|
||||
///
|
||||
/// Returns `None` for anything else so the caller can fall back to a less
|
||||
/// precise representation rather than committing a wrong expression.
|
||||
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Bare integer literal — `srepr` doesn't usually emit this at the top
|
||||
// level (it wraps in `Integer(...)`), but accept it for robustness.
|
||||
if let Ok(n) = s.parse::<i64>() {
|
||||
return Some(Expression::from(n as usize));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(s)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
// Body is `'name', positive=True, integer=True` etc. Pull the
|
||||
// first quoted token as the name.
|
||||
let name = extract_first_quoted(body)?;
|
||||
sym_to_char.get(&name).map(|c| Expression::from(*c))
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let n: i64 = body.trim().parse().ok()?;
|
||||
Some(Expression::from(n as usize))
|
||||
}
|
||||
"Mul" | "Add" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
|
||||
for p in iter {
|
||||
let rhs = parse_sympy_expr(p, sym_to_char)?;
|
||||
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into (head, body); returns None if not in that form.
|
||||
fn split_head(s: &str) -> Option<(&str, &str)> {
|
||||
let open = s.find('(')?;
|
||||
if !s.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&s[..open], &s[open + 1..s.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list,
|
||||
/// e.g. `'s77', positive=True` → `s77`.
|
||||
fn extract_first_quoted(s: &str) -> Option<String> {
|
||||
let bytes = s.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(s[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
|
||||
/// dimensional information).
|
||||
fn split_top_level_args(s: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = s.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = s[start..i].trim();
|
||||
// Drop `key=value` kwargs — they're metadata sympy uses
|
||||
// for pretty-printing, not arguments to the operator.
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = s[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
// sympy kwargs are bare identifiers like `positive`, `integer`.
|
||||
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
@@ -262,13 +113,10 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,14 +132,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -160,7 +160,31 @@ pub fn parse_pt2(path: &str) -> Result<ParsedPT2> {
|
||||
let file = File::open(path).with_context(|| format!("Failed to open PT2 file: {path}"))?;
|
||||
let mut archive = ZipArchive::new(file).context("Failed to read PT2 ZIP archive")?;
|
||||
|
||||
// Determine archive prefix from the first entry
|
||||
// Torch >= 2.6 uses a flat archive with no prefix directory; detect by presence of the
|
||||
// well-known root-level file. Older torch used a prefix (e.g. "archive/models/model.json").
|
||||
let is_new_format = archive
|
||||
.file_names()
|
||||
.any(|n| n == "serialized_exported_program.json");
|
||||
|
||||
if is_new_format {
|
||||
let program: ExportedProgram = {
|
||||
let mut entry = archive.by_name("serialized_exported_program.json")?;
|
||||
let mut buf = String::new();
|
||||
entry.read_to_string(&mut buf)?;
|
||||
serde_json::from_str(&buf)
|
||||
.context("Failed to parse serialized_exported_program.json")?
|
||||
};
|
||||
// Tensor constants live in serialized_constants.pt; Python extracts them
|
||||
// and loads them post-compile via set_weight_from_ptr.
|
||||
return Ok(ParsedPT2 {
|
||||
program,
|
||||
constants_config: None,
|
||||
archive_prefix: String::new(),
|
||||
pt2_path: path.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Old prefix-based format.
|
||||
let archive_prefix = {
|
||||
let first = archive
|
||||
.file_names()
|
||||
|
||||
@@ -15,16 +15,7 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
pub min_val: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Which SDPA variant we're translating. Governs argument positions and
|
||||
/// which output slots are consumed downstream.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum SdpaVariant {
|
||||
/// `aten._scaled_dot_product_efficient_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False, *, scale=None)
|
||||
/// -> (output, log_sumexp, philox_seed, philox_offset)`
|
||||
Efficient,
|
||||
/// `aten._scaled_dot_product_flash_attention.default(q, k, v, dropout_p=0.,
|
||||
/// is_causal=False, return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// rng_state, unused, debug_attn_mask)`
|
||||
Flash,
|
||||
/// `aten._scaled_dot_product_flash_attention_for_cpu.default(q, k, v,
|
||||
/// dropout_p=0., is_causal=False, *, attn_mask=None, scale=None)
|
||||
/// -> (output, logsumexp)`
|
||||
FlashForCpu,
|
||||
/// `aten._scaled_dot_product_cudnn_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False,
|
||||
/// return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// philox_seed, philox_offset, debug_attn_mask)`
|
||||
Cudnn,
|
||||
/// `aten.scaled_dot_product_attention.default(q, k, v, attn_mask=None,
|
||||
/// dropout_p=0., is_causal=False, *, scale=None, enable_gqa=False)
|
||||
/// -> Tensor` (single output, no tuple).
|
||||
Unified,
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate any SDPA op variant into `softmax((Q@K^T)*scale + causal_mask +
|
||||
/// attn_bias) @ V`. Stores the primary `output` by the node's first output
|
||||
/// name. Other tuple outputs (logsumexp, philox_seed, etc.) are unused in
|
||||
/// inference — left unbound; the downstream `getitem(node, 0)` resolves
|
||||
/// to `output` via the tuple-output name list.
|
||||
pub(crate) fn translate_sdpa(&mut self, node: &Node, variant: SdpaVariant) -> Result<()> {
|
||||
let query = self.get_input_tensor(node, 0)?;
|
||||
let key = self.get_input_tensor(node, 1)?;
|
||||
let value = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// Resolve args by NAME rather than positional index. PT2 serializes
|
||||
// kwargs inline in `node.inputs` with `kind=2`, so any arg that wasn't
|
||||
// passed positionally by the caller shifts the indices of subsequent
|
||||
// positional args. Name-based lookup is unambiguous across variants
|
||||
// and across caller argument-passing styles.
|
||||
let arg_by_name =
|
||||
|name: &str| -> Option<&NodeInput> { node.inputs.iter().find(|i| i.name == name) };
|
||||
let tensor_arg = |name: &str| -> Option<GraphTensor> {
|
||||
arg_by_name(name)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.and_then(|n| self.get_tensor(n).ok())
|
||||
};
|
||||
let float_arg =
|
||||
|name: &str| -> Option<f64> { arg_by_name(name).and_then(|i| i.arg.as_float()) };
|
||||
let bool_arg =
|
||||
|name: &str| -> Option<bool> { arg_by_name(name).and_then(|i| i.arg.as_bool()) };
|
||||
|
||||
// attn_bias (Efficient/Cudnn/Unified) or attn_mask (FlashForCpu/Unified).
|
||||
let additive = tensor_arg("attn_bias").or_else(|| tensor_arg("attn_mask"));
|
||||
|
||||
let dropout_p = float_arg("dropout_p").unwrap_or(0.0) as f32;
|
||||
anyhow::ensure!(
|
||||
dropout_p == 0.0,
|
||||
"SDPA: dropout_p={dropout_p} unsupported (inference only)"
|
||||
);
|
||||
let is_causal = bool_arg("is_causal").unwrap_or(false);
|
||||
// Silence compiler warnings — variant arg remains for branch-specific
|
||||
// logic (output tuple-name resolution below) and for future divergence.
|
||||
let _ = variant;
|
||||
|
||||
// `scale` kwarg, default 1/sqrt(head_dim).
|
||||
let head_dim = query
|
||||
.shape
|
||||
.dims
|
||||
.last()
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA: query head_dim must be concrete")?;
|
||||
let default_scale = 1.0_f32 / (head_dim as f32).sqrt();
|
||||
let scale = float_arg("scale")
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(default_scale);
|
||||
|
||||
// Math form: scores = (Q @ K^T) * scale; + causal_mask; + attn_bias;
|
||||
// attn = softmax(scores, dim=-1); out = attn @ V.
|
||||
let q_ndim = query.shape.len();
|
||||
anyhow::ensure!(
|
||||
q_ndim >= 2,
|
||||
"SDPA: query must have at least 2 dims (got {q_ndim})"
|
||||
);
|
||||
// Transpose last two dims of key.
|
||||
let mut perm: Vec<usize> = (0..q_ndim).collect();
|
||||
perm.swap(q_ndim - 2, q_ndim - 1);
|
||||
let key_t = key.permute(perm);
|
||||
let (q_for_mm, k_for_mm) = ensure_same_dtype(query, key_t);
|
||||
let scores = q_for_mm.matmul(k_for_mm);
|
||||
let scale_t = self
|
||||
.graph
|
||||
.constant_float(scale)
|
||||
.cast(scores.dtype)
|
||||
.expand_rhs(scores.shape);
|
||||
let mut scores = scores * scale_t;
|
||||
|
||||
if is_causal {
|
||||
let s_q = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 2)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_q must be concrete")?;
|
||||
let s_k = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 1)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_k must be concrete")?;
|
||||
let size = s_q.max(s_k);
|
||||
// triu with diagonal=1 = 1 strictly above diagonal, 0 elsewhere.
|
||||
let mut mask = self.graph.triu(size, 1).cast(DType::F32);
|
||||
if s_q != size || s_k != size {
|
||||
mask = mask.slice_along(0..s_q, 0).slice_along(0..s_k, 1);
|
||||
}
|
||||
// -1e9 * mask ≈ -inf where masked, 0 otherwise. Broadcast across
|
||||
// batch/head prefix dims of `scores`.
|
||||
let neg_large = mask * (-1e9_f32);
|
||||
let mut neg_large = neg_large.cast(scores.dtype);
|
||||
for _ in 0..(q_ndim - 2) {
|
||||
neg_large = neg_large.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let (scores_b, mask_b) = broadcast_binary(scores, neg_large);
|
||||
scores = scores_b + mask_b;
|
||||
}
|
||||
if let Some(bias) = additive {
|
||||
let (scores_b, bias_b) = ensure_same_dtype(scores, bias);
|
||||
let (scores_b, bias_b) = broadcast_binary(scores_b, bias_b);
|
||||
scores = scores_b + bias_b;
|
||||
}
|
||||
|
||||
let attn = scores.softmax(q_ndim - 1);
|
||||
let (attn, value) = ensure_same_dtype(attn, value);
|
||||
let out = attn.matmul(value);
|
||||
|
||||
// Store the primary output by name. The other tuple outputs are
|
||||
// inference-time dead ends — downstream getitem(node, 0) resolves to
|
||||
// the same tensor name we bind here, because pt2 serializes the
|
||||
// multi-output name list with output[0] as the primary slot.
|
||||
let out_name = if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else if variant == SdpaVariant::Unified {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.or_else(|| {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(name) = out_name
|
||||
&& !name.is_empty()
|
||||
{
|
||||
self.tensors.insert(name, out);
|
||||
} else {
|
||||
anyhow::bail!("SDPA: no output tensor name found on node {}", node.target);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for SdpaVariant {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
matches!(
|
||||
(self, other),
|
||||
(SdpaVariant::Efficient, SdpaVariant::Efficient)
|
||||
| (SdpaVariant::Flash, SdpaVariant::Flash)
|
||||
| (SdpaVariant::FlashForCpu, SdpaVariant::FlashForCpu)
|
||||
| (SdpaVariant::Cudnn, SdpaVariant::Cudnn)
|
||||
| (SdpaVariant::Unified, SdpaVariant::Unified)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,11 +389,8 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -5,8 +5,6 @@ use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -70,8 +68,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
|
||||
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
@@ -148,7 +144,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -188,16 +183,8 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
// `empty` and `empty_permuted` allocate uninitialised tensors of
|
||||
// a given shape; the caller fills them. We lower to zeros with
|
||||
// the same shape+dtype — downstream reads are officially UB on
|
||||
// PyTorch's side, and downstream writes overwrite our zeros.
|
||||
// Qwen3MoE's MoE block uses `empty_permuted` to allocate the
|
||||
// expert-output staging tensor before scatter-adding into it.
|
||||
"torch.ops.aten.empty.memory_format" | "torch.ops.aten.empty_permuted.default" => {
|
||||
self.translate_empty(node)?
|
||||
}
|
||||
// Qwen3-MoE's expert-balance counts tokens-per-expert via histc.
|
||||
"torch.ops.aten.empty_permuted.default"
|
||||
| "torch.ops.aten.empty.memory_format" => self.translate_empty(node)?,
|
||||
"torch.ops.aten.histc.default" => self.translate_histc(node)?,
|
||||
|
||||
// Grouped matmul (MoE expert dispatch).
|
||||
@@ -219,18 +206,9 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
"torch.ops.aten.ge.Scalar" => self.translate_scalar_comparison(node, |a, s| a.ge(s))?,
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
"torch.ops.aten.eq.Scalar" => self.translate_scalar_comparison(node, |a, s| a.eq(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -248,13 +226,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -293,27 +264,18 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -326,35 +288,49 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.ceil.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// ceil(x) = trunc(x) + (x > trunc(x)).
|
||||
// Cast-to-Int rounds toward zero, so for any positive fractional
|
||||
// `x` the trunc sits below `x` and we add 1; for negatives we
|
||||
// have `trunc >= x` and adjust=0. Avoids the two extra
|
||||
// mul-by-(-1) nodes that the `-floor(-x)` lowering emits.
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = a.gt(trunc).cast(DType::F32);
|
||||
trunc + adjust
|
||||
// ceil(x) = -floor(-x)
|
||||
let neg_a = a * (-1.0);
|
||||
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = neg_a.lt(trunc).cast(DType::F32);
|
||||
let floor_neg = trunc - adjust;
|
||||
floor_neg * (-1.0)
|
||||
}
|
||||
"torch.ops.aten.erf.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// Abramowitz & Stegun approximation 7.1.28 (max error ~1.5e-7)
|
||||
// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
|
||||
// where t = 1/(1 + 0.3275911*|x|), poly in Horner form
|
||||
let ax = a.abs();
|
||||
let x2 = a * a;
|
||||
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
|
||||
// Horner: t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
|
||||
let poly = t
|
||||
* (t * (t
|
||||
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
|
||||
+ (-0.284_496_72_f32))
|
||||
+ 0.254_829_6_f32);
|
||||
let result_abs =
|
||||
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
|
||||
// sign(x) = 2*(x >= 0) - 1
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
|
||||
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
|
||||
result_abs * sign
|
||||
self.erf_approx(a)
|
||||
}
|
||||
"torch.ops.aten.gelu.default" => {
|
||||
let a_in = self.get_input_tensor(node, 0)?;
|
||||
// PyTorch's gelu has a kwarg `approximate` (default "none").
|
||||
// "none" → 0.5 * x * (1 + erf(x / sqrt(2))) (exact)
|
||||
// "tanh" → 0.5 * x * (1 + tanh(c * (x + 0.044715*x^3)))
|
||||
// where c = sqrt(2/pi) ≈ 0.7978845608
|
||||
// Gemma family uses approximate="tanh" but lowering may emit
|
||||
// either form; honour whatever the FX graph carries.
|
||||
let approximate = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "approximate"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
// Promote to F32 around the constants/comparisons (same reason
|
||||
// as clamp/erf — luminal binary ops assert matching dtypes).
|
||||
let orig = a_in.dtype;
|
||||
let a = if orig == DType::F32 { a_in } else { a_in.cast(DType::F32) };
|
||||
let half = self.graph.constant_float(0.5).expand_rhs(a.shape);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(a.shape);
|
||||
let result = if approximate.as_deref() == Some("tanh") {
|
||||
let x2 = a * a;
|
||||
let inner = a * (x2 * 0.044715_f32 + 1.0) * 0.797_884_56_f32;
|
||||
half * a * (one + inner.tanh())
|
||||
} else {
|
||||
let scaled = a * 0.707_106_77_f32; // 1 / sqrt(2)
|
||||
let erf_val = self.erf_approx(scaled);
|
||||
half * a * (one + erf_val)
|
||||
};
|
||||
if orig == DType::F32 { result } else { result.cast(orig) }
|
||||
}
|
||||
"torch.ops.aten.isnan.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -409,17 +385,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -450,29 +415,6 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Scaled dot-product attention — each variant binds args slightly
|
||||
// differently but all lower to matmul+softmax via translate_sdpa.
|
||||
"torch.ops.aten._scaled_dot_product_efficient_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Efficient)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Flash)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::FlashForCpu)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_cudnn_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Cudnn)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten.scaled_dot_product_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Unified)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
@@ -483,28 +425,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod attention;
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
@@ -69,6 +68,9 @@ impl<'a> Translator<'a> {
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
|
||||
// Per-block partitioning is now handled automatically by the upstream
|
||||
// loop-rolling prepass; this translator no longer needs to insert
|
||||
// manual graph breaks at RMSNorm boundaries.
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
self.translate_node(node)
|
||||
@@ -188,21 +190,8 @@ impl<'a> Translator<'a> {
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Ok(v);
|
||||
}
|
||||
// Fall through to symbolic-aware resolution. Op-arg slots like `dim`
|
||||
// and `axis` are always concrete in practice, but with dynamic shapes
|
||||
// PT2 occasionally hands us a SymInt that is fully bound at export
|
||||
// time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`);
|
||||
// accept those when they reduce to a concrete int rather than failing
|
||||
// with the misleading "not an int" diagnostic.
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg)
|
||||
&& let Some(v) = expr.to_usize()
|
||||
{
|
||||
return Ok(v as i64);
|
||||
}
|
||||
anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg)
|
||||
arg.as_int()
|
||||
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
|
||||
@@ -221,37 +210,11 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
|
||||
use crate::pt2_schema::SymIntEntry;
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
// Symbolic int lists: tolerate them as long as every entry is a
|
||||
// bound concrete value. Prevents false "not an int list" failures on
|
||||
// graphs where torch.export emits sym_ints for what is dimensionally
|
||||
// a static parameter (kernel sizes, etc. with dynamic batch).
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
let mut out = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
let v = match entry {
|
||||
SymIntEntry::Int(i) => Some(i.as_int),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.and_then(|e| e.to_usize().map(|u| u as i64)),
|
||||
};
|
||||
match v {
|
||||
Some(n) => out.push(n),
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Input {idx} of {} contains an unresolved sym_int entry",
|
||||
node.target
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(out);
|
||||
}
|
||||
arg.as_ints()
|
||||
.map(|v| v.to_vec())
|
||||
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))
|
||||
@@ -376,3 +339,4 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -120,47 +120,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -300,15 +259,21 @@ impl<'a> Translator<'a> {
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
|
||||
// Normalize negative indices for this dimension. Stay in Int —
|
||||
// multiplying an Int tensor by an Expression broadcasts the axis
|
||||
// size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int
|
||||
// for the result, Bool→F32 for the negative mask) per indexed dim.
|
||||
let axis_size = src_shape[dim_idx];
|
||||
let idx_int = idx_tensor.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
|
||||
let is_negative = idx_int.lt(zero).cast(DType::Int);
|
||||
let idx_int = idx_int + is_negative * axis_size;
|
||||
// Normalize negative indices for this dimension
|
||||
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"index.Tensor: dim {} must be concrete for negative index normalization",
|
||||
dim_idx
|
||||
)
|
||||
})?;
|
||||
let idx_f32 = idx_tensor.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_size as f32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_negative = idx_f32.lt(zero).cast(DType::F32);
|
||||
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let stride = &strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
@@ -374,34 +339,20 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
// (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for
|
||||
// the negative mask) that the previous F32-routed path emitted.
|
||||
let axis_dim = a.shape.dims[dim];
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(indices_int.shape);
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
|
||||
})?;
|
||||
let indices_f32 = indices.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_dim as f32)
|
||||
.expand_rhs(indices_f32.shape);
|
||||
let is_negative = indices_f32.lt(zero).cast(DType::F32);
|
||||
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -438,50 +389,100 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
.arg
|
||||
.as_tensors()
|
||||
.context("index_put: indices not as_tensors")?;
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
// --- all-tensor indices: bool-mask blend or scatter_nd ---
|
||||
if let Some(index_names) = node.inputs[1].arg.as_tensors() {
|
||||
if index_names.len() == 1 {
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
// Broadcast the (often scalar) value tensor to match data shape,
|
||||
// then blend by mask. Cast mask to data's dtype for the
|
||||
// arithmetic so this works for both integer and float data.
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// where(mask, value, a) as `a + mask*(value - a)`. Saves a mul
|
||||
// and the `1.0` constant compared to the `a*(1 - m) + v*m`
|
||||
// form; works for any numeric dtype without a dedicated cond.
|
||||
return Ok(a + mask_f * (values_b - a));
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// Implements where(mask, value, a) as
|
||||
// a*(1 - mask) + value*mask
|
||||
// — works without a dedicated cond op for any numeric dtype.
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
return Ok(a * (one - mask_f) + values_b * mask_f);
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. Always pad
|
||||
// a trailing size-1 dim so rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
return Ok(a.scatter_nd(indices, values));
|
||||
}
|
||||
bail!("index_put with multiple all-tensor indices not yet supported");
|
||||
}
|
||||
|
||||
// --- optional-tensor indices: [None, arange_tensor, None, ...] ---
|
||||
// Each None means "all of that dimension"; one tensor means "index into that dim".
|
||||
// StaticCache uses this for KV updates: cache[:, :, position, :] = new_value.
|
||||
if let Some(opt_tensors) = node.inputs[1].arg.as_optional_tensors() {
|
||||
use crate::pt2_schema::OptionalTensorEntry;
|
||||
let mut first_non_none_dim = 0usize;
|
||||
let mut idx_name: Option<String> = None;
|
||||
let mut non_none_count = 0usize;
|
||||
|
||||
for (i, entry) in opt_tensors.iter().enumerate() {
|
||||
if let OptionalTensorEntry::Tensor(t) = entry {
|
||||
if idx_name.is_none() {
|
||||
first_non_none_dim = i;
|
||||
}
|
||||
idx_name = Some(t.as_tensor.name.clone());
|
||||
non_none_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. idx_tensor can
|
||||
// be ANY shape — its whole shape is "batch dims" in scatter_nd terms,
|
||||
// and K is always 1 (number of dims we're indexing into). Always pad
|
||||
// a trailing size-1 dim so the rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
if non_none_count != 1 {
|
||||
bail!(
|
||||
"index_put with optional tensors: only single non-None index supported \
|
||||
(got {non_none_count})"
|
||||
);
|
||||
}
|
||||
|
||||
let mut indices = self.get_tensor(&idx_name.unwrap())?.cast(DType::Int);
|
||||
|
||||
// Expand 1-D indices [P] to values.shape for scatter_elements:
|
||||
// Build [1, ..., 1, P, 1, ..., 1] with P at first_non_none_dim, then broadcast.
|
||||
let rank = a.shape.len();
|
||||
// Insert singleton dims before first_non_none_dim
|
||||
for i in 0..first_non_none_dim {
|
||||
indices = indices.expand_dim(i, Expression::from(1usize));
|
||||
}
|
||||
// Insert singleton dims after first_non_none_dim
|
||||
let current_rank = indices.shape.len();
|
||||
for j in current_rank..rank {
|
||||
indices = indices.expand_dim(j, Expression::from(1usize));
|
||||
}
|
||||
// Broadcast singletons to values shape
|
||||
let values_shape: Vec<Expression> = values.shape.dims[..rank].to_vec();
|
||||
indices.shape.expand(values_shape);
|
||||
|
||||
return Ok(a.scatter_elements(indices, values, first_non_none_dim));
|
||||
}
|
||||
|
||||
bail!(
|
||||
"index_put: unsupported indices format: {:?}",
|
||||
node.inputs[1].arg
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
|
||||
@@ -6,20 +6,6 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -51,26 +37,16 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -94,100 +70,4 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,94 +72,70 @@ impl<'a> Translator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Lower `aten.histc.default` for the integer-bincount case.
|
||||
/// Translate `aten.histc.default(input, bins, min, max)` → `Tensor[bins]`.
|
||||
///
|
||||
/// Qwen3-MoE's expert-balance layer calls
|
||||
/// `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)` to count how
|
||||
/// many tokens were routed to each expert. With those args every
|
||||
/// integer value `i ∈ [0, K-1]` maps to exactly bin `i`, and the result
|
||||
/// is equivalent to `torch.bincount`. We implement that case as a
|
||||
/// broadcast equality + sum:
|
||||
/// Counts how many input elements fall in each of `bins` equal-width
|
||||
/// buckets over `[min, max]`. PyTorch's histc accepts only 1D input;
|
||||
/// HF MoE forwards emit it on flattened expert-assignment tensors to
|
||||
/// produce per-expert token counts (one_hot + sum, essentially).
|
||||
///
|
||||
/// counts[b] = sum_i (input[i] == b + min) for b in [0, bins)
|
||||
///
|
||||
/// More general histc bin widths (`bins != max - min + 1`, or
|
||||
/// non-integer values that span fractional bins) are not supported
|
||||
/// today — the equality path would silently drop them. We bail rather
|
||||
/// than produce wrong counts.
|
||||
/// Implementation: arange over bins, broadcast to [G, N], element-wise
|
||||
/// `(lower <= input < upper)` into a F32 mask, sum over the input axis.
|
||||
/// The right edge of the last bin is technically inclusive in PyTorch;
|
||||
/// we treat it as exclusive — for the typical MoE use (integer expert
|
||||
/// IDs in `[0, num_experts)`), no input ever equals `max` so this is
|
||||
/// indistinguishable.
|
||||
pub(crate) fn translate_histc(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let bins_i64: i64 = self
|
||||
.get_int_arg(node, 1)
|
||||
.context("histc: missing `bins` arg (#1)")?;
|
||||
// `min`/`max` are float kwargs (default 0.0 each, which means
|
||||
// "auto-pick from input"); for the qwen3-moe call they're always
|
||||
// integers passed as floats.
|
||||
let min = self.get_float_arg(node, 2).unwrap_or(0.0);
|
||||
let max = self.get_float_arg(node, 3).unwrap_or(0.0);
|
||||
let bins = self.get_int_arg(node, 1)? as usize;
|
||||
let min_val = self.get_float_arg(node, 2)? as f32;
|
||||
let max_val = self.get_float_arg(node, 3)? as f32;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 1,
|
||||
"histc: only 1D input is supported, got {}D",
|
||||
"histc: only 1D input supported (got {}D)",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
bins_i64 > 0,
|
||||
"histc: bins must be positive, got {}",
|
||||
bins_i64
|
||||
);
|
||||
// Bincount-equivalent case: one integer value per bin.
|
||||
anyhow::ensure!(
|
||||
(max - min - (bins_i64 - 1) as f64).abs() < 1e-6,
|
||||
"histc: only the bincount-equivalent case (bins == max - min + 1) is \
|
||||
supported; got bins={}, min={}, max={}. Other cases would need a \
|
||||
general bin-width / right-edge-inclusion implementation.",
|
||||
bins_i64,
|
||||
min,
|
||||
max,
|
||||
);
|
||||
|
||||
let bins_u = bins_i64 as usize;
|
||||
let n = input.shape.dims[0];
|
||||
let g = Expression::from(bins);
|
||||
|
||||
// arange(bins) [bins] → cast to input dtype, optionally shift by min,
|
||||
// broadcast to [bins, N], compare for equality with input broadcast.
|
||||
let mut bins_arange = self.graph.arange(Expression::from(bins_u));
|
||||
if min != 0.0 {
|
||||
// `min` is non-zero (uncommon in the qwen3-moe path but legal)
|
||||
// — shift the comparison values to start at min.
|
||||
let min_i = min as i64;
|
||||
let shift = self
|
||||
.graph
|
||||
.constant_float(min_i as f32)
|
||||
.cast(bins_arange.dtype)
|
||||
.expand_rhs(bins_arange.shape);
|
||||
bins_arange += shift;
|
||||
}
|
||||
let bins_expanded = bins_arange.cast(input.dtype).expand_dim(1, n);
|
||||
let input_expanded = input.expand_dim(0, Expression::from(bins_u));
|
||||
let matches = input_expanded.eq(bins_expanded); // Bool [bins, N]
|
||||
let input_f = input.cast(DType::F32);
|
||||
let step = (max_val - min_val) / bins as f32;
|
||||
|
||||
let out_dtype = self.output_meta_dtype(node)?;
|
||||
Ok(matches.cast(out_dtype).sum(1))
|
||||
// Per-bin lower edges: arange(bins) * step + min.
|
||||
let bin_idx = self.graph.arange(g).cast(DType::F32);
|
||||
let lower_1d = bin_idx * step + min_val;
|
||||
let upper_1d = lower_1d + step;
|
||||
|
||||
// Broadcast to [G, N] and produce the boolean mask.
|
||||
let input_b = input_f.expand_dim(0, g);
|
||||
let lower = lower_1d.expand_dim(1, n);
|
||||
let upper = upper_1d.expand_dim(1, n);
|
||||
|
||||
let in_lower = input_b.ge(lower).cast(DType::F32);
|
||||
let in_upper = input_b.lt(upper).cast(DType::F32);
|
||||
let mask = in_lower * in_upper;
|
||||
|
||||
Ok(mask.sum(1))
|
||||
}
|
||||
|
||||
/// Lower `aten.empty.memory_format` and `aten.empty_permuted.default`.
|
||||
/// Translate `aten.empty_permuted.default(size, physical_layout, **kwargs)`
|
||||
/// → zero-filled tensor of shape `size`.
|
||||
///
|
||||
/// Both allocate an uninitialised tensor; the caller is responsible for
|
||||
/// writing into it. We materialise zeros instead — luminal has no
|
||||
/// "uninitialised" notion, and PyTorch's contract on `empty` outputs is
|
||||
/// undefined for any read prior to a write, so a zero-fill is sound.
|
||||
/// `aten.empty_permuted` additionally takes a `physical_layout` arg
|
||||
/// (the storage permutation); for a zero-filled tensor that's a no-op.
|
||||
/// PyTorch's `empty_permuted` allocates uninitialized memory with a given
|
||||
/// stride permutation; downstream code typically overwrites every element
|
||||
/// before reading. Luminal's tensor abstraction doesn't expose strides, so
|
||||
/// the physical_layout hint is irrelevant — we just emit a zero tensor of
|
||||
/// the requested shape and dtype. (Same approach works for `aten.empty`
|
||||
/// variants when they show up.)
|
||||
pub(crate) fn translate_empty(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let zero = self.graph.constant_float(0.0).cast(dtype);
|
||||
let value = self.graph.constant_float(0.0).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
zero
|
||||
value
|
||||
} else {
|
||||
zero.expand_rhs(shape)
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -242,26 +218,28 @@ impl<'a> Translator<'a> {
|
||||
let k = weight.shape.dims[1];
|
||||
let n = weight.shape.dims[2];
|
||||
|
||||
// expert_id[m] = number of g s.t. m >= offs[g], clamped to [0, G-1].
|
||||
// Same value as HF MoE's `expert_ids.clamp(0, num_experts-1)` for
|
||||
// invalid expert IDs from EP, AND protects search-time profiling:
|
||||
// dummy-1 input bytes give offs=[1,…,1], which pushes the raw count
|
||||
// to G for any token with index ≥ 1 and would OOB the weight gather.
|
||||
//
|
||||
// Stay in Int throughout — arange / offs are already Int, ge → Bool
|
||||
// → cast(Int), sum stays Int, and the binary `minimum` handles the
|
||||
// clamp without an F32 round-trip.
|
||||
let _ = g
|
||||
// expert_id[m] = number of g s.t. m >= offs[g]
|
||||
// = first g s.t. m < offs[g], i.e. the expert assigned to m.
|
||||
// Clamp to [0, G-1] before using as gather index. Matches HF MoE's
|
||||
// `expert_ids.clamp(0, num_experts-1)` for invalid IDs from EP, AND
|
||||
// protects search-time profiling: dummy-1 input bytes give offs=[1,…,1],
|
||||
// which makes `m >= offs[g]` true for m≥1 and pushes expert_id to G,
|
||||
// out of bounds for the weight gather. Clamping keeps the gather safe.
|
||||
let g_max_f = (g
|
||||
.to_usize()
|
||||
.context("_grouped_mm: G (num_experts) must be concrete")?;
|
||||
let s_arange = self.graph.arange(s); // Int [S]
|
||||
let ge_int = s_arange
|
||||
.context("_grouped_mm: G (num_experts) must be concrete")?
|
||||
as f32)
|
||||
- 1.0;
|
||||
let offs_f = offs.cast(DType::F32);
|
||||
let s_arange_f = self.graph.arange(s).cast(DType::F32);
|
||||
let ge_boundary = s_arange_f
|
||||
.expand_dim(0, g)
|
||||
.ge(offs.expand_dim(1, s)) // Bool [G, S]
|
||||
.cast(DType::Int); // Int [G, S]
|
||||
let raw = ge_int.sum(0); // Int [S], values in [0, G]
|
||||
let cap = self.graph.constant(g - 1).expand_dim(0, s); // Int [S], all G-1
|
||||
let expert_id = raw.minimum(cap); // Int [S]
|
||||
.ge(offs_f.expand_dim(1, s))
|
||||
.cast(DType::F32);
|
||||
let expert_id = ge_boundary
|
||||
.sum(0)
|
||||
.minimum_f32(g_max_f)
|
||||
.cast(DType::Int); // [S] Int
|
||||
|
||||
// Flat gather index into weight (treated as a length-G*K*N 1D buffer):
|
||||
// flat[m, k_, n_] = expert_id[m] * (K*N) + k_ * N + n_
|
||||
@@ -274,65 +252,46 @@ impl<'a> Translator<'a> {
|
||||
let exp_within = within.expand_dim(0, s);
|
||||
let flat_idx = exp_base + exp_within;
|
||||
|
||||
// Gather → [S, K, N], preserves weight's native dtype (bf16 stays bf16).
|
||||
// Gather → [S, K, N]. Preserves weight's native dtype (bf16 stays bf16).
|
||||
let weight_gathered = weight.gather(flat_idx);
|
||||
|
||||
// Cast for matmul — now on the small gathered slice, not the full weight.
|
||||
let input_f = input.cast(DType::F32);
|
||||
let weight_f = weight_gathered.cast(DType::F32);
|
||||
|
||||
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
|
||||
// Operands stay in their native dtype — no F32 cast on the gathered
|
||||
// weight or the input. The earlier cast(F32) was a holdover from the
|
||||
// broadcast-and-mask version (which had to use F32 because of the
|
||||
// cast(F32) on the mask). Gather-then-matmul has no such requirement,
|
||||
// and casting `[S, K, N]` to F32 doubled the gather scratch (~100 MB
|
||||
// to ~200 MB per layer for Qwen3-30B-A3B prefill). Matmul rewrites
|
||||
// (cuBLASLt etc.) handle bf16 input with F32 accumulator internally.
|
||||
let result = input.unsqueeze(1).matmul(weight_gathered).squeeze(1);
|
||||
let result = input_f.unsqueeze(1).matmul(weight_f).squeeze(1);
|
||||
|
||||
Ok(result.cast(input.dtype))
|
||||
}
|
||||
|
||||
/// Build the where-formula graph: `cond * x + (1 - cond) * y`, computed
|
||||
/// in F32, cast back to `out_dtype`. Shared between `translate_where`,
|
||||
/// `translate_where_scalar_other`, and `translate_masked_fill_scalar` so
|
||||
/// they all go through one well-tested code path.
|
||||
pub(crate) fn where_formula(
|
||||
&mut self,
|
||||
cond: GraphTensor,
|
||||
x: GraphTensor,
|
||||
y: GraphTensor,
|
||||
out_dtype: DType,
|
||||
) -> GraphTensor {
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
// Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops
|
||||
// plus the explicit `1.0` constant. Mathematically identical for
|
||||
// c ∈ {0, 1} and produces the same F32 output type.
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
// Cast back: an F32 result downstream-interpreted as bf16 walks the
|
||||
// buffer at half-stride, returning every-other-element zeros.
|
||||
(y_f + c * (x_f - y_f)).cast(out_dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
let out_dtype = x.dtype;
|
||||
Ok(self.where_formula(cond, x, y, out_dtype))
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
let out_dtype = x.dtype;
|
||||
// Build a tensor for the scalar `other` matching `x`'s shape so we
|
||||
// can route through the shared where_formula helper.
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(x.shape);
|
||||
Ok(self.where_formula(cond, x, other, out_dtype))
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -387,37 +346,33 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names
|
||||
let tuple_outputs = node.outputs.first().and_then(|o| o.as_tensors.as_ref());
|
||||
let values_name = if let Some(ts) = tuple_outputs {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
};
|
||||
let indices_name = if let Some(ts) = tuple_outputs {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build top-k outputs from a full stable argsort. Slice the indices
|
||||
// before gathering values so the gather shape matches the requested
|
||||
// top-k output rather than the full sort width.
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(topk_indices, dim);
|
||||
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
self.tensors.insert(idx_name, topk_indices);
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -51,19 +51,13 @@ impl<'a> Translator<'a> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
let dtype_int = input
|
||||
.arg
|
||||
.as_int()
|
||||
.map(|i| i as u32)
|
||||
.or_else(|| input.arg.as_scalar_type());
|
||||
if let Some(d) = dtype_int {
|
||||
let dtype = torch_dtype_int_to_luminal(d);
|
||||
// Skip emitting a Cast op when the dtype already matches —
|
||||
// PT2 graphs frequently emit `_to_copy` purely as a clone hint
|
||||
// (e.g. dtype=float32 on a tensor that is already F32), and
|
||||
// every redundant Cast inflates the graph and survives until
|
||||
// optimization passes can prove it as a no-op.
|
||||
return Ok(if a.dtype == dtype { a } else { a.cast(dtype) });
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,34 +131,37 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
// `masked_fill(input, mask, fill)` = `where(mask, fill, input)`.
|
||||
// Routes through the shared `where_formula` helper so we exercise
|
||||
// the exact same code path as `aten.where.self`, which is verified
|
||||
// to handle the bf16 cast-back correctly. Hand-rolling the same
|
||||
// formula directly here used to drift (egglog made different
|
||||
// rewrite choices on the rebuilt-locally graph), so we deliberately
|
||||
// re-use the helper.
|
||||
// `aten.masked_fill.Scalar(input, mask, fill)` ≡
|
||||
// `aten.where.self(mask, full_like(input, fill), input)`. The
|
||||
// `full_like + where` sequence is the verified-working path
|
||||
// (test: `where(mask, torch.zeros_like(x), x)` round-trips with
|
||||
// max_diff = 0); we reproduce its exact graph-build order here.
|
||||
// Hand-rolling the formula in any other shape (single-mul, F32
|
||||
// throughout, alternative constant-cast orderings) routes egglog
|
||||
// through a rewrite that returns an F32 buffer downstream-read as
|
||||
// bf16 — the every-other-element-zero pattern.
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let out_dtype = input.dtype;
|
||||
// Build fill_t exactly like translate_full_like does:
|
||||
// constant_float(val).cast(dtype).expand_rhs(reference.shape)
|
||||
let fill_t = self
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(out_dtype)
|
||||
.expand_rhs(input.shape);
|
||||
Ok(self.where_formula(mask, fill_t, input, out_dtype))
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -213,18 +210,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -266,61 +257,54 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
let mut result = a;
|
||||
// maximum_f32 / minimum_f32 internally use `.lt(F32 scalar)`, which
|
||||
// asserts matching tensor dtypes. Without this, clamp on an Int tensor
|
||||
// (e.g. Qwen3-MoE routes `cache_position.clamp(...)` through here)
|
||||
// panics inside luminal core. Promote to F32 around the bounds check
|
||||
// and cast back at the end.
|
||||
let original_dtype = a.dtype;
|
||||
let needs_promote = original_dtype != DType::F32;
|
||||
let mut result = if needs_promote { a.cast(DType::F32) } else { a };
|
||||
if let Some(min) = min_val {
|
||||
result = result.maximum_f32(min);
|
||||
}
|
||||
if let Some(max) = max_val {
|
||||
result = result.minimum_f32(max);
|
||||
}
|
||||
if needs_promote {
|
||||
result = result.cast(original_dtype);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
/// Compute `erf(a)` via the Abramowitz & Stegun 7.1.28 approximation
|
||||
/// (max error ~1.5e-7). Shared by `aten.erf.default` and the exact
|
||||
/// `aten.gelu.default` (which is `0.5 * x * (1 + erf(x / sqrt(2)))`).
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
/// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
|
||||
/// where t = 1/(1 + 0.3275911*|x|), poly is degree 5 in Horner form.
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
/// Promotes the input to F32 internally (the approximation constants are
|
||||
/// F32 anyway, and luminal's binary ops assert matching dtypes — running
|
||||
/// this on Bf16 input directly trips the assertion at `a.ge(zero)`).
|
||||
/// Restores the original dtype on return.
|
||||
pub(crate) fn erf_approx(&mut self, a: GraphTensor) -> GraphTensor {
|
||||
let orig = a.dtype;
|
||||
let a = if orig == DType::F32 { a } else { a.cast(DType::F32) };
|
||||
let ax = a.abs();
|
||||
let x2 = a * a;
|
||||
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
|
||||
let poly = t
|
||||
* (t * (t
|
||||
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
|
||||
+ (-0.284_496_72_f32))
|
||||
+ 0.254_829_6_f32);
|
||||
let result_abs =
|
||||
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
|
||||
// sign(x) = 2*(x >= 0) - 1
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
|
||||
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
|
||||
let result = result_abs * sign;
|
||||
if orig == DType::F32 { result } else { result.cast(orig) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
# Import Python components
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
import torch.export._unlift as _torch_export_unlift
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
@@ -11,6 +13,49 @@ from .main import luminal_backend, register_backend
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Suppress torch.export's `_guards_fn` insertion when luminal is on the stack.
|
||||
#
|
||||
# When `torch._dynamo.config.automatic_dynamic_shapes=True` (the default) and
|
||||
# a model is called with shapes that vary across calls, dynamo promotes the
|
||||
# changing dim to a SymInt and re-traces. During the re-trace, torch.export's
|
||||
# `_unlift_exported_program_lifted_states` (in `torch/export/_unlift.py`)
|
||||
# generates a `_guards_fn` submodule whose body closes over `L` — dynamo's
|
||||
# locals namespace. When aot_autograd later evaluates the resulting
|
||||
# GraphModule via fx.Interpreter, the closure's free `L` reference doesn't
|
||||
# resolve and we get
|
||||
# NameError: name 'L' is not defined
|
||||
# (gemma3 + StaticCache reproduces this deterministically).
|
||||
#
|
||||
# torch.export's own opt-out — `_ok_to_generate_guards_fn` — already walks
|
||||
# the call stack for filename patterns to suppress guard generation for
|
||||
# specific embedders (executorch, modai, on_device_ai, torchao). Add
|
||||
# "luminal" to the same suppression set by monkey-patching the function.
|
||||
# Net effect: torch.export never inserts `_guards_fn`, so re-tracing
|
||||
# succeeds, dynamic-shape compile-once-run-many works, and StaticCache
|
||||
# decode loops compile in ~one shot instead of per-token.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_orig_ok_to_generate_guards_fn = _torch_export_unlift._ok_to_generate_guards_fn
|
||||
|
||||
|
||||
def _luminal_aware_ok_to_generate_guards_fn() -> bool:
|
||||
"""Return False whenever luminal is anywhere in the call stack."""
|
||||
import inspect
|
||||
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
while frame is not None:
|
||||
if "luminal" in frame.f_code.co_filename:
|
||||
return False
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame # avoid reference cycle
|
||||
return _orig_ok_to_generate_guards_fn()
|
||||
|
||||
|
||||
_torch_export_unlift._ok_to_generate_guards_fn = _luminal_aware_ok_to_generate_guards_fn
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,10 +8,6 @@ from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class DTypeBoundaryWarning(UserWarning):
|
||||
"""Warns when the PyTorch boundary must cast input data before execution."""
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
@@ -82,10 +77,7 @@ class CompiledModel:
|
||||
)
|
||||
user_inputs = inputs
|
||||
|
||||
# Use the first *user* input for device detection — when torch.compile
|
||||
# has lifted SymInts or weights into the call args, `inputs[0]` may not
|
||||
# be a tensor. user_inputs has been filtered to actual tensors.
|
||||
input_device = user_inputs[0].device if user_inputs else torch.device("cpu")
|
||||
input_device = inputs[0].device if inputs else torch.device("cpu")
|
||||
|
||||
# Auto-detect dynamic dims from input shapes
|
||||
if self._has_dynamic_dims:
|
||||
@@ -100,15 +92,6 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if tensor.dtype != expected_dtype:
|
||||
warnings.warn(
|
||||
"Luminal compiled input "
|
||||
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
|
||||
f"expects {expected_dtype}; converting at every call will "
|
||||
"allocate/copy input data.",
|
||||
DTypeBoundaryWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
@@ -149,11 +132,6 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
@@ -169,12 +147,11 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -202,13 +179,9 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype in _int_dtypes:
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
|
||||
@@ -10,22 +10,24 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device."""
|
||||
# Dynamo can prefix `example_inputs` with SymInt entries when shapes are
|
||||
# dynamic — those have no `.device`. Pick the first real tensor instead.
|
||||
first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None)
|
||||
device = first_tensor.device if first_tensor is not None else torch.device("cpu")
|
||||
"""Pick the best built-in factory capsule based on input device.
|
||||
|
||||
Walks example_inputs for the first Tensor to read .device from. With
|
||||
dynamic=True, dynamo may pass SymInt/SymFloat alongside Tensors and those
|
||||
don't have a .device attribute — falling back to CPU on a SymInt-only call
|
||||
would silently route to the wrong backend, so prefer the first Tensor."""
|
||||
device = torch.device("cpu")
|
||||
for v in example_inputs or ():
|
||||
if isinstance(v, torch.Tensor):
|
||||
device = v.device
|
||||
break
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError) as exc:
|
||||
raise RuntimeError(
|
||||
"CUDA input was provided, but luminal_python was not built with "
|
||||
"the cuda feature. Rebuild with `maturin develop --features cuda` "
|
||||
"or run through `run_tests_cuda.sh`/the Modal CUDA test runner."
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
@@ -83,7 +85,7 @@ def register_backend(factory_capsule):
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule)
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
@@ -102,7 +104,7 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -110,8 +112,16 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule)
|
||||
search_iterations = None
|
||||
if options is not None:
|
||||
search_iterations = options.get("search_iterations")
|
||||
return pt2_backend(
|
||||
gm,
|
||||
example_inputs,
|
||||
factory=factory_capsule,
|
||||
search_iterations=search_iterations,
|
||||
)
|
||||
|
||||
@@ -110,35 +110,45 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _decomp_table():
|
||||
"""Decomposition table for `ep.run_decompositions()` that preserves SDPA.
|
||||
def _extract_pt2_constants(pt2_path):
|
||||
"""Extract tensor constants from the new flat PT2 format (torch >= 2.6).
|
||||
|
||||
The default table decomposes `aten.scaled_dot_product_attention.default`
|
||||
into ~20 ops (matmul/softmax + an `eq.Scalar`/`logical_not`/`any.dim`/
|
||||
`where`/`full_like` "all-masked" sentinel chain). We translate SDPA as a
|
||||
single fused op via `translate_sdpa`, so we strip the SDPA decompositions
|
||||
here to let them survive into the FX graph the translator walks.
|
||||
In the new format, inline constants (e.g. ``torch.tensor([1., 2.])``) are
|
||||
stored in ``serialized_constants.pt`` rather than individual ZIP entries.
|
||||
The Rust parser skips them (returns constants_config=None); this function
|
||||
reads them back and returns a cpu_ptrs dict ready for _load_cpu_weights.
|
||||
|
||||
Returns (keep_alive, cpu_ptrs) — keep_alive must stay alive until after
|
||||
_load_cpu_weights returns (set_weight_from_ptr copies the bytes).
|
||||
"""
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
try:
|
||||
from torch.export import default_decompositions
|
||||
except ImportError:
|
||||
return None
|
||||
table = default_decompositions()
|
||||
sdpa_ops = [
|
||||
torch.ops.aten.scaled_dot_product_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
|
||||
]
|
||||
for op in sdpa_ops:
|
||||
table.pop(op, None)
|
||||
return table
|
||||
with zipfile.ZipFile(pt2_path) as z:
|
||||
if "serialized_constants.pt" not in z.namelist():
|
||||
return [], {}
|
||||
data = z.read("serialized_constants.pt")
|
||||
except Exception:
|
||||
return [], {}
|
||||
|
||||
constants = torch.load(io.BytesIO(data), weights_only=False)
|
||||
if not constants:
|
||||
return [], {}
|
||||
|
||||
keep_alive = []
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in constants.items():
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
keep_alive.append(t)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
cpu_ptrs[name] = (t.data_ptr(), n_bytes, _torch_dtype_code(t.dtype))
|
||||
return keep_alive, cpu_ptrs
|
||||
|
||||
|
||||
def _save_and_compile(
|
||||
ep_or_path, factory, search_iterations, original_weights=None, user_indices=None
|
||||
):
|
||||
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
@@ -173,185 +183,38 @@ def _save_and_compile(
|
||||
pt2_path, "", search_iterations, factory, weight_device_ptrs
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
# Load CPU weights; also load inline tensor constants from the new flat
|
||||
# PT2 format (torch >= 2.6 stores them in serialized_constants.pt).
|
||||
const_keep_alive, const_cpu_weights = _extract_pt2_constants(pt2_path)
|
||||
cpu_weights.update(const_cpu_weights)
|
||||
_load_cpu_weights(compiled, cpu_weights)
|
||||
del const_keep_alive # bytes were copied by set_weight_from_ptr
|
||||
|
||||
return CompiledModel(
|
||||
compiled, weight_refs=keep_alive, user_indices=user_indices
|
||||
)
|
||||
return CompiledModel(compiled, weight_refs=keep_alive)
|
||||
finally:
|
||||
if owns_tmpdir and tmpdir:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _safe_int_bound(value):
|
||||
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
|
||||
|
||||
Range bounds returned by ShapeEnv can be sympy `Infinity` / `-Infinity`
|
||||
(as well as the internal `int_oo` sentinel), which both raise on `int(...)`.
|
||||
Treat anything non-finite — and anything that simply doesn't coerce — as
|
||||
"no bound."
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
# Stringify is robust against the various sentinel types: sympy.Infinity,
|
||||
# torch.utils._sympy.numbers.IntInfinity, etc. all stringify to "oo"/"-oo".
|
||||
s = str(value)
|
||||
if "oo" in s or "inf" in s.lower():
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError, OverflowError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def _strip_symint_placeholders(gm, example_inputs):
|
||||
"""Rewrite SymInt graph inputs into tensor.size(d) calls, then drop them.
|
||||
|
||||
When Dynamo decides a dim is dynamic it emits the symbol as a separate
|
||||
placeholder (e.g. `s77`) alongside the user's tensor (whose FakeTensor shape
|
||||
references the same symbol). torch.export.export rejects mixed
|
||||
SymInt/Tensor positional args, and the Rust pipeline doesn't model SymInt
|
||||
inputs anyway — so we replace each SymInt placeholder's uses with
|
||||
`aten.sym_size.int(tensor, dim)` for the first tensor placeholder whose
|
||||
example_value's shape[dim] matches the symbol, then erase the placeholder.
|
||||
|
||||
Returns `(post_strip_inputs, kept_indices, ok)` where:
|
||||
- `post_strip_inputs` is `example_inputs` filtered to tensor-only entries
|
||||
- `kept_indices` is the indices into `example_inputs` we kept (used by
|
||||
the caller to compose with any prior input filter, e.g. lifted-weight
|
||||
re-internalization, when handing `user_indices` to CompiledModel)
|
||||
- `ok` is False when at least one SymInt placeholder couldn't be
|
||||
rewritten (compound expression with users, or no matching tensor dim);
|
||||
the caller should fall back to no-dynamic export in that case.
|
||||
"""
|
||||
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
|
||||
|
||||
# Collect (placeholder_node, example_input_idx) for every SymInt placeholder.
|
||||
symint_entries = []
|
||||
tensor_entries = []
|
||||
for idx, node in enumerate(placeholders):
|
||||
ev = node.meta.get("example_value")
|
||||
if isinstance(ev, torch.SymInt) or (
|
||||
ev is None
|
||||
and idx < len(example_inputs)
|
||||
and isinstance(example_inputs[idx], torch.SymInt)
|
||||
):
|
||||
symint_entries.append((node, idx))
|
||||
else:
|
||||
tensor_entries.append((node, idx))
|
||||
|
||||
if not symint_entries:
|
||||
return example_inputs, list(range(len(example_inputs))), True
|
||||
|
||||
# Build a symbol -> (tensor_node, dim) lookup from the tensor placeholders'
|
||||
# example FakeTensor shapes. Any tensor whose shape[d] is the SymInt
|
||||
# is a valid source — pick the first.
|
||||
sym_to_source = {}
|
||||
for t_node, _ in tensor_entries:
|
||||
ev = t_node.meta.get("example_value")
|
||||
if not torch.is_tensor(ev):
|
||||
continue
|
||||
for d, s in enumerate(ev.shape):
|
||||
if isinstance(s, torch.SymInt):
|
||||
key = str(s.node.expr)
|
||||
sym_to_source.setdefault(key, (t_node, d))
|
||||
|
||||
# Rewrite each SymInt placeholder's uses to sym_size calls, then erase it.
|
||||
all_clean = True
|
||||
for s_node, _ in symint_entries:
|
||||
ev = s_node.meta.get("example_value")
|
||||
if ev is None:
|
||||
all_clean = False
|
||||
continue
|
||||
# The placeholder's example_value is the SymInt itself; its expr is the
|
||||
# symbol name (or a compound expression we can't lift this way).
|
||||
expr_str = str(ev.node.expr)
|
||||
source = sym_to_source.get(expr_str)
|
||||
if source is None:
|
||||
# Compound expression or no tensor carries this symbol — bail.
|
||||
if len(s_node.users) > 0:
|
||||
all_clean = False
|
||||
continue
|
||||
gm.graph.erase_node(s_node)
|
||||
continue
|
||||
|
||||
if len(s_node.users) > 0:
|
||||
t_node, dim = source
|
||||
with gm.graph.inserting_after(t_node):
|
||||
size_node = gm.graph.call_function(
|
||||
torch.ops.aten.sym_size.int, (t_node, dim)
|
||||
)
|
||||
size_node.meta["val"] = ev
|
||||
size_node.meta["example_value"] = ev
|
||||
s_node.replace_all_uses_with(size_node)
|
||||
gm.graph.erase_node(s_node)
|
||||
|
||||
if not all_clean:
|
||||
# Recompile defensively even on partial success — some erases may have
|
||||
# happened. Caller will decide whether to proceed.
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
return example_inputs, list(range(len(example_inputs))), False
|
||||
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
# Filter the runtime example_inputs to drop the stripped SymInt entries.
|
||||
kept_indices = [idx for _, idx in tensor_entries]
|
||||
keep_set = set(kept_indices)
|
||||
new_inputs = [v for i, v in enumerate(example_inputs) if i in keep_set]
|
||||
return new_inputs, kept_indices, True
|
||||
|
||||
|
||||
def _build_dynamic_shapes_from_gm(gm):
|
||||
"""Construct a torch.export.export `dynamic_shapes` spec from FX metadata.
|
||||
|
||||
Walks each tensor placeholder's `meta['example_value']` FakeTensor and
|
||||
marks every SymInt dim as `Dim.AUTO`. Sharing/equality relationships
|
||||
between symbolic dims are already encoded in the FakeTensor shapes —
|
||||
torch.export's symbolic-shape engine recovers them during the trace, so
|
||||
we don't need to allocate named `Dim` objects ourselves.
|
||||
|
||||
The returned spec is wrapped under `{"args": (...)}` because Dynamo's
|
||||
`GraphModule.forward(*args, **kwargs)` signature treats positional inputs
|
||||
as the `args` tuple.
|
||||
|
||||
Returns None if there are no symbolic dims to mark.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
|
||||
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
|
||||
|
||||
per_input_spec = []
|
||||
saw_dynamic = False
|
||||
for node in placeholders:
|
||||
ev = node.meta.get("example_value")
|
||||
if not torch.is_tensor(ev):
|
||||
per_input_spec.append(None)
|
||||
continue
|
||||
spec = {}
|
||||
for d, s in enumerate(ev.shape):
|
||||
if isinstance(s, torch.SymInt):
|
||||
spec[d] = Dim.AUTO
|
||||
saw_dynamic = True
|
||||
per_input_spec.append(spec if spec else None)
|
||||
|
||||
if not saw_dynamic:
|
||||
return None
|
||||
return {"args": tuple(per_input_spec)}
|
||||
|
||||
|
||||
def _reinternalize_lifted_params(gm, example_inputs):
|
||||
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
|
||||
|
||||
torch.compile lifts model parameters out of the module and passes them as
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
lifting by registering them as buffers and replacing the placeholder nodes
|
||||
with get_attr nodes.
|
||||
|
||||
SymInt/SymFloat/SymBool values in example_inputs are rejected by
|
||||
torch.export.export as user inputs ("Unsupported input type
|
||||
<class 'torch.SymInt'>"). We don't restructure the graph for this — we
|
||||
specialize the *value* to its concrete hint (a plain int/float/bool), which
|
||||
torch.export accepts. The placeholder stays in place; the traced graph
|
||||
proceeds as if dynamo had specialized this dim. Invisible to callers of
|
||||
`torch.compile(..., backend=luminal_backend)`.
|
||||
|
||||
Returns (gm, user_inputs, original_weights) where:
|
||||
- user_inputs contains only the real inputs
|
||||
- user_inputs contains only real inputs (Tensors and concrete scalars)
|
||||
- original_weights maps buffer name -> original tensor (for zero-copy device pointers)
|
||||
"""
|
||||
buffer_indices = []
|
||||
@@ -385,12 +248,47 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
user_inputs = (
|
||||
raw_user_inputs = (
|
||||
[example_inputs[i] for i in user_indices]
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
return gm, user_inputs, original_weights, user_indices
|
||||
user_inputs = [
|
||||
_specialize_sym_scalar(v) if _is_sym_scalar(v) else v
|
||||
for v in raw_user_inputs
|
||||
]
|
||||
return gm, user_inputs, original_weights
|
||||
|
||||
|
||||
def _is_sym_scalar(val) -> bool:
|
||||
"""True for torch SymInt/SymFloat/SymBool — anything torch.export's fakify
|
||||
rejects as a user input. Plain int/float/bool are fine; only the symbolic
|
||||
wrappers need specialization."""
|
||||
if val is None:
|
||||
return False
|
||||
if isinstance(val, torch.Tensor):
|
||||
return False
|
||||
return type(val).__name__ in ("SymInt", "SymFloat", "SymBool") or isinstance(
|
||||
val, (torch.SymInt, torch.SymFloat, torch.SymBool)
|
||||
)
|
||||
|
||||
|
||||
def _specialize_sym_scalar(val):
|
||||
"""Resolve a SymInt/SymFloat/SymBool to its concrete hint. Falls back to
|
||||
str(val) -> primitive parse if the SymNode hint is missing."""
|
||||
try:
|
||||
if isinstance(val, torch.SymBool):
|
||||
return bool(val)
|
||||
if isinstance(val, torch.SymFloat):
|
||||
return float(val)
|
||||
return int(val)
|
||||
except Exception:
|
||||
# SymNodes without a hint — try parsing the str form as a last resort.
|
||||
s = str(val)
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return float(s)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -405,238 +303,120 @@ def compile(
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
dynamic_shapes=None,
|
||||
):
|
||||
"""Compile a PyTorch model to run on Luminal via PT2 pipeline.
|
||||
|
||||
Args:
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor — or a list/tuple of tensors for
|
||||
multi-input models.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
|
||||
symbolic dim is needed.
|
||||
* `None` (default): leave shapes static.
|
||||
* `int`: mark that dim of the (first) input as `Dim.AUTO`.
|
||||
* `Iterable[int]`: mark each listed dim of the first input.
|
||||
* `"auto"`: mark every non-trivial dim (size > 1) of the
|
||||
first input as `Dim.AUTO` — works for floating-point and
|
||||
integer inputs alike.
|
||||
dynamic_shapes: Direct passthrough to `torch.export.export`'s
|
||||
`dynamic_shapes` argument. When provided, takes precedence over
|
||||
`dynamic_dim`. Use this for full control: per-input specs,
|
||||
`Dim("name", min=, max=)` ranges, shared dims across inputs, etc.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(
|
||||
example_input
|
||||
if isinstance(example_input, (list, tuple))
|
||||
else [example_input]
|
||||
)
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if isinstance(example_input, (list, tuple)):
|
||||
example_args = tuple(example_input)
|
||||
else:
|
||||
example_args = (example_input,)
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule([example_input])
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
|
||||
# Build dynamic_shapes from the convenience knob if the caller didn't
|
||||
# hand us a full spec. `dynamic_dim=None` falls back to the legacy
|
||||
# `"auto"` behavior (mark the last axis of an integer input as dynamic)
|
||||
# so callers that relied on the previous default keep working.
|
||||
if dynamic_shapes is None:
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = _legacy_auto_dim(example_args)
|
||||
if dynamic_dim is not None:
|
||||
dynamic_shapes = _build_dynamic_shapes_from_dim_arg(
|
||||
dynamic_dim, example_args
|
||||
)
|
||||
|
||||
# `torch.export.export` is finicky: when `dynamic_shapes` is set it
|
||||
# validates the spec against the example shapes and raises on any
|
||||
# disagreement (e.g. the user marked a dim as dynamic but their model
|
||||
# specialises it to a constant). Fall back to a static export so the
|
||||
# caller still gets a usable CompiledModel rather than a hard error.
|
||||
ep = None
|
||||
if dynamic_shapes is not None:
|
||||
try:
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_args,
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
except Exception:
|
||||
ep = None
|
||||
|
||||
# Try dynamic dimension export
|
||||
candidate_dims = []
|
||||
if isinstance(dynamic_dim, int):
|
||||
candidate_dims = [dynamic_dim]
|
||||
elif dynamic_dim == "auto" and example_input.dim() >= 2:
|
||||
if not example_input.is_floating_point():
|
||||
candidate_dims = [example_input.dim() - 1]
|
||||
|
||||
if candidate_dims:
|
||||
from torch.export import Dim
|
||||
|
||||
for dim_idx in candidate_dims:
|
||||
try:
|
||||
seq = Dim("seq", min=2)
|
||||
arg_shapes = {dim_idx: seq}
|
||||
kwarg_shapes = {k: None for k in kwargs}
|
||||
dynamic_shapes = (
|
||||
(arg_shapes,) + tuple(kwarg_shapes.values())
|
||||
if kwarg_shapes
|
||||
else (arg_shapes,)
|
||||
)
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
(example_input,),
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ep is None:
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_args,
|
||||
(example_input,),
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def _drop_input_guards(ep):
|
||||
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
|
||||
def pt2_backend(gm, example_inputs, factory=None, search_iterations=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
|
||||
``i = torch.tensor(2)``), torch.export records two equivalent input
|
||||
guards: ``L['i'].item() == 2`` (referencing the original local source)
|
||||
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
|
||||
Two failures stack on top of each other:
|
||||
|
||||
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
|
||||
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
|
||||
a literal ``L`` reference in the generated ``_guards_fn`` and raising
|
||||
``NameError: name 'L' is not defined`` during retracing.
|
||||
2. Even after dropping the unresolvable guard, the surviving
|
||||
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
|
||||
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
|
||||
a graph break.
|
||||
|
||||
These guards exist solely to validate inputs at runtime in eager-mode
|
||||
consumers of the ExportedProgram; the luminal compiler does its own
|
||||
input shape/dtype checks against the compiled graph signature, so we
|
||||
are not losing any safety by clearing them.
|
||||
"""
|
||||
|
||||
if hasattr(ep, "_guards_code"):
|
||||
ep._guards_code = []
|
||||
|
||||
|
||||
def _drop_dead_data_dependent_ops(gm):
|
||||
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
|
||||
|
||||
When dynamo specializes a 0-d int input by tracing through ``.item()``,
|
||||
the resulting graph may contain a dead ``aten.item.default`` node whose
|
||||
output is never consumed. luminal's translator does not lower
|
||||
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
|
||||
node in the graph causes a graph break at compile time. Eliminating it
|
||||
keeps the (correctly specialized) downstream graph in a single subgraph.
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
changed = False
|
||||
for node in list(graph.nodes):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
|
||||
and len(node.users) == 0
|
||||
):
|
||||
graph.erase_node(node)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
"""Match the historical `dynamic_dim="auto"` heuristic.
|
||||
|
||||
Returns the last axis of the first input when that input is a 2-D-or-
|
||||
larger integer tensor (the typical token-id sequence pattern), and
|
||||
`None` otherwise. Float inputs and 1-D tensors fall through to the
|
||||
static export path the legacy code did.
|
||||
"""
|
||||
if not example_args:
|
||||
return None
|
||||
first = example_args[0]
|
||||
if not torch.is_tensor(first):
|
||||
return None
|
||||
if first.is_floating_point():
|
||||
return None
|
||||
if first.dim() < 2:
|
||||
return None
|
||||
return first.dim() - 1
|
||||
|
||||
|
||||
def _build_dynamic_shapes_from_dim_arg(dynamic_dim, example_args):
|
||||
"""Translate the `dynamic_dim` shorthand into a full `dynamic_shapes` spec.
|
||||
|
||||
Always targets the first positional input — multi-input dynamic specs
|
||||
require the caller to use `dynamic_shapes=` directly so they can name
|
||||
which input each dim belongs to.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
|
||||
if not example_args:
|
||||
return None
|
||||
first = example_args[0]
|
||||
if not torch.is_tensor(first):
|
||||
return None
|
||||
|
||||
if isinstance(dynamic_dim, int):
|
||||
dims = [dynamic_dim]
|
||||
elif isinstance(dynamic_dim, str) and dynamic_dim == "auto":
|
||||
# Mark every dim with size > 1 as dynamic. Dim.AUTO leaves
|
||||
# torch.export to pick a Dim per axis and infer relationships from
|
||||
# the example FakeTensor.
|
||||
dims = [d for d, s in enumerate(first.shape) if int(s) > 1]
|
||||
elif hasattr(dynamic_dim, "__iter__"):
|
||||
dims = [int(d) for d in dynamic_dim]
|
||||
else:
|
||||
return None
|
||||
|
||||
if not dims:
|
||||
return None
|
||||
|
||||
spec = {d: Dim.AUTO for d in dims}
|
||||
rest = (None,) * (len(example_args) - 1)
|
||||
return (spec,) + rest
|
||||
|
||||
|
||||
def _eager_pt2_compile(
|
||||
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
|
||||
):
|
||||
"""Run torch.export → save → Rust compile end-to-end. Returns CompiledModel.
|
||||
|
||||
Factored out so both the eager (static-shapes) and lazy (dynamic-shapes)
|
||||
backend paths share a single implementation.
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import gc
|
||||
|
||||
try:
|
||||
ep = torch.export.export(
|
||||
gm,
|
||||
tuple(user_inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**_export_kwargs(),
|
||||
)
|
||||
except Exception:
|
||||
# If torch.export rejects the dynamic spec (e.g. user code introduced
|
||||
# a constraint we didn't model), retry without it. Better to lose the
|
||||
# dynamic-dim optimization than to hand the user a hard failure.
|
||||
if dynamic_shapes is None:
|
||||
raise
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
# LUM-499: drop dynamo-emitted input guards before run_decompositions
|
||||
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
|
||||
# data-dependent .item() calls and unresolved `L[...]` references.
|
||||
_drop_input_guards(ep)
|
||||
_drop_dead_data_dependent_ops(ep.graph_module)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
if search_iterations is None:
|
||||
search_iterations = 10
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers
|
||||
# from the EP before saving. The Rust side uses device pointers for these
|
||||
# weights, not the .pt2 file data, so serializing them is pure IO waste
|
||||
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
|
||||
gm = gm.eval()
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# Detect USER_INPUT_MUTATION outputs (e.g., in-place KV cache updates).
|
||||
# These must be written back to the original input tensors after each call.
|
||||
# Only USER_OUTPUT results are returned to the torch.compile caller.
|
||||
try:
|
||||
from torch.export.graph_signature import OutputKind
|
||||
|
||||
mutation_mappings = [] # list of (compiled_output_idx, user_input_idx)
|
||||
user_output_indices = []
|
||||
for i, spec in enumerate(ep.graph_signature.output_specs):
|
||||
if spec.kind == OutputKind.USER_INPUT_MUTATION:
|
||||
# target is 'args_N' — index into user_inputs
|
||||
try:
|
||||
arg_idx = int(spec.target.split("_")[1])
|
||||
mutation_mappings.append((i, arg_idx))
|
||||
except (ValueError, IndexError):
|
||||
user_output_indices.append(i)
|
||||
else:
|
||||
user_output_indices.append(i)
|
||||
except ImportError:
|
||||
mutation_mappings = []
|
||||
user_output_indices = None # unknown; return all outputs
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
|
||||
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
@@ -644,9 +424,9 @@ def _eager_pt2_compile(
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
|
||||
# Save EP to disk, then free it and the traced graph module before Rust
|
||||
# compilation. torch.export clones the state_dict internally; holding ep
|
||||
# alive during compile would double weight memory on GPU.
|
||||
# Save the exported program to disk, then free it and the traced graph module
|
||||
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
|
||||
# holding ep alive during compilation would double the weight memory on GPU.
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep, pt2_path)
|
||||
@@ -657,129 +437,28 @@ def _eager_pt2_compile(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
return _save_and_compile(
|
||||
pt2_path,
|
||||
factory,
|
||||
10,
|
||||
original_weights=original_weights,
|
||||
user_indices=user_indices,
|
||||
result = _save_and_compile(
|
||||
pt2_path, factory, search_iterations, original_weights=original_weights
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
# Wrap the compiled model to handle USER_INPUT_MUTATION: write updated tensors
|
||||
# back into the original input buffers and return only USER_OUTPUT tensors.
|
||||
if mutation_mappings:
|
||||
_compiled = result
|
||||
_mut = mutation_mappings
|
||||
_usr = user_output_indices
|
||||
|
||||
class _LazyDynamicCompiledModel:
|
||||
"""Defers torch.export + Rust compile to the first invocation.
|
||||
def _mutation_wrapper(*inputs):
|
||||
outputs = _compiled(*inputs)
|
||||
for out_idx, inp_idx in _mut:
|
||||
if inp_idx < len(inputs) and out_idx < len(outputs):
|
||||
inputs[inp_idx].copy_(outputs[out_idx])
|
||||
if _usr is not None:
|
||||
return tuple(outputs[i] for i in _usr if i < len(outputs))
|
||||
return outputs
|
||||
|
||||
Calling `torch.export.export(..., dynamic_shapes=...)` from inside a
|
||||
Dynamo backend frame triggers an internal "Guard failed on the same
|
||||
frame it was created" assertion in PyTorch — `torch.export`'s symbolic
|
||||
tracer mutates the ShapeEnv that Dynamo is also relying on for the
|
||||
surrounding compile, leaving the just-installed guards in an
|
||||
inconsistent state. Punting all of that work to the first runtime call
|
||||
sidesteps the issue: by then Dynamo's guard installation is finished,
|
||||
so the shape-env mutations no longer matter.
|
||||
return _mutation_wrapper
|
||||
|
||||
This wrapper is API-compatible with `CompiledModel` for the bits the
|
||||
caller cares about (`__call__`, `has_dynamic_dims`, `dim_params`,
|
||||
`set_dim`). Subsequent calls forward straight to the inner CompiledModel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gm,
|
||||
user_inputs,
|
||||
original_weights,
|
||||
user_indices,
|
||||
dynamic_shapes,
|
||||
factory,
|
||||
):
|
||||
self._gm = gm
|
||||
self._user_inputs = user_inputs
|
||||
self._original_weights = original_weights
|
||||
self._user_indices = user_indices
|
||||
self._dynamic_shapes = dynamic_shapes
|
||||
self._factory = factory
|
||||
self._compiled = None
|
||||
|
||||
def _ensure_compiled(self):
|
||||
if self._compiled is None:
|
||||
self._compiled = _eager_pt2_compile(
|
||||
self._gm,
|
||||
self._user_inputs,
|
||||
self._original_weights,
|
||||
self._user_indices,
|
||||
self._dynamic_shapes,
|
||||
self._factory,
|
||||
)
|
||||
# Drop references to inputs we no longer need — the Rust side
|
||||
# holds onto weights via device pointers / CPU buffers.
|
||||
self._gm = None
|
||||
self._user_inputs = None
|
||||
self._original_weights = None
|
||||
return self._compiled
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
return self._ensure_compiled()(*inputs, **kwargs)
|
||||
|
||||
@property
|
||||
def has_dynamic_dims(self):
|
||||
return self._ensure_compiled().has_dynamic_dims
|
||||
|
||||
@property
|
||||
def dim_params(self):
|
||||
return self._ensure_compiled().dim_params
|
||||
|
||||
def set_dim(self, name, value):
|
||||
return self._ensure_compiled().set_dim(name, value)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, factory=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import copy as _copy
|
||||
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
|
||||
# Work on a private copy of the GraphModule. Dynamo holds onto the
|
||||
# original to install guards and to retrace on shape changes; mutating it
|
||||
# here (erasing SymInt placeholders, re-internalizing lifted weights)
|
||||
# corrupts that bookkeeping and surfaces as cryptic "guard failed on the
|
||||
# same frame" assertions on the next call. The deepcopy is cheap relative
|
||||
# to the rest of the export pipeline.
|
||||
gm = _copy.deepcopy(gm).eval()
|
||||
gm, user_inputs, original_weights, post_lift_indices = _reinternalize_lifted_params(
|
||||
gm, example_inputs
|
||||
)
|
||||
|
||||
# Lift any SymInt placeholders Dynamo emitted alongside the tensor inputs
|
||||
# into `aten.sym_size.int` calls so the re-export sees a tensor-only
|
||||
# signature, then derive the `dynamic_shapes` spec from the surviving
|
||||
# tensor placeholders' FakeTensor shapes. If the strip can't fully clean
|
||||
# the graph (e.g. a compound-expr SymInt with users), we drop dynamic
|
||||
# info and fall back to per-shape recompilation — same as today.
|
||||
user_inputs, post_strip_subindices, strip_ok = _strip_symint_placeholders(
|
||||
gm, user_inputs
|
||||
)
|
||||
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
|
||||
|
||||
# Compose both filter steps into a single user_indices list relative to
|
||||
# the *original* example_inputs Dynamo will pass at runtime — so
|
||||
# CompiledModel.__call__ can drop both lifted weights and SymInt args.
|
||||
user_indices = [post_lift_indices[i] for i in post_strip_subindices]
|
||||
|
||||
if dynamic_shapes is not None:
|
||||
# See `_LazyDynamicCompiledModel` for why dynamic-shape compiles must
|
||||
# be deferred — torch.export with dynamic_shapes mutates ShapeEnv state
|
||||
# Dynamo is still relying on, and running it inside the backend frame
|
||||
# corrupts the freshly-installed guards.
|
||||
return _LazyDynamicCompiledModel(
|
||||
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
|
||||
)
|
||||
|
||||
return _eager_pt2_compile(
|
||||
gm, user_inputs, original_weights, user_indices, None, factory
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
from luminal.compiled_model import DTypeBoundaryWarning
|
||||
|
||||
|
||||
class BoundaryNoopModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.dtype is torch.bool:
|
||||
return x | torch.zeros((), dtype=torch.bool, device=x.device)
|
||||
return x + torch.zeros((), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeCase:
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
values: Callable[[], torch.Tensor]
|
||||
xfail_reason: str | None = None
|
||||
|
||||
|
||||
DTYPE_CASES = [
|
||||
DTypeCase(
|
||||
"bool",
|
||||
torch.bool,
|
||||
lambda: torch.tensor([True, False, True], dtype=torch.bool),
|
||||
),
|
||||
DTypeCase(
|
||||
"uint8",
|
||||
torch.uint8,
|
||||
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int8",
|
||||
torch.int8,
|
||||
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int16",
|
||||
torch.int16,
|
||||
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
|
||||
),
|
||||
DTypeCase(
|
||||
"int32",
|
||||
torch.int32,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float16",
|
||||
torch.float16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
|
||||
),
|
||||
DTypeCase(
|
||||
"bfloat16",
|
||||
torch.bfloat16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
|
||||
),
|
||||
DTypeCase(
|
||||
"float32",
|
||||
torch.float32,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_f32_exact",
|
||||
torch.float64,
|
||||
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
xfail_reason=(
|
||||
"Luminal currently collapses integer inputs through i32 at the "
|
||||
"compiled boundary, so out-of-range int64 values lose information."
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
torch.float64,
|
||||
lambda: torch.tensor(
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
xfail_reason=(
|
||||
"Luminal currently routes float64 no-op computation through f32 "
|
||||
"storage/outputs before restoring the PyTorch-visible dtype."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _cuda_skip_reason() -> str | None:
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA is not available"
|
||||
|
||||
try:
|
||||
from luminal.luminal import _cuda_lite_factory_capsule
|
||||
|
||||
_cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError, RuntimeError) as exc:
|
||||
return f"luminal_python was not built with CUDA support: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
|
||||
def boundary_device(request) -> torch.device:
|
||||
device_name = request.param
|
||||
if device_name == "cuda":
|
||||
skip_reason = _cuda_skip_reason()
|
||||
if skip_reason is not None:
|
||||
pytest.skip(skip_reason)
|
||||
return torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(
|
||||
case,
|
||||
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
|
||||
if case.xfail_reason is not None
|
||||
else (),
|
||||
id=case.name,
|
||||
)
|
||||
for case in DTYPE_CASES
|
||||
],
|
||||
)
|
||||
def test_boundary_noop_preserves_dtype_and_values(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
x = case.values().to(boundary_device)
|
||||
expected = model(x)
|
||||
actual = compiled(x)
|
||||
|
||||
assert isinstance(actual, torch.Tensor)
|
||||
assert actual.dtype == expected.dtype
|
||||
assert torch.equal(actual.cpu(), expected.cpu())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
|
||||
compiled(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
|
||||
],
|
||||
)
|
||||
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
warnings.simplefilter("always")
|
||||
compiled(x)
|
||||
|
||||
dtype_boundary_warnings = [
|
||||
record
|
||||
for record in records
|
||||
if issubclass(record.category, DTypeBoundaryWarning)
|
||||
]
|
||||
assert dtype_boundary_warnings == []
|
||||
@@ -1,312 +0,0 @@
|
||||
"""End-to-end tests for dynamic-shape support through ``torch.compile``.
|
||||
|
||||
These exercise the path that the standard PyTorch user hits — i.e. wrapping a
|
||||
model with ``torch.compile(model, backend=luminal_backend)`` and calling it
|
||||
with varying input shapes. The luminal backend is expected to recognise
|
||||
Dynamo-emitted SymInt placeholders, propagate the symbolic dims through the
|
||||
PT2 export, and reuse a single compiled graph across shape changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal.main import luminal_backend
|
||||
|
||||
|
||||
def _compile(model, count_holder):
|
||||
def wrapper(gm, example_inputs):
|
||||
out = luminal_backend(gm, example_inputs)
|
||||
count_holder.append(1)
|
||||
return out
|
||||
|
||||
return torch.compile(model, backend=wrapper)
|
||||
|
||||
|
||||
def _compile_with_dynamic_true(model, count_holder):
|
||||
def wrapper(gm, example_inputs):
|
||||
out = luminal_backend(gm, example_inputs)
|
||||
count_holder.append(1)
|
||||
return out
|
||||
|
||||
return torch.compile(model, backend=wrapper, dynamic=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_automatic_dynamic():
|
||||
"""Make sure the tests run with Dynamo's automatic-dynamic detection on.
|
||||
|
||||
Other tests in the suite flip this off; reset state between tests so the
|
||||
cache that backs the previous suppression doesn't carry over. We also
|
||||
raise the recompile limit because Dynamo defaults to 1 (which trips
|
||||
before automatic-dynamic kicks in) and have to do an extra reset to
|
||||
drop any cached frames from prior tests in the suite.
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
|
||||
prev_limit = torch._dynamo.config.recompile_limit
|
||||
torch._dynamo.config.automatic_dynamic_shapes = True
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
|
||||
torch._dynamo.config.recompile_limit = prev_limit
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — the dynamic-shape backend wiring is exercised end to end against the cuda_lite runtime",
|
||||
)
|
||||
def test_dynamic_seq_via_torch_compile_reuses_compile(device: torch.device):
|
||||
"""A varying seq dim should produce two backend invocations total.
|
||||
|
||||
First call: Dynamo emits a static-shape graph (no SymInt placeholders).
|
||||
Second call: Dynamo detects the size mismatch and re-traces with the dim
|
||||
marked dynamic. From that point on, every subsequent shape variation
|
||||
must be served by the same compiled graph — no further backend calls.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
s = x.shape[0]
|
||||
return x.reshape(s, -1).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
assert len(counts) == 2, (
|
||||
f"expected exactly 2 backend invocations (one static, one dynamic), got {len(counts)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_dynamic_via_torch_compile_with_lifted_weights(device: torch.device):
|
||||
"""Combines lifted-weight re-internalization with the SymInt strip.
|
||||
|
||||
Most real models hit both paths simultaneously (Dynamo lifts every
|
||||
`nn.Parameter` AND emits SymInt placeholders for any dim that varies
|
||||
between calls), so the two filters need to compose without losing
|
||||
track of input positions.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin(x).sum(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [3, 4, 5, 6, 4]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
assert len(counts) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_compound_shape_expression_auto_resolves(device: torch.device):
|
||||
"""Affine shape expressions (`2*s` etc.) should still let auto-detect work.
|
||||
|
||||
The `auto_set_dims_from_input_shapes` Rust path used to only handle bare
|
||||
`Term::Var(c)` shape expressions and silently skip anything else, leaving
|
||||
affine dims unresolved on the CompiledGraph and the corresponding output
|
||||
sizes stale. We now invert single-variable affine forms `a*x + b` by
|
||||
sampling two probe points; this test exercises that path by constructing
|
||||
a model whose first axis evolves into `2*s` after a `cat` along it.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# `cat([x, x], dim=0)` doubles the leading dim — torch.export
|
||||
# encodes the resulting shape as `2*s` rather than `s`.
|
||||
return torch.cat([x, x], dim=0).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_torch_compile_dynamic_true_single_compile(device: torch.device):
|
||||
"""`torch.compile(model, backend=luminal_backend, dynamic=True)` works.
|
||||
|
||||
`dynamic=True` skips Dynamo's specialise-then-promote dance and emits a
|
||||
fully-symbolic graph from the first call. The luminal backend must
|
||||
handle the SymInt placeholders Dynamo passes alongside the tensor
|
||||
inputs and reuse a single compiled graph across all shape variations —
|
||||
one backend invocation total, in contrast to the 2 we'd see under
|
||||
automatic-dynamic mode (which burns a static compile on call 1 before
|
||||
promoting to dynamic on call 2).
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
s = x.shape[0]
|
||||
return x.reshape(s, -1).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile_with_dynamic_true(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
assert len(counts) == 1, (
|
||||
f"dynamic=True should produce a single backend invocation, got {len(counts)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_explicit_compile_float_input_dynamic(device: torch.device):
|
||||
"""`luminal.pt2.compile(model, example, dynamic_dim=...)` with a float input.
|
||||
|
||||
The previous version of `compile()` silently fell back to a static export
|
||||
for floating-point inputs (the `"auto"` heuristic was integer-only). The
|
||||
new spec accepts an explicit `int` or `Iterable[int]` regardless of dtype,
|
||||
and `"auto"` now picks every non-trivial axis.
|
||||
"""
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return (x * 2.0).sum(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
example = torch.randn(4, 8, device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=3, dynamic_dim=0)
|
||||
|
||||
assert compiled.has_dynamic_dims, "compile() should have produced a dynamic graph"
|
||||
|
||||
for shp in [4, 5, 6, 7]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
# `compile()` returns a tuple of outputs; extract the first.
|
||||
out_t = out[0] if isinstance(out, tuple) else out
|
||||
assert out_t.shape == ref.shape, (
|
||||
f"shape={shp}: got {out_t.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out_t, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_explicit_compile_dynamic_shapes_passthrough(device: torch.device):
|
||||
"""`luminal.pt2.compile(... , dynamic_shapes=...)` accepts a full spec.
|
||||
|
||||
Lets the caller specify named `Dim` objects with ranges — the previous
|
||||
API hardcoded `Dim("seq", min=2)` for any single dynamic dim.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.mean(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
example = torch.randn(4, 8, device=device)
|
||||
seq = Dim("seq_len", min=2, max=64)
|
||||
compiled = luminal_compile(
|
||||
model, example, search_iterations=3, dynamic_shapes=({0: seq},)
|
||||
)
|
||||
assert compiled.has_dynamic_dims
|
||||
# torch.export rewrites user-supplied Dim names to its internal s77/s33
|
||||
# convention before saving — what we actually need to verify is that a
|
||||
# symbolic dim was registered, not what label it ended up with.
|
||||
assert len(compiled.dim_params) == 1, (
|
||||
f"expected exactly one dynamic dim, got {compiled.dim_params}"
|
||||
)
|
||||
|
||||
for shp in [3, 5, 16]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
out_t = out[0] if isinstance(out, tuple) else out
|
||||
assert out_t.shape == ref.shape
|
||||
assert torch.allclose(out_t, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_dynamic_two_dim_via_torch_compile(device: torch.device):
|
||||
"""Both batch and seq dynamic — should still reuse a single compile."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
# Vary batch and seq together so Dynamo marks both as dynamic.
|
||||
for batch, seq in [(2, 8), (3, 9), (4, 10), (5, 11), (3, 12)]:
|
||||
x = torch.randn(batch, seq, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
# Allow at most a small number of compiles — two shape transitions can
|
||||
# legitimately take Dynamo two retraces (one per newly-dynamic dim).
|
||||
assert len(counts) <= 3, (
|
||||
f"expected ≤3 compiles for two-dim dynamic, got {len(counts)}"
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -171,7 +170,8 @@ from test_models import (
|
||||
ScatterElementsAxis0TestModel,
|
||||
# ScatterElements models
|
||||
ScatterElementsTestModel,
|
||||
# ScatterND model
|
||||
# ScatterND / IndexPut models
|
||||
IndexPutOptionalModel,
|
||||
ScatterNDTestModel,
|
||||
ShapeReshapeBatchFlattenModel,
|
||||
ShapeReshapeKeepBatchModel,
|
||||
@@ -221,7 +221,6 @@ from test_models import (
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv1dFloorDivPositionalModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
@@ -1098,17 +1097,6 @@ def test_reduce_sum_all_axes(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
|
||||
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
|
||||
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
|
||||
eager = model(x)
|
||||
out = model_compiled(x)
|
||||
assert out.dtype == eager.dtype == torch.int64
|
||||
assert torch.equal(out, eager)
|
||||
|
||||
|
||||
def test_reduce_sum_3d_axis1(device: torch.device):
|
||||
"""Test sum reduction along axis 1 for a 3D tensor."""
|
||||
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
|
||||
@@ -1860,60 +1848,6 @@ def test_scaled_dot_product_attention(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== F.scaled_dot_product_attention (SDPA aten variants) ==========
|
||||
# Tests for `torch.nn.functional.scaled_dot_product_attention`, which lowers
|
||||
# to one of `aten._scaled_dot_product_*_attention.default` (variant chosen by
|
||||
# PyTorch's dispatcher: efficient/flash/flash_for_cpu/cudnn). Coverage here
|
||||
# exercises `translate_sdpa` end-to-end.
|
||||
|
||||
|
||||
def _sdpa_qkv(device: torch.device, b: int = 1, h: int = 2, s: int = 4, d: int = 8):
|
||||
"""Build a `(B, H, S, D)` Q/K/V triple of float32 tensors on `device`."""
|
||||
torch.manual_seed(0)
|
||||
q = torch.rand((b, h, s, d), device=device)
|
||||
k = torch.rand((b, h, s, d), device=device)
|
||||
v = torch.rand((b, h, s, d), device=device)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def test_sdpa_basic(device: torch.device):
|
||||
"""`F.scaled_dot_product_attention(q, k, v)` — default scale, no mask."""
|
||||
from test_models import SdpaBasicModel
|
||||
|
||||
model: torch.nn.Module = SdpaBasicModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
expected: torch.Tensor = model(q, k, v)
|
||||
actual: torch.Tensor = compiled(q, k, v)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_sdpa_causal(device: torch.device):
|
||||
"""`F.scaled_dot_product_attention(q, k, v, is_causal=True)`."""
|
||||
from test_models import SdpaCausalModel
|
||||
|
||||
model: torch.nn.Module = SdpaCausalModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
expected: torch.Tensor = model(q, k, v)
|
||||
actual: torch.Tensor = compiled(q, k, v)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_sdpa_with_attn_bias(device: torch.device):
|
||||
"""SDPA with an additive `attn_mask` (float bias) broadcast over heads."""
|
||||
from test_models import SdpaWithBiasModel
|
||||
|
||||
model: torch.nn.Module = SdpaWithBiasModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
bias = torch.zeros((1, 1, q.shape[-2], k.shape[-2]), device=device)
|
||||
bias[..., 0, 1] = -1.0 # any non-trivial bias to verify it's actually applied
|
||||
expected: torch.Tensor = model(q, k, v, bias)
|
||||
actual: torch.Tensor = compiled(q, k, v, bias)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_mlp_block(device: torch.device):
|
||||
"""Test two-layer MLP: Linear(8,16) -> ReLU -> Linear(16,4) on input (2,8)."""
|
||||
model: torch.nn.Module = MLPBlockModel().to(device)
|
||||
@@ -2035,16 +1969,9 @@ def test_split(device: torch.device):
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking.
|
||||
|
||||
Parametrized over int32/int64 to verify luminal preserves whichever
|
||||
integer dtype the eager model declares (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
|
||||
device
|
||||
)
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
@@ -2053,21 +1980,13 @@ def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype)
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
|
||||
assert output.dtype == original.dtype, (
|
||||
f"luminal returned {output.dtype}, eager produced {original.dtype}"
|
||||
)
|
||||
assert torch.equal(output, original)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Focused proof for built MoE routing support.
|
||||
|
||||
Parametrized over int32/int64 for the integer-valued outputs to verify
|
||||
luminal preserves the dtype declared by the eager model (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
@@ -2078,10 +1997,17 @@ def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
for actual, eager in zip(output, expected):
|
||||
assert actual.dtype == eager.dtype, (
|
||||
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
|
||||
)
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
@@ -2099,23 +2025,6 @@ def test_topk_values(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
def test_topk_values_width_128_with_indices(device: torch.device):
|
||||
"""Regression for router-sized TopK values when both tuple outputs are used."""
|
||||
|
||||
class TopKValuesAndIndices(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
values, indices = torch.topk(torch.softmax(x, dim=-1), 8, dim=1)
|
||||
return values, indices
|
||||
|
||||
model = TopKValuesAndIndices().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(4, 128, device=device)
|
||||
actual_values, actual_indices = model_compiled(x)
|
||||
expected_values, expected_indices = model(x)
|
||||
assert torch.allclose(actual_values, expected_values, atol=1e-5)
|
||||
assert torch.equal(actual_indices.to(expected_indices.dtype), expected_indices)
|
||||
|
||||
|
||||
def test_topk_indices(device: torch.device):
|
||||
"""Tests TopK indices output for 2D tensor along axis=1."""
|
||||
model: torch.nn.Module = TopKIndicesTestModel().to(device)
|
||||
@@ -2173,6 +2082,16 @@ def test_scatter_nd(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_index_put_optional(device: torch.device):
|
||||
"""Tests index_put with optional (None) indices — mirrors StaticCache KV update."""
|
||||
model: torch.nn.Module = IndexPutOptionalModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.zeros(2, 2, 8, 4, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== Bool-mask index_put correctness tests ==========
|
||||
#
|
||||
# `x[bool_mask] = scalar` is semantically `where(mask, scalar, x)`, NOT a
|
||||
@@ -2498,17 +2417,6 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_floor_div_positional_pt2(device: torch.device):
|
||||
"""Conv1d stride output uses floor division before positional add."""
|
||||
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
|
||||
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.shape == original.shape == (15, 16)
|
||||
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user